Browse Source

Implement reassembly of out-of-order TCP segments.

whitequark 7 years ago
parent
commit
394ea4d633
3 changed files with 159 additions and 107 deletions
  1. 1 1
      README.md
  2. 105 69
      src/socket/tcp.rs
  3. 53 37
      src/storage/assembler.rs

+ 1 - 1
README.md

@@ -57,12 +57,12 @@ The TCP protocol is supported over IPv4. Server and client sockets are supported
   * Header checksum is generated and validated.
   * Maximum segment size is negotiated.
   * Multiple packets are transmitted without waiting for an acknowledgement.
+  * Reassembly of out-of-order segments is supported, with no more than 4 missing sequence ranges.
   * Lost packets are retransmitted with exponential backoff, starting at a fixed delay of 100 ms.
   * Sending keep-alive packets is supported, with a configurable interval.
   * Connection, retransmission and keep-alive timeouts are supported, with a configurable duration.
   * After arriving at the TIME-WAIT state, sockets close after a fixed delay of 10 s.
   * Urgent pointer is **not** supported; any urgent octets will be received alongside data octets.
-  * Reassembly of out-of-order segments is **not** supported.
   * Silly window syndrome avoidance is **not** supported for either transmission or reception.
   * Congestion control is **not** implemented.
   * Delayed acknowledgements are **not** implemented.

+ 105 - 69
src/socket/tcp.rs

@@ -7,7 +7,7 @@ use {Error, Result};
 use phy::DeviceLimits;
 use wire::{IpProtocol, IpAddress, IpEndpoint, TcpSeqNumber, TcpRepr, TcpControl};
 use socket::{Socket, IpRepr};
-use storage::RingBuffer;
+use storage::{Assembler, RingBuffer};
 
 pub type SocketBuffer<'a> = RingBuffer<'a, u8>;
 
@@ -171,6 +171,7 @@ pub struct TcpSocket<'a> {
     debug_id:        usize,
     state:           State,
     timer:           Timer,
+    assembler:       Assembler,
     rx_buffer:       SocketBuffer<'a>,
     tx_buffer:       SocketBuffer<'a>,
     /// Interval after which, if no inbound packets are received, the connection is aborted.
