Browse Source

tcp: ensure we always accept the segment at offset=0 even if the assembler is full.

Fixes #452
Dario Nieuwenhuis 2 years ago
parent
commit
8bd28ab66a
2 changed files with 87 additions and 23 deletions
  1. 19 23
      src/socket/tcp.rs
  2. 68 0
      src/storage/assembler.rs

+ 19 - 23
src/socket/tcp.rs

@@ -1798,30 +1798,26 @@ impl<'a> Socket<'a> {
         let assembler_was_empty = self.assembler.is_empty();
 
         // Try adding payload octets to the assembler.
-        match self.assembler.add(payload_offset, payload_len) {
-            Ok(()) => {
-                // Place payload octets into the buffer.
-                tcp_trace!(
-                    "rx buffer: receiving {} octets at offset {}",
-                    payload_len,
-                    payload_offset
-                );
-                let len_written = self
-                    .rx_buffer
-                    .write_unallocated(payload_offset, repr.payload);
-                debug_assert!(len_written == payload_len);
-            }
-            Err(_) => {
-                net_debug!(
-                    "assembler: too many holes to add {} octets at offset {}",
-                    payload_len,
-                    payload_offset
-                );
-                return None;
-            }
-        }
+        let Ok(contig_len) = self.assembler.add_then_remove_front(payload_offset, payload_len) else {
+            net_debug!(
+                "assembler: too many holes to add {} octets at offset {}",
+                payload_len,
+                payload_offset
+            );
+            return None;
+        };
+
+        // Place payload octets into the buffer.
+        tcp_trace!(
+            "rx buffer: receiving {} octets at offset {}",
+            payload_len,
+            payload_offset
+        );
+        let len_written = self
+            .rx_buffer
+            .write_unallocated(payload_offset, repr.payload);
+        debug_assert!(len_written == payload_len);
 
-        let contig_len = self.assembler.remove_front();
         if contig_len != 0 {
             // Enqueue the contiguous data octets in front of the buffer.
             tcp_trace!(

+ 68 - 0
src/storage/assembler.rs

@@ -269,6 +269,29 @@ impl Assembler {
         }
     }
 
+    /// Add a segment, then remove_front.
+    ///
+    /// This is equivalent to calling `add` then `remove_front` individually,
+    /// except it's guaranteed to not fail when offset = 0.
+    /// This is required for TCP: we must never drop the next expected segment, or
+    /// the protocol might get stuck.
+    pub fn add_then_remove_front(
+        &mut self,
+        offset: usize,
+        size: usize,
+    ) -> Result<usize, TooManyHolesError> {
+        // This is the only case where a segment at offset=0 would cause the
+        // total amount of contigs to rise (and therefore can potentially cause
+        // a TooManyHolesError). Handle it in a way that is guaranteed to succeed.
+        if offset == 0 && size < self.contigs[0].hole_size {
+            self.contigs[0].hole_size -= size;
+            return Ok(size);
+        }
+
+        self.add(offset, size)?;
+        Ok(self.remove_front())
+    }
+
     /// Iterate over all of the contiguous data ranges.
     ///
     /// This is used in calculating what data ranges have been received. The offset indicates the
@@ -598,6 +621,51 @@ mod test {
         assert_eq!(assr.add(1, 1), Ok(()));
     }
 
+    #[test]
+    fn test_add_then_remove_front() {
+        let mut assr = Assembler::new();
+        assert_eq!(assr.add(50, 10), Ok(()));
+        assert_eq!(assr.add_then_remove_front(10, 10), Ok(0));
+        assert_eq!(assr, contigs![(10, 10), (30, 10)]);
+    }
+
+    #[test]
+    fn test_add_then_remove_front_at_front() {
+        let mut assr = Assembler::new();
+        assert_eq!(assr.add(50, 10), Ok(()));
+        assert_eq!(assr.add_then_remove_front(0, 10), Ok(10));
+        assert_eq!(assr, contigs![(40, 10)]);
+    }
+
+    #[test]
+    fn test_add_then_remove_front_at_front_touch() {
+        let mut assr = Assembler::new();
+        assert_eq!(assr.add(50, 10), Ok(()));
+        assert_eq!(assr.add_then_remove_front(0, 50), Ok(60));
+        assert_eq!(assr, contigs![]);
+    }
+
+    #[test]
+    fn test_add_then_remove_front_at_front_full() {
+        let mut assr = Assembler::new();
+        for c in 1..=CONTIG_COUNT {
+            assert_eq!(assr.add(c * 10, 3), Ok(()));
+        }
+        // Maximum of allowed holes is reached
+        let assr_before = assr.clone();
+        assert_eq!(assr.add_then_remove_front(1, 3), Err(TooManyHolesError));
+        assert_eq!(assr_before, assr);
+    }
+
+    #[test]
+    fn test_add_then_remove_front_at_front_full_offset_0() {
+        let mut assr = Assembler::new();
+        for c in 1..=CONTIG_COUNT {
+            assert_eq!(assr.add(c * 10, 3), Ok(()));
+        }
+        assert_eq!(assr.add_then_remove_front(0, 3), Ok(3));
+    }
+
     // Test against an obviously-correct but inefficient bitmap impl.
     #[test]
     fn test_random() {