Sfoglia il codice sorgente

Add support for multiple outgoing in-flight TCP packets.

whitequark 8 anni fa
parent
commit
c3eee36b8f
1 ha cambiato i file con 73 aggiunte e 21 eliminazioni
  1. 73 21
      src/socket/tcp.rs

+ 73 - 21
src/socket/tcp.rs

@@ -68,8 +68,8 @@ impl<'a> SocketBuffer<'a> {
         dest.copy_from_slice(data);
     }
 
-    fn clamp_reader(&self, mut size: usize) -> (usize, usize) {
-        let read_at = self.read_at;
+    fn clamp_reader(&self, offset: usize, mut size: usize) -> (usize, usize) {
+        let read_at = (self.read_at + offset) % self.storage.len();
         // We can't dequeue more than was queued.
         if size > self.length { size = self.length }
         // We can't contiguously dequeue past the end of the storage.
@@ -79,23 +79,24 @@ impl<'a> SocketBuffer<'a> {
         (read_at, size)
     }
 
-    fn peek(&self, size: usize) -> &[u8] {
-        let (read_at, size) = self.clamp_reader(size);
+    #[allow(dead_code)] // only used in tests
+    fn dequeue(&mut self, size: usize) -> &[u8] {
+        let (read_at, size) = self.clamp_reader(0, size);
+        self.read_at = (self.read_at + size) % self.storage.len();
+        self.length -= size;
         &self.storage[read_at..read_at + size]
     }
 
-    fn advance(&mut self, size: usize) {
-        let (read_at, size) = self.clamp_reader(size);
-        self.read_at = (read_at + size) % self.storage.len();
-        self.length -= size;
+    fn peek(&self, offset: usize, size: usize) -> &[u8] {
+        if offset > self.length { panic!("peeking {} octets past free space", offset) }
+        let (read_at, size) = self.clamp_reader(offset, size);
+        &self.storage[read_at..read_at + size]
     }
 
-    #[allow(dead_code)] // only used in tests
-    fn dequeue(&mut self, size: usize) -> &[u8] {
-        let (read_at, size) = self.clamp_reader(size);
+    fn advance(&mut self, size: usize) {
+        if size > self.length { panic!("advancing {} octets into free space", size) }
         self.read_at = (self.read_at + size) % self.storage.len();
         self.length -= size;
-        &self.storage[read_at..read_at + size]
     }
 }
 
@@ -162,13 +163,33 @@ impl Retransmit {
 /// A Transmission Control Protocol data stream.
 #[derive(Debug)]
 pub struct TcpSocket<'a> {
+    /// State of the socket.
     state:           State,
+    /// Address passed to `listen()`. `listen_address` is set when `listen()` is called and
+    /// used every time the socket is reset back to the `LISTEN` state.
     listen_address:  IpAddress,
+    /// Current local endpoint. This is used for both filtering the incoming packets and
+    /// setting the source address. When listening or initiating connection on/from
+    /// an unspecified address, this field is updated with the chosen source address before
+    /// any packets are sent.
     local_endpoint:  IpEndpoint,
+    /// Current remote endpoint. This is used for both filtering the incoming packets and
+    /// setting the destination address.
     remote_endpoint: IpEndpoint,
+    /// The sequence number corresponding to the beginning of the transmit buffer.
+    /// I.e. an ACK(local_seq_no+n) packet removes n bytes from the transmit buffer.
     local_seq_no:    i32,
+    /// The sequence number corresponding to the beginning of the receive buffer.
+    /// I.e. userspace reading n bytes adds n to remote_seq_no.
     remote_seq_no:   i32,
+    /// The last sequence number sent.
+    /// I.e. in an idle socket, local_seq_no+tx_buffer.len().
+    remote_last_seq: i32,
+    /// The last acknowledgement number sent.
+    /// I.e. in an idle socket, remote_seq_no+rx_buffer.len().
     remote_last_ack: i32,
+    /// The speculative remote window size.
+    /// I.e. the actual remote window size minus the count of in-flight octets.
     remote_win_len:  usize,
     retransmit:      Retransmit,
     rx_buffer:       SocketBuffer<'a>,
@@ -192,8 +213,9 @@ impl<'a> TcpSocket<'a> {
             remote_endpoint: IpEndpoint::default(),
             local_seq_no:    0,
             remote_seq_no:   0,
-            remote_win_len:  0,
+            remote_last_seq: 0,
             remote_last_ack: 0,
+            remote_win_len:  0,
             retransmit:      Retransmit::new(),
             tx_buffer:       tx_buffer.into(),
             rx_buffer:       rx_buffer.into()
@@ -252,7 +274,7 @@ impl<'a> TcpSocket<'a> {
     pub fn send(&mut self, size: usize) -> &mut [u8] {
         let buffer = self.tx_buffer.enqueue(size);
         if buffer.len() > 0 {
-            net_trace!("tcp:{}:{}: buffer to send {} octets",
+            net_trace!("tcp:{}:{}: tx buffer: enqueueing {} octets",
                        self.local_endpoint, self.remote_endpoint, buffer.len());
         }
         buffer
@@ -431,6 +453,7 @@ impl<'a> TcpSocket<'a> {
 
             // SYN|ACK packets in the SYN_RECEIVED state change it to ESTABLISHED.
             (State::SynReceived, TcpRepr { control: TcpControl::None, .. }) => {
+                self.remote_last_seq = self.local_seq_no + 1;
                 self.set_state(State::Established);
                 self.retransmit.reset()
             }
@@ -500,24 +523,36 @@ impl<'a> TcpSocket<'a> {
                 repr.control = TcpControl::Syn;
                 net_trace!("tcp:{}:{}: sending SYN|ACK",
                            self.local_endpoint, self.remote_endpoint);
-                self.remote_last_ack = self.remote_seq_no;
             }
 
             State::Established => {
-                if self.tx_buffer.len() > 0 && self.remote_win_len > 0 {
-                    if !self.retransmit.check() { return Err(Error::Exhausted) }
+                // See if we should send data to the remote end because:
+                //   1. the retransmit timer has expired, or...
+                let mut may_send = self.retransmit.check();
+                //   2. we've got new data in the transmit buffer.
+                let remote_next_seq = self.local_seq_no + self.tx_buffer.len() as i32;
+                if self.remote_last_seq != remote_next_seq {
+                    may_send = true;
+                }
 
+                if self.tx_buffer.len() > 0 && self.remote_win_len > 0 && may_send {
                     // We can send something, so let's do that.
+                    let offset = self.remote_last_seq - self.local_seq_no;
                     let mut size = self.remote_win_len;
                     // Clamp to MSS. Currently we only support the default MSS value.
                     if size > 536 { size = 536 }
                     // Extract data from the buffer. This may return less than what we want,
                     // in case it's not possible to extract a contiguous slice.
-                    let data = self.tx_buffer.peek(size);
-
-                    net_trace!("tcp:{}:{}: sending {} octets",
+                    let data = self.tx_buffer.peek(offset as usize, size);
+                    // Send the extracted data.
+                    net_trace!("tcp:{}:{}: tx buffer: peeking at {} octets",
                                self.local_endpoint, self.remote_endpoint, data.len());
                     repr.payload = data;
+                    // Speculatively shrink the remote window. This will get updated the next
+                    // time we receive a packet.
+                    self.remote_win_len -= data.len();
+                    // Advance the in-flight sequence number.
+                    self.remote_last_seq += data.len() as i32;
                 } else if self.remote_last_ack != ack_number {
                     // We don't have anything to send, or can't because the remote end does not
                     // have any space to accept it, but we haven't yet acknowledged everything
@@ -822,12 +857,14 @@ mod test {
         s.remote_endpoint = REMOTE_END;
         s.local_seq_no    = LOCAL_SEQ + 1;
         s.remote_seq_no   = REMOTE_SEQ + 1;
+        s.remote_last_seq = LOCAL_SEQ + 1;
+        s.remote_last_ack = REMOTE_SEQ + 1;
         s.remote_win_len  = 128;
         s
     }
 
     #[test]
-    fn test_established_receive() {
+    fn test_established_recv() {
         let mut s = socket_established();
         send!(s, [TcpRepr {
             seq_number: REMOTE_SEQ + 1,
@@ -847,6 +884,7 @@ mod test {
     #[test]
     fn test_established_send() {
         let mut s = socket_established();
+        // First roundtrip after establishing.
         s.tx_buffer.enqueue_slice(b"abcdef");
         recv!(s, [TcpRepr {
             seq_number: LOCAL_SEQ + 1,
@@ -861,6 +899,20 @@ mod test {
             ..SEND_TEMPL
         }]);
         assert_eq!(s.tx_buffer.len(), 0);
+        // Second roundtrip.
+        s.tx_buffer.enqueue_slice(b"foobar");
+        recv!(s, [TcpRepr {
+            seq_number: LOCAL_SEQ + 1 + 6,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload: &b"foobar"[..],
+            ..RECV_TEMPL
+        }]);
+        send!(s, [TcpRepr {
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1 + 6 + 6),
+            ..SEND_TEMPL
+        }]);
+        assert_eq!(s.tx_buffer.len(), 0);
     }
 
     #[test]