Selaa lähdekoodia

Rework TCP retransmit logic to be much more robust.

Before this commit, if the amount of data in the buffer caused it
to be split among many outgoing packets, and retransmit timer
was active, the socket would behave very erratically and flood
the peer.
whitequark 7 vuotta sitten
vanhempi
commit
79d8470f41
1 muutettua tiedostoa jossa 94 lisäystä ja 44 poistoa
  1. 94 44
      src/socket/tcp.rs

+ 94 - 44
src/socket/tcp.rs

@@ -176,10 +176,6 @@ const RETRANSMIT_DELAY: u64 = 100;
 const CLOSE_DELAY:      u64 = 10_000;
 const CLOSE_DELAY:      u64 = 10_000;
 
 
 impl Timer {
 impl Timer {
-    fn is_idle(&self) -> bool {
-        *self == Timer::Idle
-    }
-
     fn should_retransmit(&self, timestamp: u64) -> Option<u64> {
     fn should_retransmit(&self, timestamp: u64) -> Option<u64> {
         match *self {
         match *self {
             Timer::Retransmit { expires_at, delay }
             Timer::Retransmit { expires_at, delay }
@@ -212,7 +208,7 @@ impl Timer {
         *self = Timer::Idle
         *self = Timer::Idle
     }
     }
 
 
-    fn set_for_data(&mut self, timestamp: u64) {
+    fn set_for_retransmit(&mut self, timestamp: u64) {
         match *self {
         match *self {
             Timer::Idle => {
             Timer::Idle => {
                 *self = Timer::Retransmit {
                 *self = Timer::Retransmit {
@@ -273,7 +269,7 @@ pub struct TcpSocket<'a> {
     remote_seq_no:   TcpSeqNumber,
     remote_seq_no:   TcpSeqNumber,
     /// The last sequence number sent.
     /// The last sequence number sent.
     /// I.e. in an idle socket, local_seq_no+tx_buffer.len().
     /// I.e. in an idle socket, local_seq_no+tx_buffer.len().
-    remote_next_seq: TcpSeqNumber,
+    remote_last_seq: TcpSeqNumber,
     /// The last acknowledgement number sent.
     /// The last acknowledgement number sent.
     /// I.e. in an idle socket, remote_seq_no+rx_buffer.len().
     /// I.e. in an idle socket, remote_seq_no+rx_buffer.len().
     remote_last_ack: TcpSeqNumber,
     remote_last_ack: TcpSeqNumber,
@@ -307,7 +303,7 @@ impl<'a> TcpSocket<'a> {
             remote_endpoint: IpEndpoint::default(),
             remote_endpoint: IpEndpoint::default(),
             local_seq_no:    TcpSeqNumber::default(),
             local_seq_no:    TcpSeqNumber::default(),
             remote_seq_no:   TcpSeqNumber::default(),
             remote_seq_no:   TcpSeqNumber::default(),
-            remote_next_seq: TcpSeqNumber::default(),
+            remote_last_seq: TcpSeqNumber::default(),
             remote_last_ack: TcpSeqNumber::default(),
             remote_last_ack: TcpSeqNumber::default(),
             remote_win_len:  0,
             remote_win_len:  0,
             remote_mss:      DEFAULT_MSS,
             remote_mss:      DEFAULT_MSS,
@@ -353,7 +349,7 @@ impl<'a> TcpSocket<'a> {
         self.remote_endpoint = IpEndpoint::default();
         self.remote_endpoint = IpEndpoint::default();
         self.local_seq_no    = TcpSeqNumber::default();
         self.local_seq_no    = TcpSeqNumber::default();
         self.remote_seq_no   = TcpSeqNumber::default();
         self.remote_seq_no   = TcpSeqNumber::default();
-        self.remote_next_seq = TcpSeqNumber::default();
+        self.remote_last_seq = TcpSeqNumber::default();
         self.remote_last_ack = TcpSeqNumber::default();
         self.remote_last_ack = TcpSeqNumber::default();
         self.remote_win_len  = 0;
         self.remote_win_len  = 0;
         self.remote_mss      = DEFAULT_MSS;
         self.remote_mss      = DEFAULT_MSS;
@@ -421,7 +417,7 @@ impl<'a> TcpSocket<'a> {
         self.local_endpoint  = local_endpoint;
         self.local_endpoint  = local_endpoint;
         self.remote_endpoint = remote_endpoint;
         self.remote_endpoint = remote_endpoint;
         self.local_seq_no    = local_seq_no;
         self.local_seq_no    = local_seq_no;
-        self.remote_next_seq = local_seq_no;
+        self.remote_last_seq = local_seq_no;
         self.set_state(State::SynSent);
         self.set_state(State::SynSent);
         Ok(())
         Ok(())
     }
     }
@@ -742,7 +738,7 @@ impl<'a> TcpSocket<'a> {
         // [...] an empty acknowledgment segment containing the current send-sequence number
         // [...] an empty acknowledgment segment containing the current send-sequence number
         // and an acknowledgment indicating the next sequence number expected
         // and an acknowledgment indicating the next sequence number expected
         // to be received.
         // to be received.
-        reply_repr.seq_number = self.remote_next_seq;
+        reply_repr.seq_number = self.remote_last_seq;
         reply_repr.ack_number = Some(self.remote_last_ack);
         reply_repr.ack_number = Some(self.remote_last_ack);
         reply_repr.window_len = self.rx_buffer.window() as u16;
         reply_repr.window_len = self.rx_buffer.window() as u16;
 
 
@@ -944,8 +940,8 @@ impl<'a> TcpSocket<'a> {
                 self.remote_endpoint = IpEndpoint::new(ip_repr.src_addr(), repr.src_port);
                 self.remote_endpoint = IpEndpoint::new(ip_repr.src_addr(), repr.src_port);
                 // FIXME: use something more secure here
                 // FIXME: use something more secure here
                 self.local_seq_no    = TcpSeqNumber(-repr.seq_number.0);
                 self.local_seq_no    = TcpSeqNumber(-repr.seq_number.0);
-                self.remote_next_seq = self.local_seq_no;
                 self.remote_seq_no   = repr.seq_number + 1;
                 self.remote_seq_no   = repr.seq_number + 1;
+                self.remote_last_seq = self.local_seq_no;
                 if let Some(max_seg_size) = repr.max_seg_size {
                 if let Some(max_seg_size) = repr.max_seg_size {
                     self.remote_mss = max_seg_size as usize
                     self.remote_mss = max_seg_size as usize
                 }
                 }
@@ -973,8 +969,8 @@ impl<'a> TcpSocket<'a> {
                 net_trace!("[{}]{}:{}: received SYN|ACK",
                 net_trace!("[{}]{}:{}: received SYN|ACK",
                            self.debug_id, self.local_endpoint, self.remote_endpoint);
                            self.debug_id, self.local_endpoint, self.remote_endpoint);
                 self.local_endpoint  = IpEndpoint::new(ip_repr.dst_addr(), repr.dst_port);
                 self.local_endpoint  = IpEndpoint::new(ip_repr.dst_addr(), repr.dst_port);
-                self.remote_next_seq = self.local_seq_no + 1;
                 self.remote_seq_no   = repr.seq_number + 1;
                 self.remote_seq_no   = repr.seq_number + 1;
+                self.remote_last_seq = self.local_seq_no + 1;
                 self.remote_last_ack = repr.seq_number;
                 self.remote_last_ack = repr.seq_number;
                 if let Some(max_seg_size) = repr.max_seg_size {
                 if let Some(max_seg_size) = repr.max_seg_size {
                     self.remote_mss = max_seg_size as usize;
                     self.remote_mss = max_seg_size as usize;
@@ -1087,24 +1083,34 @@ impl<'a> TcpSocket<'a> {
         }
         }
     }
     }
 
 
+    fn seq_to_transmit(&self, control: TcpControl) -> bool {
+        self.remote_last_seq < self.local_seq_no + self.tx_buffer.len() + control.len()
+    }
+
+    fn ack_to_transmit(&self) -> bool {
+        self.remote_last_ack < self.remote_seq_no + self.rx_buffer.len()
+    }
+
     pub(crate) fn dispatch<F>(&mut self, timestamp: u64, limits: &DeviceLimits,
     pub(crate) fn dispatch<F>(&mut self, timestamp: u64, limits: &DeviceLimits,
                               emit: F) -> Result<()>
                               emit: F) -> Result<()>
             where F: FnOnce((IpRepr, TcpRepr)) -> Result<()> {
             where F: FnOnce((IpRepr, TcpRepr)) -> Result<()> {
         if !self.remote_endpoint.is_specified() { return Err(Error::Exhausted) }
         if !self.remote_endpoint.is_specified() { return Err(Error::Exhausted) }
 
 
-        if let Some(retransmit_delta) = self.timer.should_retransmit(timestamp) {
-            // If a retransmit timer expired, we should resend data starting at the last ACK.
-            net_debug!("[{}]{}:{}: retransmitting at t+{}ms",
-                       self.debug_id, self.local_endpoint, self.remote_endpoint,
-                       retransmit_delta);
-            self.remote_next_seq = self.local_seq_no;
+        if !self.seq_to_transmit(TcpControl::None) {
+            if let Some(retransmit_delta) = self.timer.should_retransmit(timestamp) {
+                // If a retransmit timer expired, we should resend data starting at the last ACK.
+                net_debug!("[{}]{}:{}: retransmitting at t+{}ms",
+                           self.debug_id, self.local_endpoint, self.remote_endpoint,
+                           retransmit_delta);
+                self.remote_last_seq = self.local_seq_no;
+            }
         }
         }
 
 
         let mut repr = TcpRepr {
         let mut repr = TcpRepr {
             src_port:     self.local_endpoint.port,
             src_port:     self.local_endpoint.port,
             dst_port:     self.remote_endpoint.port,
             dst_port:     self.remote_endpoint.port,
             control:      TcpControl::None,
             control:      TcpControl::None,
-            seq_number:   self.remote_next_seq,
+            seq_number:   self.remote_last_seq,
             ack_number:   Some(self.remote_seq_no + self.rx_buffer.len()),
             ack_number:   Some(self.remote_seq_no + self.rx_buffer.len()),
             window_len:   self.rx_buffer.window() as u16,
             window_len:   self.rx_buffer.window() as u16,
             max_seg_size: None,
             max_seg_size: None,
@@ -1136,7 +1142,7 @@ impl<'a> TcpSocket<'a> {
             State::Established | State::FinWait1 | State::CloseWait | State::LastAck => {
             State::Established | State::FinWait1 | State::CloseWait | State::LastAck => {
                 // Extract as much data as the remote side can receive in this packet
                 // Extract as much data as the remote side can receive in this packet
                 // from the transmit buffer.
                 // from the transmit buffer.
-                let offset = self.remote_next_seq - self.local_seq_no;
+                let offset = self.remote_last_seq - self.local_seq_no;
                 let size = cmp::min(self.remote_win_len, self.remote_mss);
                 let size = cmp::min(self.remote_win_len, self.remote_mss);
                 repr.payload = self.tx_buffer.peek(offset, size);
                 repr.payload = self.tx_buffer.peek(offset, size);
                 // If we've sent everything we had in the buffer, follow it with the PSH or FIN
                 // If we've sent everything we had in the buffer, follow it with the PSH or FIN
@@ -1171,14 +1177,14 @@ impl<'a> TcpSocket<'a> {
             }
             }
         }
         }
 
 
-        if self.timer.should_retransmit(timestamp).is_some() {
+        if self.seq_to_transmit(repr.control) && repr.segment_len() > 0 {
+            // If we have data to transmit and it fits into partner's window, do it.
+        } else if self.ack_to_transmit() {
+            // If we have data to acknowledge, do it.
+        } else if self.timer.should_retransmit(timestamp).is_some() {
             // If we have packets to retransmit, do it.
             // If we have packets to retransmit, do it.
-        } else if repr.segment_len() > 0 && self.timer.is_idle() {
-            // If we have something new to transmit, do it.
         } else if repr.control == TcpControl::Rst {
         } else if repr.control == TcpControl::Rst {
             // If we need to abort the connection, do it.
             // If we need to abort the connection, do it.
-        } else if self.remote_seq_no + self.rx_buffer.len() != self.remote_last_ack {
-            // If we have something to acknowledge, do it.
         } else {
         } else {
             return Err(Error::Exhausted)
             return Err(Error::Exhausted)
         }
         }
@@ -1197,7 +1203,7 @@ impl<'a> TcpSocket<'a> {
             if repr.payload.len() > 0 {
             if repr.payload.len() > 0 {
                 net_trace!("[{}]{}:{}: tx buffer: peeking at {} octets (from {})",
                 net_trace!("[{}]{}:{}: tx buffer: peeking at {} octets (from {})",
                            self.debug_id, self.local_endpoint, self.remote_endpoint,
                            self.debug_id, self.local_endpoint, self.remote_endpoint,
-                           repr.payload.len(), self.remote_next_seq - self.local_seq_no);
+                           repr.payload.len(), self.remote_last_seq - self.local_seq_no);
             } else {
             } else {
                 net_debug!("[{}]{}:{}: sending {}",
                 net_debug!("[{}]{}:{}: sending {}",
                            self.debug_id, self.local_endpoint, self.remote_endpoint,
                            self.debug_id, self.local_endpoint, self.remote_endpoint,
@@ -1245,14 +1251,13 @@ impl<'a> TcpSocket<'a> {
         emit((ip_repr, repr))?;
         emit((ip_repr, repr))?;
 
 
         // We've sent a packet successfully, so we can update the internal state now.
         // We've sent a packet successfully, so we can update the internal state now.
-        self.remote_next_seq = repr.seq_number + repr.segment_len();
+        self.remote_last_seq = repr.seq_number + repr.segment_len();
         self.remote_last_ack = repr.ack_number.unwrap_or_default();
         self.remote_last_ack = repr.ack_number.unwrap_or_default();
 
 
-        if self.remote_next_seq - self.local_seq_no >= self.tx_buffer.len() &&
-                repr.segment_len() > 0 {
-            // If we've transmitted all we could (and there was something to transmit),
-            // wind up the retransmit timer.
-            self.timer.set_for_data(timestamp);
+        if !self.seq_to_transmit(repr.control) && repr.segment_len() > 0 {
+            // If we've transmitted all data could (and there was something at all,
+            // data or flag, to transmit, not just an ACK), wind up the retransmit timer.
+            self.timer.set_for_retransmit(timestamp);
         }
         }
 
 
         if repr.control == TcpControl::Rst {
         if repr.control == TcpControl::Rst {
@@ -1324,11 +1329,11 @@ mod test {
     fn test_timer_retransmit() {
     fn test_timer_retransmit() {
         let mut r = Timer::Idle;
         let mut r = Timer::Idle;
         assert_eq!(r.should_retransmit(1000), None);
         assert_eq!(r.should_retransmit(1000), None);
-        r.set_for_data(1000);
+        r.set_for_retransmit(1000);
         assert_eq!(r.should_retransmit(1000), None);
         assert_eq!(r.should_retransmit(1000), None);
         assert_eq!(r.should_retransmit(1050), None);
         assert_eq!(r.should_retransmit(1050), None);
         assert_eq!(r.should_retransmit(1101), Some(101));
         assert_eq!(r.should_retransmit(1101), Some(101));
-        r.set_for_data(1101);
+        r.set_for_retransmit(1101);
         assert_eq!(r.should_retransmit(1101), None);
         assert_eq!(r.should_retransmit(1101), None);
         assert_eq!(r.should_retransmit(1150), None);
         assert_eq!(r.should_retransmit(1150), None);
         assert_eq!(r.should_retransmit(1200), None);
         assert_eq!(r.should_retransmit(1200), None);
@@ -1442,7 +1447,7 @@ mod test {
             assert_eq!(s1.remote_endpoint,  s2.remote_endpoint, "remote_endpoint");
             assert_eq!(s1.remote_endpoint,  s2.remote_endpoint, "remote_endpoint");
             assert_eq!(s1.local_seq_no,     s2.local_seq_no,    "local_seq_no");
             assert_eq!(s1.local_seq_no,     s2.local_seq_no,    "local_seq_no");
             assert_eq!(s1.remote_seq_no,    s2.remote_seq_no,   "remote_seq_no");
             assert_eq!(s1.remote_seq_no,    s2.remote_seq_no,   "remote_seq_no");
-            assert_eq!(s1.remote_next_seq,  s2.remote_next_seq, "remote_next_seq");
+            assert_eq!(s1.remote_last_seq,  s2.remote_last_seq, "remote_last_seq");
             assert_eq!(s1.remote_last_ack,  s2.remote_last_ack, "remote_last_ack");
             assert_eq!(s1.remote_last_ack,  s2.remote_last_ack, "remote_last_ack");
             assert_eq!(s1.remote_win_len,   s2.remote_win_len,  "remote_win_len");
             assert_eq!(s1.remote_win_len,   s2.remote_win_len,  "remote_win_len");
             assert_eq!(s1.timer,            s2.timer,           "timer");
             assert_eq!(s1.timer,            s2.timer,           "timer");
@@ -1599,7 +1604,7 @@ mod test {
         s.remote_endpoint = REMOTE_END;
         s.remote_endpoint = REMOTE_END;
         s.local_seq_no    = LOCAL_SEQ;
         s.local_seq_no    = LOCAL_SEQ;
         s.remote_seq_no   = REMOTE_SEQ + 1;
         s.remote_seq_no   = REMOTE_SEQ + 1;
-        s.remote_next_seq = LOCAL_SEQ;
+        s.remote_last_seq = LOCAL_SEQ;
         s.remote_win_len  = 256;
         s.remote_win_len  = 256;
         s
         s
     }
     }
@@ -1689,7 +1694,7 @@ mod test {
         s.local_endpoint  = IpEndpoint::new(IpAddress::v4(0, 0, 0, 0), LOCAL_PORT);
         s.local_endpoint  = IpEndpoint::new(IpAddress::v4(0, 0, 0, 0), LOCAL_PORT);
         s.remote_endpoint = REMOTE_END;
         s.remote_endpoint = REMOTE_END;
         s.local_seq_no    = LOCAL_SEQ;
         s.local_seq_no    = LOCAL_SEQ;
-        s.remote_next_seq = LOCAL_SEQ;
+        s.remote_last_seq = LOCAL_SEQ;
         s
         s
     }
     }
 
 
@@ -1841,7 +1846,7 @@ mod test {
         let mut s = socket_syn_received();
         let mut s = socket_syn_received();
         s.state           = State::Established;
         s.state           = State::Established;
         s.local_seq_no    = LOCAL_SEQ + 1;
         s.local_seq_no    = LOCAL_SEQ + 1;
-        s.remote_next_seq = LOCAL_SEQ + 1;
+        s.remote_last_seq = LOCAL_SEQ + 1;
         s.remote_last_ack = REMOTE_SEQ + 1;
         s.remote_last_ack = REMOTE_SEQ + 1;
         s
         s
     }
     }
@@ -2146,7 +2151,7 @@ mod test {
         let mut s = socket_fin_wait_1();
         let mut s = socket_fin_wait_1();
         s.state           = State::FinWait2;
         s.state           = State::FinWait2;
         s.local_seq_no    = LOCAL_SEQ + 1 + 1;
         s.local_seq_no    = LOCAL_SEQ + 1 + 1;
-        s.remote_next_seq = LOCAL_SEQ + 1 + 1;
+        s.remote_last_seq = LOCAL_SEQ + 1 + 1;
         s
         s
     }
     }
 
 
@@ -2176,7 +2181,7 @@ mod test {
     fn socket_closing() -> TcpSocket<'static> {
     fn socket_closing() -> TcpSocket<'static> {
         let mut s = socket_fin_wait_1();
         let mut s = socket_fin_wait_1();
         s.state           = State::Closing;
         s.state           = State::Closing;
-        s.remote_next_seq = LOCAL_SEQ + 1 + 1;
+        s.remote_last_seq = LOCAL_SEQ + 1 + 1;
         s.remote_seq_no   = REMOTE_SEQ + 1 + 1;
         s.remote_seq_no   = REMOTE_SEQ + 1 + 1;
         s
         s
     }
     }
@@ -2655,6 +2660,51 @@ mod test {
         }));
         }));
     }
     }
 
 
+    #[test]
+    fn test_data_retransmit_bursts() {
+        let mut s = socket_established();
+        s.remote_win_len = 6;
+        s.send_slice(b"abcdef012345").unwrap();
+
+        recv!(s, time 0, Ok(TcpRepr {
+            control:    TcpControl::None,
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload:    &b"abcdef"[..],
+            ..RECV_TEMPL
+        }), exact);
+        s.remote_win_len = 6;
+        recv!(s, time 0, Ok(TcpRepr {
+            control:    TcpControl::Psh,
+            seq_number: LOCAL_SEQ + 1 + 6,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload:    &b"012345"[..],
+            ..RECV_TEMPL
+        }), exact);
+        s.remote_win_len = 6;
+        recv!(s, time 0, Err(Error::Exhausted));
+
+        recv!(s, time 50, Err(Error::Exhausted));
+
+        recv!(s, time 100, Ok(TcpRepr {
+            control:    TcpControl::None,
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload:    &b"abcdef"[..],
+            ..RECV_TEMPL
+        }), exact);
+        s.remote_win_len = 6;
+        recv!(s, time 150, Ok(TcpRepr {
+            control:    TcpControl::Psh,
+            seq_number: LOCAL_SEQ + 1 + 6,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload:    &b"012345"[..],
+            ..RECV_TEMPL
+        }), exact);
+        s.remote_win_len = 6;
+        recv!(s, time 200, Err(Error::Exhausted));
+    }
+
     #[test]
     #[test]
     fn test_send_data_after_syn_ack_retransmit() {
     fn test_send_data_after_syn_ack_retransmit() {
         let mut s = socket_syn_received();
         let mut s = socket_syn_received();
@@ -2806,6 +2856,10 @@ mod test {
         }));
         }));
     }
     }
 
 
+    // =========================================================================================//
+    // Tests for window management.
+    // =========================================================================================//
+
     #[test]
     #[test]
     fn test_maximum_segment_size() {
     fn test_maximum_segment_size() {
         let mut s = socket_listen();
         let mut s = socket_listen();
@@ -2839,10 +2893,6 @@ mod test {
         }));
         }));
     }
     }
 
 
-    // =========================================================================================//
-    // Tests for window management.
-    // =========================================================================================//
-
     #[test]
     #[test]
     fn test_window_size_clamp() {
     fn test_window_size_clamp() {
         let mut s = socket_established();
         let mut s = socket_established();