Explorar o código

tcp: fix retransmit exponential backoff, align to rfc6298.

Dario Nieuwenhuis hai 5 meses
pai
achega
9b9f60b5ba
Modificáronse 1 ficheiros con 112 adicións e 104 borrados
  1. 112 104
      src/socket/tcp.rs

+ 112 - 104
src/socket/tcp.rs

@@ -141,23 +141,38 @@ impl fmt::Display for State {
     }
 }
 
-// Conservative initial RTT estimate.
-const RTTE_INITIAL_RTT: u32 = 300;
-const RTTE_INITIAL_DEV: u32 = 100;
+/// RFC 6298: (2.1) Until a round-trip time (RTT) measurement has been made for a
+/// segment sent between the sender and receiver, the sender SHOULD
+/// set RTO <- 1 second,
+const RTTE_INITIAL_RTO: u32 = 1000;
 
 // Minimum "safety margin" for the RTO that kicks in when the
 // variance gets very low.
 const RTTE_MIN_MARGIN: u32 = 5;
 
-const RTTE_MIN_RTO: u32 = 10;
-const RTTE_MAX_RTO: u32 = 10000;
+/// K, according to RFC 6298
+const RTTE_K: u32 = 4;
+
+// RFC 6298 (2.4): Whenever RTO is computed, if it is less than 1 second, then the
+// RTO SHOULD be rounded up to 1 second.
+const RTTE_MIN_RTO: u32 = 1000;
+
+// RFC 6298 (2.5) A maximum value MAY be placed on RTO provided it is at least 60
+// seconds
+const RTTE_MAX_RTO: u32 = 60_000;
 
 #[derive(Debug, Clone, Copy)]
 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
 struct RttEstimator {
+    /// true if we have made at least one rtt measurement.
+    have_measurement: bool,
     // Using u32 instead of Duration to save space (Duration is i64)
-    rtt: u32,
-    deviation: u32,
+    /// Smoothed RTT
+    srtt: u32,
+    /// RTT variance.
+    rttvar: u32,
+    /// Retransmission Time-Out
+    rto: u32,
     timestamp: Option<(Instant, TcpSeqNumber)>,
     max_seq_sent: Option<TcpSeqNumber>,
     rto_count: u8,
@@ -166,8 +181,10 @@ struct RttEstimator {
 impl Default for RttEstimator {
     fn default() -> Self {
         Self {
-            rtt: RTTE_INITIAL_RTT,
-            deviation: RTTE_INITIAL_DEV,
+            have_measurement: false,
+            srtt: 0,   // ignored, will be overwritten on first measurement.
+            rttvar: 0, // ignored, will be overwritten on first measurement.
+            rto: RTTE_INITIAL_RTO,
             timestamp: None,
             max_seq_sent: None,
             rto_count: 0,
@@ -177,26 +194,34 @@ impl Default for RttEstimator {
 
 impl RttEstimator {
     fn retransmission_timeout(&self) -> Duration {
-        let margin = RTTE_MIN_MARGIN.max(self.deviation * 4);
-        let ms = (self.rtt + margin).clamp(RTTE_MIN_RTO, RTTE_MAX_RTO);
-        Duration::from_millis(ms as u64)
+        Duration::from_millis(self.rto as _)
     }
 
     fn sample(&mut self, new_rtt: u32) {
-        // "Congestion Avoidance and Control", Van Jacobson, Michael J. Karels, 1988
-        self.rtt = (self.rtt * 7 + new_rtt + 7) / 8;
-        let diff = (self.rtt as i32 - new_rtt as i32).unsigned_abs();
-        self.deviation = (self.deviation * 3 + diff + 3) / 4;
+        if self.have_measurement {
+            // RFC 6298 (2.3) When a subsequent RTT measurement R' is made, a host MUST set (...)
+            let diff = (self.srtt as i32 - new_rtt as i32).unsigned_abs();
+            self.rttvar = (self.rttvar * 3 + diff).div_ceil(4);
+            self.srtt = (self.srtt * 7 + new_rtt).div_ceil(8);
+        } else {
+            // RFC 6298 (2.2) When the first RTT measurement R is made, the host MUST set (...)
+            self.have_measurement = true;
+            self.srtt = new_rtt;
+            self.rttvar = new_rtt / 2;
+        }
+
+        // RFC 6298 (2.2), (2.3)
+        let margin = RTTE_MIN_MARGIN.max(self.rttvar * RTTE_K);
+        self.rto = (self.srtt + margin).clamp(RTTE_MIN_RTO, RTTE_MAX_RTO);
 
         self.rto_count = 0;
 
-        let rto = self.retransmission_timeout().total_millis();
         tcp_trace!(
-            "rtte: sample={:?} rtt={:?} dev={:?} rto={:?}",
+            "rtte: sample={:?} srtt={:?} rttvar={:?} rto={:?}",
             new_rtt,
-            self.rtt,
-            self.deviation,
-            rto
+            self.srtt,
+            self.rttvar,
+            self.rto
         );
     }
 
@@ -228,23 +253,23 @@ impl RttEstimator {
             tcp_trace!("rtte: abort sampling due to retransmit");
         }
         self.timestamp = None;
-        self.rto_count = self.rto_count.saturating_add(1);
+
+        // RFC 6298 (5.5) The host MUST set RTO <- RTO * 2 ("back off the timer").  The
+        // maximum value discussed in (2.5) above may be used to provide
+        // an upper bound to this doubling operation.
+        self.rto = (self.rto * 2).min(RTTE_MAX_RTO);
+        tcp_trace!("rtte: doubling rto to {:?}", self.rto);
+
+        // RFC 6298: a TCP implementation MAY clear SRTT and RTTVAR after
+        // backing off the timer multiple times as it is likely that the current
+        // SRTT and RTTVAR are bogus in this situation.  Once SRTT and RTTVAR
+        // are cleared, they should be initialized with the next RTT sample
+        // taken per (2.2) rather than using (2.3).
+        self.rto_count += 1;
         if self.rto_count >= 3 {
-            // This happens in 2 scenarios:
-            // - The RTT is higher than the initial estimate
-            // - The network conditions change, suddenly making the RTT much higher
-            // In these cases, the estimator can get stuck, because it can't sample because
-            // all packets sent would incur a retransmit. To avoid this, force an estimate
-            // increase if we see 3 consecutive retransmissions without any successful sample.
             self.rto_count = 0;
-            self.rtt = RTTE_MAX_RTO.min(self.rtt * 2);
-            let rto = self.retransmission_timeout().total_millis();
-            tcp_trace!(
-                "rtte: too many retransmissions, increasing: rtt={:?} dev={:?} rto={:?}",
-                self.rtt,
-                self.deviation,
-                rto
-            );
+            self.have_measurement = false;
+            tcp_trace!("rtte: too many retransmissions, clearing srtt, rttvar.");
         }
     }
 }
@@ -252,17 +277,10 @@ impl RttEstimator {
 #[derive(Debug, Clone, Copy, PartialEq)]
 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
 enum Timer {
-    Idle {
-        keep_alive_at: Option<Instant>,
-    },
-    Retransmit {
-        expires_at: Instant,
-        delay: Duration,
-    },
+    Idle { keep_alive_at: Option<Instant> },
+    Retransmit { expires_at: Instant },
     FastRetransmit,
-    Close {
-        expires_at: Instant,
-    },
+    Close { expires_at: Instant },
 }
 
 const ACK_DELAY_DEFAULT: Duration = Duration::from_millis(10);
@@ -284,13 +302,11 @@ impl Timer {
         }
     }
 
-    fn should_retransmit(&self, timestamp: Instant) -> Option<Duration> {
+    fn should_retransmit(&self, timestamp: Instant) -> bool {
         match *self {
-            Timer::Retransmit { expires_at, delay } if timestamp >= expires_at => {
-                Some(timestamp - expires_at + delay)
-            }
-            Timer::FastRetransmit => Some(Duration::from_millis(0)),
-            _ => None,
+            Timer::Retransmit { expires_at } if timestamp >= expires_at => true,
+            Timer::FastRetransmit => true,
+            _ => false,
         }
     }
 
@@ -340,7 +356,6 @@ impl Timer {
             Timer::Idle { .. } | Timer::FastRetransmit { .. } | Timer::Retransmit { .. } => {
                 *self = Timer::Retransmit {
                     expires_at: timestamp + delay,
-                    delay,
                 }
             }
             Timer::Close { .. } => (),
@@ -2271,30 +2286,28 @@ impl<'a> Socket<'a> {
             // If a timeout expires, we should abort the connection.
             net_debug!("timeout exceeded");
             self.set_state(State::Closed);
-        } else if !self.seq_to_transmit(cx) {
-            if let Some(retransmit_delta) = self.timer.should_retransmit(cx.now()) {
-                // If a retransmit timer expired, we should resend data starting at the last ACK.
-                net_debug!("retransmitting at t+{}", retransmit_delta);
-
-                // Rewind "last sequence number sent", as if we never
-                // had sent them. This will cause all data in the queue
-                // to be sent again.
-                self.remote_last_seq = self.local_seq_no;
+        } else if !self.seq_to_transmit(cx) && self.timer.should_retransmit(cx.now()) {
+            // If a retransmit timer expired, we should resend data starting at the last ACK.
+            net_debug!("retransmitting");
 
-                // Clear the `should_retransmit` state. If we can't retransmit right
-                // now for whatever reason (like zero window), this avoids an
-                // infinite polling loop where `poll_at` returns `Now` but `dispatch`
-                // can't actually do anything.
-                self.timer.set_for_idle(cx.now(), self.keep_alive);
+            // Rewind "last sequence number sent", as if we never
+            // had sent them. This will cause all data in the queue
+            // to be sent again.
+            self.remote_last_seq = self.local_seq_no;
 
-                // Inform RTTE, so that it can avoid bogus measurements.
-                self.rtte.on_retransmit();
+            // Clear the `should_retransmit` state. If we can't retransmit right
+            // now for whatever reason (like zero window), this avoids an
+            // infinite polling loop where `poll_at` returns `Now` but `dispatch`
+            // can't actually do anything.
+            self.timer.set_for_idle(cx.now(), self.keep_alive);
 
-                // Inform the congestion controller that we're retransmitting.
-                self.congestion_controller
-                    .inner_mut()
-                    .on_retransmit(cx.now());
-            }
+            // Inform RTTE, so that it can avoid bogus measurements.
+            self.rtte.on_retransmit();
+
+            // Inform the congestion controller that we're retransmitting.
+            self.congestion_controller
+                .inner_mut()
+                .on_retransmit(cx.now());
         }
 
         // Decide whether we're sending a packet.
@@ -2735,6 +2748,7 @@ mod test {
         }
     }
 
+    #[track_caller]
     fn send(
         socket: &mut TestSocket,
         timestamp: Instant,
@@ -5709,9 +5723,9 @@ mod test {
             ..SEND_TEMPL
         });
         // The ACK of the first packet should restart the retransmit timer and delay a retransmission.
-        recv_nothing!(s, time 1500);
+        recv_nothing!(s, time 2399);
         // The second packet should be re-sent.
-        recv!(s, time 1600, Ok(TcpRepr {
+        recv!(s, time 2400, Ok(TcpRepr {
             control:    TcpControl::Psh,
             seq_number: LOCAL_SEQ + 1 + 6,
             ack_number: Some(REMOTE_SEQ + 1),
@@ -5770,7 +5784,7 @@ mod test {
             max_seg_size: Some(BASE_MSS),
             ..RECV_TEMPL
         }));
-        recv!(s, time 750, Ok(TcpRepr { // retransmit
+        recv!(s, time 1050, Ok(TcpRepr { // retransmit
             control:    TcpControl::Syn,
             seq_number: LOCAL_SEQ,
             ack_number: Some(REMOTE_SEQ + 1),
@@ -5891,18 +5905,18 @@ mod test {
             payload:    &b"ABCDEF"[..],
             ..RECV_TEMPL
         })); // also dropped
-        recv!(s, time 2000, Ok(TcpRepr {
+        recv!(s, time 3000, Ok(TcpRepr {
             seq_number: LOCAL_SEQ + 1,
             ack_number: Some(REMOTE_SEQ + 1),
             payload:    &b"abcdef"[..],
             ..RECV_TEMPL
         })); // retransmission
-        send!(s, time 2005, TcpRepr {
+        send!(s, time 3005, TcpRepr {
             seq_number: REMOTE_SEQ + 1,
             ack_number: Some(LOCAL_SEQ + 1 + 6 + 6),
             ..SEND_TEMPL
         }); // acknowledgement of both segments
-        recv!(s, time 2010, Ok(TcpRepr {
+        recv!(s, time 3010, Ok(TcpRepr {
             seq_number: LOCAL_SEQ + 1 + 6 + 6,
             ack_number: Some(REMOTE_SEQ + 1),
             payload:    &b"ABCDEF"[..],
@@ -6908,11 +6922,11 @@ mod test {
     #[test]
     fn test_established_timeout() {
         let mut s = socket_established();
-        s.set_timeout(Some(Duration::from_millis(1000)));
+        s.set_timeout(Some(Duration::from_millis(2000)));
         recv_nothing!(s, time 250);
         assert_eq!(
             s.socket.poll_at(&mut s.cx),
-            PollAt::Time(Instant::from_millis(1250))
+            PollAt::Time(Instant::from_millis(2250))
         );
         s.send_slice(b"abcdef").unwrap();
         assert_eq!(s.socket.poll_at(&mut s.cx), PollAt::Now);
@@ -6924,9 +6938,9 @@ mod test {
         }));
         assert_eq!(
             s.socket.poll_at(&mut s.cx),
-            PollAt::Time(Instant::from_millis(955))
+            PollAt::Time(Instant::from_millis(1255))
         );
-        recv!(s, time 955, Ok(TcpRepr {
+        recv!(s, time 1255, Ok(TcpRepr {
             seq_number: LOCAL_SEQ + 1,
             ack_number: Some(REMOTE_SEQ + 1),
             payload:    &b"abcdef"[..],
@@ -6934,9 +6948,9 @@ mod test {
         }));
         assert_eq!(
             s.socket.poll_at(&mut s.cx),
-            PollAt::Time(Instant::from_millis(1255))
+            PollAt::Time(Instant::from_millis(2255))
         );
-        recv!(s, time 1255, Ok(TcpRepr {
+        recv!(s, time 2255, Ok(TcpRepr {
             control:    TcpControl::Rst,
             seq_number: LOCAL_SEQ + 1 + 6,
             ack_number: Some(REMOTE_SEQ + 1),
@@ -7822,24 +7836,18 @@ mod test {
     fn test_timer_retransmit() {
         const RTO: Duration = Duration::from_millis(100);
         let mut r = Timer::new();
-        assert_eq!(r.should_retransmit(Instant::from_secs(1)), None);
+        assert!(!r.should_retransmit(Instant::from_secs(1)));
         r.set_for_retransmit(Instant::from_millis(1000), RTO);
-        assert_eq!(r.should_retransmit(Instant::from_millis(1000)), None);
-        assert_eq!(r.should_retransmit(Instant::from_millis(1050)), None);
-        assert_eq!(
-            r.should_retransmit(Instant::from_millis(1101)),
-            Some(Duration::from_millis(101))
-        );
+        assert!(!r.should_retransmit(Instant::from_millis(1000)));
+        assert!(!r.should_retransmit(Instant::from_millis(1050)));
+        assert!(r.should_retransmit(Instant::from_millis(1101)));
         r.set_for_retransmit(Instant::from_millis(1101), RTO);
-        assert_eq!(r.should_retransmit(Instant::from_millis(1101)), None);
-        assert_eq!(r.should_retransmit(Instant::from_millis(1150)), None);
-        assert_eq!(r.should_retransmit(Instant::from_millis(1200)), None);
-        assert_eq!(
-            r.should_retransmit(Instant::from_millis(1301)),
-            Some(Duration::from_millis(200))
-        );
+        assert!(!r.should_retransmit(Instant::from_millis(1101)));
+        assert!(!r.should_retransmit(Instant::from_millis(1150)));
+        assert!(!r.should_retransmit(Instant::from_millis(1200)));
+        assert!(r.should_retransmit(Instant::from_millis(1301)));
         r.set_for_idle(Instant::from_millis(1301), None);
-        assert_eq!(r.should_retransmit(Instant::from_millis(1350)), None);
+        assert!(!r.should_retransmit(Instant::from_millis(1350)));
     }
 
     #[test]
@@ -7847,12 +7855,12 @@ mod test {
         let mut r = RttEstimator::default();
 
         let rtos = &[
-            751, 766, 755, 731, 697, 656, 613, 567, 523, 484, 445, 411, 378, 350, 322, 299, 280,
-            261, 243, 229, 215, 206, 197, 188,
+            6000, 5000, 4252, 3692, 3272, 2956, 2720, 2540, 2408, 2308, 2232, 2176, 2132, 2100,
+            2076, 2060, 2048, 2036, 2028, 2024, 2020, 2016, 2012, 2012,
         ];
 
         for &rto in rtos {
-            r.sample(100);
+            r.sample(2000);
             assert_eq!(r.retransmission_timeout(), Duration::from_millis(rto));
         }
     }