Browse Source

tcp: Add RTT estimation.

Dario Nieuwenhuis 4 years ago
parent
commit
5117af776a
1 changed files with 144 additions and 24 deletions
  1. 144 24
      src/socket/tcp.rs

+ 144 - 24
src/socket/tcp.rs

@@ -54,6 +54,99 @@ impl fmt::Display for State {
     }
 }
 
+// Conservative initial RTT estimate.
+const RTTE_INITIAL_RTT: u32 = 300;
+const RTTE_INITIAL_DEV: u32 = 100;
+
+// Minimum "safety margin" for the RTO that kicks in when the
+// variance gets very low.
+const RTTE_MIN_MARGIN: u32 = 5;
+
+const RTTE_MIN_RTO: u32 = 10;
+const RTTE_MAX_RTO: u32 = 10000;
+
+#[derive(Debug, Clone, Copy)]
+#[cfg_attr(feature = "defmt", derive(defmt::Format))]
+struct RttEstimator {
+    // Using u32 instead of Duration to save space (Duration is i64)
+    rtt: u32,
+    deviation: u32,
+    timestamp: Option<(Instant, TcpSeqNumber)>,
+    max_seq_sent: Option<TcpSeqNumber>,
+    rto_count: u8,
+}
+
+impl Default for RttEstimator {
+    fn default() -> Self {
+        Self {
+            rtt: RTTE_INITIAL_RTT,
+            deviation: RTTE_INITIAL_DEV,
+            timestamp: None,
+            max_seq_sent: None,
+            rto_count: 0,
+        }
+    }
+}
+
+impl RttEstimator {
+    fn retransmission_timeout(&self) -> Duration {
+        let margin = RTTE_MIN_MARGIN.max(self.deviation * 4);
+        let ms = (self.rtt + margin).max(RTTE_MIN_RTO).min(RTTE_MAX_RTO);
+        Duration::from_millis(ms as u64)
+    }
+
+    fn sample(&mut self, new_rtt: u32) {
+        // "Congestion Avoidance and Control", Van Jacobson, Michael J. Karels, 1988
+        self.rtt = (self.rtt * 7 + new_rtt + 7) / 8;
+        let diff = (self.rtt as i32 - new_rtt as i32 ).abs() as u32;
+        self.deviation = (self.deviation * 3 + diff + 3) / 4;
+        
+        self.rto_count = 0;
+
+        let rto = self.retransmission_timeout().millis();
+        net_trace!("rtte: sample={:?} rtt={:?} dev={:?} rto={:?}", new_rtt, self.rtt, self.deviation, rto);
+    }
+
+    fn on_send(&mut self, timestamp: Instant, seq: TcpSeqNumber) {
+        if self.max_seq_sent.map(|max_seq_sent| seq > max_seq_sent).unwrap_or(true) {
+            self.max_seq_sent = Some(seq);
+            if self.timestamp.is_none() {
+                self.timestamp = Some((timestamp, seq));
+                net_trace!("rtte: sampling at seq={:?}", seq);
+            }
+        }
+    }
+
+    fn on_ack(&mut self, timestamp: Instant, seq: TcpSeqNumber) {
+        if let Some((sent_timestamp, sent_seq)) = self.timestamp {
+            if seq >= sent_seq {
+                self.sample((timestamp - sent_timestamp).millis() as u32);
+                self.timestamp = None;
+            }
+        }
+    }
+
+    fn on_retransmit(&mut self) {
+        if self.timestamp.is_some() {
+            net_trace!("rtte: abort sampling due to retransmit");
+        }
+        self.timestamp = None;
+        self.rto_count = self.rto_count.saturating_add(1);
+        if self.rto_count >= 3 {
+            // This happens in 2 scenarios:
+            // - The RTT is higher than the initial estimate
+            // - The network conditions change, suddenly making the RTT much higher
+            // In these cases, the estimator can get stuck, because it can't sample because
+            // all packets sent would incur a retransmit. To avoid this, force an estimate
+            // increase if we see 3 consecutive retransmissions without any successful sample.
+            self.rto_count = 0;
+            self.rtt *= 2;
+            let rto = self.retransmission_timeout().millis();
+            net_trace!("rtte: too many retransmissions, increasing: rtt={:?} dev={:?} rto={:?}", self.rtt, self.deviation, rto);
+        }
+    }
+}
+
 #[derive(Debug, Clone, Copy, PartialEq)]
 enum Timer {
     Idle {
@@ -69,7 +162,6 @@ enum Timer {
     }
 }
 
-const RETRANSMIT_DELAY: Duration = Duration { millis: 100 };
 const CLOSE_DELAY:      Duration = Duration { millis: 10_000 };
 
 impl Default for Timer {
@@ -140,12 +232,12 @@ impl Timer {
         }
     }
 
-    fn set_for_retransmit(&mut self, timestamp: Instant) {
+    fn set_for_retransmit(&mut self, timestamp: Instant, delay: Duration) {
         match *self {
             Timer::Idle { .. } | Timer::FastRetransmit { .. } => {
                 *self = Timer::Retransmit {
-                    expires_at: timestamp + RETRANSMIT_DELAY,
-                    delay:      RETRANSMIT_DELAY,
+                    expires_at: timestamp + delay,
+                    delay:      delay,
                 }
             }
             Timer::Retransmit { expires_at, delay }
@@ -189,6 +281,7 @@ pub struct TcpSocket<'a> {
     pub(crate) meta: SocketMeta,
     state:           State,
     timer:           Timer,
+    rtte:            RttEstimator,
     assembler:       Assembler,
     rx_buffer:       SocketBuffer<'a>,
     rx_fin_received: bool,
@@ -279,6 +372,7 @@ impl<'a> TcpSocket<'a> {
             meta:            SocketMeta::default(),
             state:           State::Closed,
             timer:           Timer::default(),
+            rtte:            RttEstimator::default(),
             assembler:       Assembler::new(rx_buffer.capacity()),
             tx_buffer:       tx_buffer,
             rx_buffer:       rx_buffer,
@@ -463,6 +557,7 @@ impl<'a> TcpSocket<'a> {
 
         self.state           = State::Closed;
         self.timer           = Timer::default();
+        self.rtte            = RttEstimator::default();
         self.assembler       = Assembler::new(self.rx_buffer.capacity());
         self.tx_buffer.clear();
         self.rx_buffer.clear();
@@ -1154,6 +1249,8 @@ impl<'a> TcpSocket<'a> {
                                self.meta.handle, self.local_endpoint, self.remote_endpoint);
                     ack_of_fin = true;
                 }
+
+                self.rtte.on_ack(timestamp, ack_number);
             }
         }
 
@@ -1538,6 +1635,7 @@ impl<'a> TcpSocket<'a> {
                            self.meta.handle, self.local_endpoint, self.remote_endpoint,
                            retransmit_delta);
                 self.remote_last_seq = self.local_seq_no;
+                self.rtte.on_retransmit();
             }
         }
 
@@ -1723,10 +1821,14 @@ impl<'a> TcpSocket<'a> {
         self.remote_last_ack = repr.ack_number;
         self.remote_last_win = repr.window_len;
 
+        if repr.segment_len() > 0 {
+            self.rtte.on_send(timestamp, repr.seq_number + repr.segment_len());
+        }
+
         if !self.seq_to_transmit() && repr.segment_len() > 0 {
             // If we've transmitted all data we could (and there was something at all,
             // data or flag, to transmit, not just an ACK), wind up the retransmit timer.
-            self.timer.set_for_retransmit(timestamp);
+            self.timer.set_for_retransmit(timestamp, self.rtte.retransmission_timeout());
         }
 
         if self.state == State::Closed {
@@ -3646,7 +3748,7 @@ mod test {
             ..RECV_TEMPL
         }));
         recv!(s, time 1050, Err(Error::Exhausted));
-        recv!(s, time 1100, Ok(TcpRepr {
+        recv!(s, time 2000, Ok(TcpRepr {
             seq_number: LOCAL_SEQ + 1,
             ack_number: Some(REMOTE_SEQ + 1),
             payload:    &b"abcdef"[..],
@@ -3678,21 +3780,21 @@ mod test {
 
         recv!(s, time 50, Err(Error::Exhausted));
 
-        recv!(s, time 100, Ok(TcpRepr {
+        recv!(s, time 1000, Ok(TcpRepr {
             control:    TcpControl::None,
             seq_number: LOCAL_SEQ + 1,
             ack_number: Some(REMOTE_SEQ + 1),
             payload:    &b"abcdef"[..],
             ..RECV_TEMPL
         }), exact);
-        recv!(s, time 150, Ok(TcpRepr {
+        recv!(s, time 1500, Ok(TcpRepr {
             control:    TcpControl::Psh,
             seq_number: LOCAL_SEQ + 1 + 6,
             ack_number: Some(REMOTE_SEQ + 1),
             payload:    &b"012345"[..],
             ..RECV_TEMPL
         }), exact);
-        recv!(s, time 200, Err(Error::Exhausted));
+        recv!(s, time 1550, Err(Error::Exhausted));
     }
 
     #[test]
@@ -3705,7 +3807,7 @@ mod test {
             max_seg_size: Some(BASE_MSS),
             ..RECV_TEMPL
         }));
-        recv!(s, time 150, Ok(TcpRepr { // retransmit
+        recv!(s, time 750, Ok(TcpRepr { // retransmit
             control:    TcpControl::Syn,
             seq_number: LOCAL_SEQ,
             ack_number: Some(REMOTE_SEQ + 1),
@@ -4527,9 +4629,9 @@ mod test {
     #[test]
     fn test_established_timeout() {
         let mut s = socket_established();
-        s.set_timeout(Some(Duration::from_millis(200)));
+        s.set_timeout(Some(Duration::from_millis(1000)));
         recv!(s, time 250, Err(Error::Exhausted));
-        assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(450)));
+        assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(1250)));
         s.send_slice(b"abcdef").unwrap();
         assert_eq!(s.poll_at(), PollAt::Now);
         recv!(s, time 255, Ok(TcpRepr {
@@ -4538,15 +4640,15 @@ mod test {
             payload:    &b"abcdef"[..],
             ..RECV_TEMPL
         }));
-        assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(355)));
-        recv!(s, time 355, Ok(TcpRepr {
+        assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(955)));
+        recv!(s, time 955, Ok(TcpRepr {
             seq_number: LOCAL_SEQ + 1,
             ack_number: Some(REMOTE_SEQ + 1),
             payload:    &b"abcdef"[..],
             ..RECV_TEMPL
         }));
-        assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(455)));
-        recv!(s, time 500, Ok(TcpRepr {
+        assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(1255)));
+        recv!(s, time 1255, Ok(TcpRepr {
             control:    TcpControl::Rst,
             seq_number: LOCAL_SEQ + 1 + 6,
             ack_number: Some(REMOTE_SEQ + 1),
@@ -4596,15 +4698,14 @@ mod test {
     #[test]
     fn test_fin_wait_1_timeout() {
         let mut s = socket_fin_wait_1();
-        s.set_timeout(Some(Duration::from_millis(200)));
+        s.set_timeout(Some(Duration::from_millis(1000)));
         recv!(s, time 100, Ok(TcpRepr {
             control:    TcpControl::Fin,
             seq_number: LOCAL_SEQ + 1,
             ack_number: Some(REMOTE_SEQ + 1),
             ..RECV_TEMPL
         }));
-        assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(200)));
-        recv!(s, time 400, Ok(TcpRepr {
+        recv!(s, time 1100, Ok(TcpRepr {
             control:    TcpControl::Rst,
             seq_number: LOCAL_SEQ + 1 + 1,
             ack_number: Some(REMOTE_SEQ + 1),
@@ -4616,15 +4717,14 @@ mod test {
     #[test]
     fn test_last_ack_timeout() {
         let mut s = socket_last_ack();
-        s.set_timeout(Some(Duration::from_millis(200)));
+        s.set_timeout(Some(Duration::from_millis(1000)));
         recv!(s, time 100, Ok(TcpRepr {
             control:    TcpControl::Fin,
             seq_number: LOCAL_SEQ + 1,
             ack_number: Some(REMOTE_SEQ + 1 + 1),
             ..RECV_TEMPL
         }));
-        assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(200)));
-        recv!(s, time 400, Ok(TcpRepr {
+        recv!(s, time 1100, Ok(TcpRepr {
             control:    TcpControl::Rst,
             seq_number: LOCAL_SEQ + 1 + 1,
             ack_number: Some(REMOTE_SEQ + 1 + 1),
@@ -5052,13 +5152,14 @@ mod test {
 
     #[test]
     fn test_timer_retransmit() {
+        const RTO: Duration = Duration::from_millis(100);
         let mut r = Timer::default();
         assert_eq!(r.should_retransmit(Instant::from_secs(1)), None);
-        r.set_for_retransmit(Instant::from_millis(1000));
+        r.set_for_retransmit(Instant::from_millis(1000), RTO);
         assert_eq!(r.should_retransmit(Instant::from_millis(1000)), None);
         assert_eq!(r.should_retransmit(Instant::from_millis(1050)), None);
         assert_eq!(r.should_retransmit(Instant::from_millis(1101)), Some(Duration::from_millis(101)));
-        r.set_for_retransmit(Instant::from_millis(1101));
+        r.set_for_retransmit(Instant::from_millis(1101), RTO);
         assert_eq!(r.should_retransmit(Instant::from_millis(1101)), None);
         assert_eq!(r.should_retransmit(Instant::from_millis(1150)), None);
         assert_eq!(r.should_retransmit(Instant::from_millis(1200)), None);
@@ -5067,4 +5168,23 @@ mod test {
         assert_eq!(r.should_retransmit(Instant::from_millis(1350)), None);
     }
 
+    #[test]
+    fn test_rtt_estimator() {
+        #[cfg(feature = "log")]
+        init_logger();
+
+        let mut r = RttEstimator::default();
+
+        let rtos = &[
+            751, 766, 755, 731, 697, 656, 613, 567,
+            523, 484, 445, 411, 378, 350, 322, 299,
+            280, 261, 243, 229, 215, 206, 197, 188
+        ];
+
+        for &rto in rtos {
+            r.sample(100);
+            assert_eq!(r.retransmission_timeout(), Duration::from_millis(rto));
+        }
+    }
+
 }