浏览代码

Merge #538

538: TCP fuzz fixes r=Dirbaio a=Dirbaio

Fixes panics and hangs found by whole-stack fuzzing. See individual commit messages.

Will post the whole-stack fuzz target when it's fully clean.

Co-authored-by: Dario Nieuwenhuis <dirbaio@dirbaio.net>
bors[bot] 3 年之前
父节点
当前提交
29751ae403
共有 1 个文件被更改,包括 183 次插入55 次删除
  1. 183 55
      src/socket/tcp.rs

+ 183 - 55
src/socket/tcp.rs

@@ -110,7 +110,7 @@ impl RttEstimator {
 
         self.rto_count = 0;
 
-        let rto = self.retransmission_timeout().millis();
+        let rto = self.retransmission_timeout().total_millis();
         net_trace!(
             "rtte: sample={:?} rtt={:?} dev={:?} rto={:?}",
             new_rtt,
@@ -137,7 +137,7 @@ impl RttEstimator {
     fn on_ack(&mut self, timestamp: Instant, seq: TcpSeqNumber) {
         if let Some((sent_timestamp, sent_seq)) = self.timestamp {
             if seq >= sent_seq {
-                self.sample((timestamp - sent_timestamp).millis() as u32);
+                self.sample((timestamp - sent_timestamp).total_millis() as u32);
                 self.timestamp = None;
             }
         }
@@ -158,7 +158,7 @@ impl RttEstimator {
             // increase if we see 3 consecutive retransmissions without any successful sample.
             self.rto_count = 0;
             self.rtt = RTTE_MAX_RTO.min(self.rtt * 2);
-            let rto = self.retransmission_timeout().millis();
+            let rto = self.retransmission_timeout().total_millis();
             net_trace!(
                 "rtte: too many retransmissions, increasing: rtt={:?} dev={:?} rto={:?}",
                 self.rtt,
@@ -1355,24 +1355,6 @@ impl<'a> TcpSocket<'a> {
                 );
                 return Err(Error::Dropped);
             }
-            // Any ACK in the SYN-SENT state must have the SYN flag set.
-            (
-                State::SynSent,
-                &TcpRepr {
-                    control: TcpControl::None,
-                    ack_number: Some(_),
-                    ..
-                },
-            ) => {
-                net_debug!(
-                    "{}:{}:{}: expecting a SYN|ACK",
-                    self.meta.handle,
-                    self.local_endpoint,
-                    self.remote_endpoint
-                );
-                self.abort();
-                return Err(Error::Dropped);
-            }
             // SYN|ACK in the SYN-SENT state must have the exact ACK number.
             (
                 State::SynSent,
@@ -1392,6 +1374,17 @@ impl<'a> TcpSocket<'a> {
                     return Err(Error::Dropped);
                 }
             }
+            // Anything else in the SYN-SENT state is invalid.
+            (State::SynSent, _) => {
+                net_debug!(
+                    "{}:{}:{}: expecting a SYN|ACK",
+                    self.meta.handle,
+                    self.local_endpoint,
+                    self.remote_endpoint
+                );
+                self.abort();
+                return Err(Error::Dropped);
+            }
             // Every acknowledgement must be for transmitted but unacknowledged data.
             (
                 _,
@@ -1403,9 +1396,14 @@ impl<'a> TcpSocket<'a> {
                 let unacknowledged = self.tx_buffer.len() + control_len;
 
                 // Acceptable ACK range (both inclusive)
-                let ack_min = self.local_seq_no;
+                let mut ack_min = self.local_seq_no;
                 let ack_max = self.local_seq_no + unacknowledged;
 
+                // If we have sent a SYN, it MUST be acknowledged.
+                if sent_syn {
+                    ack_min += 1;
+                }
+
                 if ack_number < ack_min {
                     net_debug!(
                         "{}:{}:{}: duplicate ACK ({} not in {}...{})",
@@ -1506,23 +1504,26 @@ impl<'a> TcpSocket<'a> {
         let mut ack_of_fin = false;
         if repr.control != TcpControl::Rst {
             if let Some(ack_number) = repr.ack_number {
-                ack_len = ack_number - self.local_seq_no;
-                // There could have been no data sent before the SYN, so we always remove it
-                // from the sequence space.
-                if sent_syn {
-                    ack_len -= 1
-                }
-                // We could've sent data before the FIN, so only remove FIN from the sequence
-                // space if all of that data is acknowledged.
-                if sent_fin && self.tx_buffer.len() + 1 == ack_len {
-                    ack_len -= 1;
-                    net_trace!(
-                        "{}:{}:{}: received ACK of FIN",
-                        self.meta.handle,
-                        self.local_endpoint,
-                        self.remote_endpoint
-                    );
-                    ack_of_fin = true;
+                // Sequence number corresponding to the first byte in `tx_buffer`.
+                // This normally equals `local_seq_no`, but is 1 higher if we ahve sent a SYN,
+                // as the SYN occupies 1 sequence number "before" the data.
+                let tx_buffer_start_seq = self.local_seq_no + (sent_syn as usize);
+
+                if ack_number >= tx_buffer_start_seq {
+                    ack_len = ack_number - tx_buffer_start_seq;
+
+                    // We could've sent data before the FIN, so only remove FIN from the sequence
+                    // space if all of that data is acknowledged.
+                    if sent_fin && self.tx_buffer.len() + 1 == ack_len {
+                        ack_len -= 1;
+                        net_trace!(
+                            "{}:{}:{}: received ACK of FIN",
+                            self.meta.handle,
+                            self.local_endpoint,
+                            self.remote_endpoint
+                        );
+                        ack_of_fin = true;
+                    }
                 }
 
                 self.rtte.on_ack(cx.now, ack_number);
@@ -1575,16 +1576,26 @@ impl<'a> TcpSocket<'a> {
             // SYN packets in the LISTEN state change it to SYN-RECEIVED.
             (State::Listen, TcpControl::Syn) => {
                 net_trace!("{}:{}: received SYN", self.meta.handle, self.local_endpoint);
+                if let Some(max_seg_size) = repr.max_seg_size {
+                    if max_seg_size == 0 {
+                        net_trace!(
+                            "{}:{}:{}: received SYNACK with zero MSS, ignoring",
+                            self.meta.handle,
+                            self.local_endpoint,
+                            self.remote_endpoint
+                        );
+                        return Ok(None);
+                    }
+                    self.remote_mss = max_seg_size as usize
+                }
+
                 self.local_endpoint = IpEndpoint::new(ip_repr.dst_addr(), repr.dst_port);
                 self.remote_endpoint = IpEndpoint::new(ip_repr.src_addr(), repr.src_port);
                 // FIXME: use something more secure here
-                self.local_seq_no = TcpSeqNumber(-repr.seq_number.0);
+                self.local_seq_no = TcpSeqNumber(!repr.seq_number.0);
                 self.remote_seq_no = repr.seq_number + 1;
                 self.remote_last_seq = self.local_seq_no;
                 self.remote_has_sack = repr.sack_permitted;
-                if let Some(max_seg_size) = repr.max_seg_size {
-                    self.remote_mss = max_seg_size as usize
-                }
                 self.remote_win_scale = repr.window_scale;
                 // Remote doesn't support window scaling, don't do it.
                 if self.remote_win_scale.is_none() {
@@ -1618,6 +1629,19 @@ impl<'a> TcpSocket<'a> {
                     self.local_endpoint,
                     self.remote_endpoint
                 );
+                if let Some(max_seg_size) = repr.max_seg_size {
+                    if max_seg_size == 0 {
+                        net_trace!(
+                            "{}:{}:{}: received SYNACK with zero MSS, ignoring",
+                            self.meta.handle,
+                            self.local_endpoint,
+                            self.remote_endpoint
+                        );
+                        return Ok(None);
+                    }
+                    self.remote_mss = max_seg_size as usize;
+                }
+
                 self.local_endpoint = IpEndpoint::new(ip_repr.dst_addr(), repr.dst_port);
                 self.remote_seq_no = repr.seq_number + 1;
                 self.remote_last_seq = self.local_seq_no + 1;
@@ -1628,9 +1652,6 @@ impl<'a> TcpSocket<'a> {
                     self.remote_win_shift = 0;
                 }
 
-                if let Some(max_seg_size) = repr.max_seg_size {
-                    self.remote_mss = max_seg_size as usize;
-                }
                 self.set_state(State::Established);
                 self.timer.set_for_idle(cx.now, self.keep_alive);
             }
@@ -1974,6 +1995,11 @@ impl<'a> TcpSocket<'a> {
         // Have we sent data that hasn't been ACKed yet?
         let data_in_flight = self.remote_last_seq != self.local_seq_no;
 
+        // If we want to send a SYN and we haven't done so, do it!
+        if matches!(self.state, State::SynSent | State::SynReceived) && !data_in_flight {
+            return true;
+        }
+
         // max sequence number we can send.
         let max_send_seq =
             self.local_seq_no + core::cmp::min(self.remote_win_len, self.tx_buffer.len());
@@ -2077,7 +2103,19 @@ impl<'a> TcpSocket<'a> {
                     self.remote_endpoint,
                     retransmit_delta
                 );
+
+                // Rewind "last sequence number sent", as if we never
+                // had sent them. This will cause all data in the queue
+                // to be sent again.
                 self.remote_last_seq = self.local_seq_no;
+
+                // Clear the `should_retransmit` state. If we can't retransmit right
+                // now for whatever reason (like zero window), this avoids an
+                // infinite polling loop where `poll_at` returns `Now` but `dispatch`
+                // can't actually do anything.
+                self.timer.set_for_idle(cx.now, self.keep_alive);
+
+                // Inform RTTE, so that it can avoid bogus measurements.
                 self.rtte.on_retransmit();
             }
         }
@@ -2115,14 +2153,6 @@ impl<'a> TcpSocket<'a> {
                 self.local_endpoint,
                 self.remote_endpoint
             );
-        } else if self.timer.should_retransmit(cx.now).is_some() {
-            // If we have packets to retransmit, do it.
-            net_trace!(
-                "{}:{}:{}: retransmit timer expired",
-                self.meta.handle,
-                self.local_endpoint,
-                self.remote_endpoint
-            );
         } else if self.timer.should_keep_alive(cx.now) {
             // If we need to transmit a keep-alive packet, do it.
             net_trace!(
@@ -2459,7 +2489,7 @@ mod test {
         port: REMOTE_PORT,
     };
     const LOCAL_SEQ: TcpSeqNumber = TcpSeqNumber(10000);
-    const REMOTE_SEQ: TcpSeqNumber = TcpSeqNumber(-10000);
+    const REMOTE_SEQ: TcpSeqNumber = TcpSeqNumber(-10001);
 
     const SEND_IP_TEMPL: IpRepr = IpRepr::Unspecified {
         src_addr: MOCK_IP_ADDR_1,
@@ -3017,6 +3047,62 @@ mod test {
         sanity!(s, socket_established());
     }
 
+    #[test]
+    fn test_syn_received_ack_too_low() {
+        let mut s = socket_syn_received();
+        recv!(
+            s,
+            [TcpRepr {
+                control: TcpControl::Syn,
+                seq_number: LOCAL_SEQ,
+                ack_number: Some(REMOTE_SEQ + 1),
+                max_seg_size: Some(BASE_MSS),
+                ..RECV_TEMPL
+            }]
+        );
+        send!(
+            s,
+            TcpRepr {
+                seq_number: REMOTE_SEQ + 1,
+                ack_number: Some(LOCAL_SEQ), // wrong
+                ..SEND_TEMPL
+            },
+            Err(Error::Dropped)
+        );
+        assert_eq!(s.state, State::SynReceived);
+    }
+
+    #[test]
+    fn test_syn_received_ack_too_high() {
+        let mut s = socket_syn_received();
+        recv!(
+            s,
+            [TcpRepr {
+                control: TcpControl::Syn,
+                seq_number: LOCAL_SEQ,
+                ack_number: Some(REMOTE_SEQ + 1),
+                max_seg_size: Some(BASE_MSS),
+                ..RECV_TEMPL
+            }]
+        );
+        send!(
+            s,
+            TcpRepr {
+                seq_number: REMOTE_SEQ + 1,
+                ack_number: Some(LOCAL_SEQ + 2), // wrong
+                ..SEND_TEMPL
+            },
+            // TODO is this correct? probably not
+            Ok(Some(TcpRepr {
+                control: TcpControl::None,
+                seq_number: LOCAL_SEQ + 1,
+                ack_number: Some(REMOTE_SEQ + 1),
+                ..RECV_TEMPL
+            }))
+        );
+        assert_eq!(s.state, State::SynReceived);
+    }
+
     #[test]
     fn test_syn_received_fin() {
         let mut s = socket_syn_received();
@@ -5486,6 +5572,48 @@ mod test {
         );
     }
 
+    #[test]
+    fn test_fast_retransmit_zero_window() {
+        let mut s = socket_established();
+
+        send!(s, time 1000, TcpRepr {
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
+            ..SEND_TEMPL
+        });
+
+        s.send_slice(b"abc").unwrap();
+
+        recv!(s, time 0, Ok(TcpRepr {
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload:    &b"abc"[..],
+            ..RECV_TEMPL
+        }));
+
+        // 3 dup acks
+        send!(s, time 1050, TcpRepr {
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
+            ..SEND_TEMPL
+        });
+        send!(s, time 1050, TcpRepr {
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
+            ..SEND_TEMPL
+        });
+        send!(s, time 1050, TcpRepr {
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
+            window_len: 0, // boom
+            ..SEND_TEMPL
+        });
+
+        // even though we're in "fast retransmit", we shouldn't
+        // force-send anything because the remote's window is full.
+        recv!(s, Err(Error::Exhausted));
+    }
+
     // =========================================================================================//
     // Tests for window management.
     // =========================================================================================//