@@ -219,7 +220,7 @@ impl<'a> TcpSocket<'a> {
     /// Create a socket using the given buffers.
     pub fn new<T>(rx_buffer: T, tx_buffer: T) -> Socket<'a, 'static>
             where T: Into<SocketBuffer<'a>> {
-        let rx_buffer = rx_buffer.into();
+        let (rx_buffer, tx_buffer) = (rx_buffer.into(), tx_buffer.into());
         if rx_buffer.capacity() > <u16>::max_value() as usize {
             panic!("buffers larger than {} require window scaling, which is not implemented",
                    <u16>::max_value())
@@ -229,8 +230,9 @@ impl<'a> TcpSocket<'a> {
             debug_id:        0,
             state:           State::Closed,
             timer:           Timer::default(),
-            tx_buffer:       tx_buffer.into(),
-            rx_buffer:       rx_buffer.into(),
+            assembler:       Assembler::new(rx_buffer.capacity()),
+            tx_buffer:       tx_buffer,
+            rx_buffer:       rx_buffer,
             timeout:         None,
             keep_alive:      None,
             listen_address:  IpAddress::default(),
@@ -332,6 +334,9 @@ impl<'a> TcpSocket<'a> {
     fn reset(&mut self) {
         self.state           = State::Closed;
         self.timer           = Timer::default();
+        self.assembler       = Assembler::new(self.rx_buffer.capacity());
+        self.tx_buffer.clear();
+        self.rx_buffer.clear();
         self.keep_alive      = None;
         self.timeout         = None;
         self.listen_address  = IpAddress::default();
@@ -345,8 +350,6 @@ impl<'a> TcpSocket<'a> {
         self.remote_win_len  = 0;
         self.remote_mss      = DEFAULT_MSS;
         self.remote_last_ts  = None;
-        self.tx_buffer.clear();
-        self.rx_buffer.clear();
     }
 
     /// Start listening on the given endpoint.
@@ -860,13 +863,14 @@ impl<'a> TcpSocket<'a> {
             }
         }
 
+        let payload_offset;
         match self.state {
             // In LISTEN and SYN-SENT states, we have not yet synchronized with the remote end.
-            State::Listen  => (),
-            State::SynSent => (),
+            State::Listen | State::SynSent =>
+                payload_offset = 0,
             // In all other states, segments must occupy a valid portion of the receive window.
             _ => {
-                let mut send_challenge_ack = false;
+                let mut segment_in_window = true;
 
                 let window_start  = self.remote_seq_no + self.rx_buffer.len();
                 let window_end    = self.remote_seq_no + self.rx_buffer.capacity();
@@ -877,34 +881,22 @@ impl<'a> TcpSocket<'a> {
                     net_debug!("[{}]{}:{}: non-zero-length segment with zero receive window, \
                                 will only send an ACK",
                                self.debug_id, self.local_endpoint, self.remote_endpoint);
-                    send_challenge_ack = true;
+                    segment_in_window = false;
                 }
 
-                if !((window_start <= segment_start && segment_start <= window_end) ||
+                if !((window_start <= segment_start && segment_start <= window_end) &&
                      (window_start <= segment_end   && segment_end <= window_end)) {
                     net_debug!("[{}]{}:{}: segment not in receive window \
                                 ({}..{} not intersecting {}..{}), will send challenge ACK",
                                self.debug_id, self.local_endpoint, self.remote_endpoint,
                                segment_start, segment_end, window_start, window_end);
-                    send_challenge_ack = true;
+                    segment_in_window = false;
                 }
 
-                // For now, do not actually try to reassemble out-of-order segments.
-                if segment_start != window_start {
-                    net_debug!("[{}]{}:{}: out-of-order SEQ ({} not equal to {}), \
-                                will send challenge ACK",
-                               self.debug_id, self.local_endpoint, self.remote_endpoint,
-                               segment_start, window_start);
-                    // Some segments between what we have last received and this segment
-                    // went missing. Send a duplicate ACK; RFC 793 does not specify the behavior
-                    // required when receiving a duplicate ACK, but in practice (see RFC 1122
-                    // section 4.2.2.21) most congestion control algorithms implement what's called
-                    // a "fast retransmit", where a threshold amount of duplicate ACKs triggers
-                    // retransmission.
-                    send_challenge_ack = true;
-                }
-
-                if send_challenge_ack {
+                if segment_in_window {
+                    // We've checked that segment_start >= window_start above.
+                    payload_offset = (segment_start - window_start) as usize;
+                } else {
                     // If we're in the TIME-WAIT state, restart the TIME-WAIT timeout, since
                     // the remote end may not have realized we've closed the connection.
                     if self.state == State::TimeWait {
@@ -1087,28 +1079,64 @@ impl<'a> TcpSocket<'a> {
 
         if ack_len > 0 {
             // Dequeue acknowledged octets.
+            debug_assert!(self.tx_buffer.len() >= ack_len);
             net_trace!("[{}]{}:{}: tx buffer: dequeueing {} octets (now {})",
                        self.debug_id, self.local_endpoint, self.remote_endpoint,
                        ack_len, self.tx_buffer.len() - ack_len);
-            let acked = self.tx_buffer.dequeue_many(ack_len);
-            debug_assert!(acked.len() == ack_len);
+            self.tx_buffer.dequeue_many(ack_len);
         }
 
-        // We've processed everything in the incoming segment, so advance the local
-        // sequence number past it.
         if let Some(ack_number) = repr.ack_number {
+            // We've processed everything in the incoming segment, so advance the local
+            // sequence number past it.
             self.local_seq_no = ack_number;
         }
 
-        if repr.payload.len() > 0 {
-            // Enqueue payload octets, which are guaranteed to be in order.
+        let payload_len = repr.payload.len();
+        if payload_len == 0 { return Ok(None) }
+
+        // Try adding payload octets to the assembler.
+        match self.assembler.add(payload_offset, payload_len) {
+            Ok(()) => {
+                debug_assert!(self.assembler.total_size() == self.rx_buffer.capacity());
+                // Place payload octets into the buffer.
+                net_trace!("[{}]{}:{}: rx buffer: writing {} octets at offset {}",
+                           self.debug_id, self.local_endpoint, self.remote_endpoint,
+                           payload_len, payload_offset);
+                self.rx_buffer.get_unallocated(payload_offset, payload_len)
+                              .copy_from_slice(repr.payload);
+            }
+            Err(()) => {
+                net_debug!("[{}]{}:{}: assembler: too many holes to add {} octets at offset {}",
+                           self.debug_id, self.local_endpoint, self.remote_endpoint,
+                           payload_len, payload_offset);
+                return Err(Error::Dropped)
+            }
+        }
+
+        if let Some(contig_len) = self.assembler.remove_front() {
+            debug_assert!(self.assembler.total_size() == self.rx_buffer.capacity());
+            // Enqueue the contiguous data octets in front of the buffer.
             net_trace!("[{}]{}:{}: rx buffer: enqueueing {} octets (now {})",
                        self.debug_id, self.local_endpoint, self.remote_endpoint,
-                       repr.payload.len(), self.rx_buffer.len() + repr.payload.len());
-            self.rx_buffer.enqueue_slice(repr.payload);
+                       contig_len, self.rx_buffer.len() + contig_len);
+            self.rx_buffer.enqueue_many(contig_len);
         }
 
-        Ok(None)
+        if self.assembler.is_empty() {
+            Ok(None)
+        } else {
+            // If the assembler isn't empty, some segments at the start of our window got lost.
+            // Send a reply acknowledging the data we already have; RFC 793 does not specify
+            // the behavior triggerd by such a reply, but RFC 1122 section 4.2.2.21 states that
+            // most congestion control algorithms implement what's called a "fast retransmit",
+            // where a threshold amount of duplicate ACKs triggers retransmission without
+            // the need to wait for a timeout to expire.
+            net_trace!("[{}]{}:{}: assembler: {}",
+                       self.debug_id, self.local_endpoint, self.remote_endpoint,
+                       self.assembler);
+            Ok(Some(self.ack_reply(ip_repr, &repr)))
+        }
     }
 
     fn timed_out(&self, timestamp: u64) -> bool {
@@ -1262,7 +1290,7 @@ impl<'a> TcpSocket<'a> {
         }
 
         if repr.payload.len() > 0 {
-            net_trace!("[{}]{}:{}: tx buffer: peeking at {} octets (from {})",
+            net_trace!("[{}]{}:{}: tx buffer: reading {} octets at offset {}",
                        self.debug_id, self.local_endpoint, self.remote_endpoint,
                        repr.payload.len(), self.remote_last_seq - self.local_seq_no);
         } else {
@@ -2685,34 +2713,6 @@ mod test {
         })));
     }
 
-    #[test]
-    fn test_missing_segment() {
-        let mut s = socket_established();
-        send!(s, TcpRepr {
-            seq_number: REMOTE_SEQ + 1,
-            ack_number: Some(LOCAL_SEQ + 1),
-            payload:    &b"abcdef"[..],
-            ..SEND_TEMPL
-        });
-        recv!(s, [TcpRepr {
-            seq_number: LOCAL_SEQ + 1,
-            ack_number: Some(REMOTE_SEQ + 1 + 6),
-            window_len: 58,
-            ..RECV_TEMPL
-        }]);
-        send!(s, TcpRepr {
-            seq_number: REMOTE_SEQ + 1 + 6 + 6,
-            ack_number: Some(LOCAL_SEQ + 1),
-            payload:    &b"mnopqr"[..],
-            ..SEND_TEMPL
-        }, Ok(Some(TcpRepr {
-            seq_number: LOCAL_SEQ + 1,
-            ack_number: Some(REMOTE_SEQ + 1 + 6),
-            window_len: 58,
-            ..RECV_TEMPL
-        })));
-    }
-
     #[test]
     fn test_data_retransmit() {
         let mut s = socket_established();
@@ -3013,6 +3013,7 @@ mod test {
     fn test_zero_window_ack() {
         let mut s = socket_established();
         s.rx_buffer = SocketBuffer::new(vec![0; 6]);
+        s.assembler = Assembler::new(s.rx_buffer.capacity());
         send!(s, TcpRepr {
             seq_number: REMOTE_SEQ + 1,
             ack_number: Some(LOCAL_SEQ + 1),
@@ -3042,6 +3043,7 @@ mod test {
     fn test_zero_window_ack_on_window_growth() {
         let mut s = socket_established();
         s.rx_buffer = SocketBuffer::new(vec![0; 6]);
+        s.assembler = Assembler::new(s.rx_buffer.capacity());
         send!(s, TcpRepr {
             seq_number: REMOTE_SEQ + 1,
             ack_number: Some(LOCAL_SEQ + 1),
@@ -3096,7 +3098,7 @@ mod test {
     }
 
     // =========================================================================================//
-    // Tests for timeouts
+    // Tests for timeouts.
     // =========================================================================================//
 
     #[test]
@@ -3193,7 +3195,7 @@ mod test {
     }
 
     // =========================================================================================//
-    // Tests for keep-alive
+    // Tests for keep-alive.
     // =========================================================================================//
 
     #[test]
@@ -3258,13 +3260,47 @@ mod test {
     }
 
     // =========================================================================================//
-    // Tests for packet filtering
+    // Tests for reassembly.
+    // =========================================================================================//
+
+    #[test]
+    fn test_out_of_order() {
+        let mut s = socket_established();
+        send!(s, TcpRepr {
+            seq_number: REMOTE_SEQ + 1 + 3,
+            ack_number: Some(LOCAL_SEQ + 1),
+            payload:    &b"def"[..],
+            ..SEND_TEMPL
+        }, Ok(Some(TcpRepr {
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(REMOTE_SEQ + 1),
+            ..RECV_TEMPL
+        })));
+        assert_eq!(s.recv(10), Ok(&b""[..]));
+        send!(s, TcpRepr {
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
+            payload:    &b"abcdef"[..],
+            ..SEND_TEMPL
+        });
+        recv!(s, [TcpRepr {
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(REMOTE_SEQ + 1 + 6),
+            window_len: 58,
+            ..RECV_TEMPL
+        }]);
+        assert_eq!(s.recv(10), Ok(&b"abcdef"[..]));
+    }
+
+    // =========================================================================================//
+    // Tests for packet filtering.
     // =========================================================================================//
 
     #[test]
     fn test_doesnt_accept_wrong_port() {
         let mut s = socket_established();
         s.rx_buffer = SocketBuffer::new(vec![0; 6]);
+        s.assembler = Assembler::new(s.rx_buffer.capacity());
 
         let tcp_repr = TcpRepr {
             seq_number: REMOTE_SEQ + 1,

+ 53 - 37
src/storage/assembler.rs

@@ -3,14 +3,15 @@ use core::fmt;
 /// A contiguous chunk of absent data, followed by a contiguous chunk of present data.
 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
 struct Contig {
-    hole_size: u32,
-    data_size: u32
+    hole_size: usize,
+    data_size: usize
 }
 
 impl fmt::Display for Contig {
     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
         if self.has_hole() { write!(f, "({})", self.hole_size)?; }
-        if self.has_data() { write!(f, " {}",  self.data_size)?; }
+        if self.has_hole() && self.has_data() { write!(f, " ")?; }
+        if self.has_data() { write!(f, "{}",   self.data_size)?; }
         Ok(())
     }
 }
@@ -20,11 +21,11 @@ impl Contig {
         Contig { hole_size: 0, data_size: 0 }
     }
 
-    fn hole(size: u32) -> Contig {
+    fn hole(size: usize) -> Contig {
         Contig { hole_size: size, data_size: 0 }
     }
 
-    fn hole_and_data(hole_size: u32, data_size: u32) -> Contig {
+    fn hole_and_data(hole_size: usize, data_size: usize) -> Contig {
         Contig { hole_size, data_size }
     }
 
@@ -36,7 +37,7 @@ impl Contig {
         self.data_size != 0
     }
 
-    fn total_size(&self) -> u32 {
+    fn total_size(&self) -> usize {
         self.hole_size + self.data_size
     }
 
@@ -44,15 +45,15 @@ impl Contig {
         self.total_size() == 0
     }
 
-    fn expand_data_by(&mut self, size: u32) {
+    fn expand_data_by(&mut self, size: usize) {
         self.data_size += size;
     }
 
-    fn shrink_hole_by(&mut self, size: u32) {
+    fn shrink_hole_by(&mut self, size: usize) {
         self.hole_size -= size;
     }
 
-    fn shrink_hole_to(&mut self, size: u32) {
+    fn shrink_hole_to(&mut self, size: usize) {
         assert!(self.hole_size >= size);
 
         let total_size = self.total_size();
@@ -86,13 +87,13 @@ impl fmt::Display for Assembler {
 
 impl Assembler {
     /// Create a new buffer assembler for buffers of the given size.
-    pub fn new(size: u32) -> Assembler {
+    pub fn new(size: usize) -> Assembler {
         let mut contigs = [Contig::empty(); CONTIG_COUNT];
         contigs[0] = Contig::hole(size);
         Assembler { contigs }
     }
 
-    pub(crate) fn total_size(&self) -> u32 {
+    pub(crate) fn total_size(&self) -> usize {
         self.contigs
             .iter()
             .map(|contig| contig.total_size())
@@ -107,13 +108,20 @@ impl Assembler {
         self.contigs[self.contigs.len() - 1]
     }
 
-    /// Remove a contig at the given index, and return a pointer to the the first empty contig.
+    /// Return whether the assembler contains no data.
+    pub fn is_empty(&self) -> bool {
+        !self.front().has_data()
+    }
+
+    /// Remove a contig at the given index, and return a pointer to the first contig
+    /// without data.
     fn remove_contig_at(&mut self, at: usize) -> &mut Contig {
         debug_assert!(!self.contigs[at].is_empty());
 
         for i in at..self.contigs.len() - 1 {
             self.contigs[i] = self.contigs[i + 1];
-            if self.contigs[i].is_empty() {
+            if !self.contigs[i].has_data() {
+                self.contigs[i + 1] = Contig::empty();
                 return &mut self.contigs[i]
             }
         }
@@ -139,7 +147,7 @@ impl Assembler {
 
     /// Add a new contiguous range to the assembler, and return `Ok(())`,
     /// or return `Err(())` if too many discontiguities are already recorded.
-    pub fn add(&mut self, mut offset: u32, mut size: u32) -> Result<(), ()> {
+    pub fn add(&mut self, mut offset: usize, mut size: usize) -> Result<(), ()> {
         let mut index = 0;
         while index != self.contigs.len() && size != 0 {
             let contig = self.contigs[index];
@@ -172,8 +180,8 @@ impl Assembler {
                 // The range being added covers a part of the hole but not of the data
                 // in this contig, add a new contig containing the range.
                 self.contigs[index].shrink_hole_by(offset + size);
-                let empty = self.add_contig_at(index)?;
-                *empty = Contig::hole_and_data(offset, size);
+                let inserted = self.add_contig_at(index)?;
+                *inserted = Contig::hole_and_data(offset, size);
                 index += 2;
             } else {
                 unreachable!()
@@ -192,30 +200,18 @@ impl Assembler {
         Ok(())
     }
 
-    /// Return `Ok(size)` with the size of a contiguous range in the front of the assembler,
-    /// or return `Err(())` if there is no such range.
-    pub fn front_len(&self) -> u32 {
-        let front = self.front();
-        if front.has_hole() {
-            0
-        } else {
-            debug_assert!(front.data_size > 0);
-            front.data_size
-        }
-    }
-
-    /// Remove a contiguous range from the front of the assembler and `Ok(data_size)`,
-    /// or return `Err(())` if there is no such range.
-    pub fn front_remove(&mut self) -> u32 {
+    /// Remove a contiguous range from the front of the assembler and `Some(data_size)`,
+    /// or return `None` if there is no such range.
+    pub fn remove_front(&mut self) -> Option<usize> {
         let front = self.front();
         if front.has_hole() {
-            0
+            None
         } else {
-            let empty = self.remove_contig_at(0);
-            *empty = Contig::hole(front.data_size);
+            let last_hole = self.remove_contig_at(0);
+            last_hole.hole_size += front.data_size;
 
             debug_assert!(front.data_size > 0);
-            front.data_size
+            Some(front.data_size)
         }
     }
 }
@@ -225,8 +221,8 @@ mod test {
     use std::vec::Vec;
     use super::*;
 
-    impl From<Vec<(u32, u32)>> for Assembler {
-        fn from(vec: Vec<(u32, u32)>) -> Assembler {
+    impl From<Vec<(usize, usize)>> for Assembler {
+        fn from(vec: Vec<(usize, usize)>) -> Assembler {
             let mut contigs = [Contig::empty(); CONTIG_COUNT];
             for (i, &(hole_size, data_size)) in vec.iter().enumerate() {
                 contigs[i] = Contig { hole_size, data_size };
@@ -331,4 +327,24 @@ mod test {
         assert_eq!(assr.add(2, 12), Ok(()));
         assert_eq!(assr, contigs![(2, 12), (2, 0)]);
     }
+
+    #[test]
+    fn test_empty_remove_front() {
+        let mut assr = contigs![(12, 0)];
+        assert_eq!(assr.remove_front(), None);
+    }
+
+    #[test]
+    fn test_trailing_hole_remove_front() {
+        let mut assr = contigs![(0, 4), (8, 0)];
+        assert_eq!(assr.remove_front(), Some(4));
+        assert_eq!(assr, contigs![(12, 0)]);
+    }
+
+    #[test]
+    fn test_trailing_data_remove_front() {
+        let mut assr = contigs![(0, 4), (4, 4)];
+        assert_eq!(assr.remove_front(), Some(4));
+        assert_eq!(assr, contigs![(4, 4), (4, 0)]);
+    }
 }