浏览代码

Implement TCP timeouts.

whitequark 7 年之前
父节点
当前提交
64a82709d4
共有 3 个文件被更改,包括 194 次插入19 次删除
  1. 7 8
      README.md
  2. 178 11
      src/socket/tcp.rs
  3. 9 0
      src/wire/tcp.rs

+ 7 - 8
README.md

@@ -46,7 +46,7 @@ The only supported internetworking protocol is IPv4.
 
 The UDP protocol is supported over IPv4.
 
-  * UDP header checksum is always generated and validated.
+  * Header checksum is always generated and validated.
   * In response to a packet arriving at a port without a listening socket,
     an ICMP destination unreachable message is generated.
 
@@ -54,15 +54,14 @@ The UDP protocol is supported over IPv4.
 
 The TCP protocol is supported over IPv4. Server and client sockets are supported.
 
-  * TCP header checksum is generated and validated.
+  * Header checksum is generated and validated.
   * Maximum segment size is negotiated.
-  * Multiple packets will be transmitted without waiting for an acknowledgement.
-  * Lost packets will be retransmitted with exponential backoff, starting at
-    a fixed delay of 100 ms.
+  * Multiple packets are transmitted without waiting for an acknowledgement.
+  * Lost packets are retransmitted with exponential backoff, starting at a fixed delay of 100 ms.
   * Sending keep-alive packets is supported, with a configurable interval.
-  * After arriving at the TIME-WAIT state, sockets will close after a fixed delay of 10 s.
-  * TCP urgent pointer is **not** supported; any urgent octets will be received alongside
-    data octets.
+  * Connection, retransmission and keep-alive timeouts are supported, with a configurable duration.
+  * After arriving at the TIME-WAIT state, sockets close after a fixed delay of 10 s.
+  * Urgent pointer is **not** supported; any urgent octets will be received alongside data octets.
   * Reassembly of out-of-order segments is **not** supported.
   * Silly window syndrome avoidance is **not** supported for either transmission or reception.
   * Congestion control is **not** implemented.

+ 178 - 11
src/socket/tcp.rs

@@ -173,6 +173,8 @@ pub struct TcpSocket<'a> {
     timer:           Timer,
     rx_buffer:       SocketBuffer<'a>,
     tx_buffer:       SocketBuffer<'a>,
+    /// Interval after which, if no inbound packets are received, the connection is aborted.
+    timeout:         Option<u64>,
     /// Interval at which keep-alive packets will be sent.
     keep_alive:      Option<u64>,
     /// Address passed to listen(). Listen address is set when listen() is called and
@@ -207,6 +209,8 @@ pub struct TcpSocket<'a> {
     remote_win_len:  usize,
     /// The maximum number of data octets that the remote side may receive.
     remote_mss:      usize,
+    /// The timestamp of the last packet received.
+    remote_last_ts:  Option<u64>,
 }
 
 const DEFAULT_MSS: usize = 536;
@@ -227,6 +231,7 @@ impl<'a> TcpSocket<'a> {
             timer:           Timer::default(),
             tx_buffer:       tx_buffer.into(),
             rx_buffer:       rx_buffer.into(),
+            timeout:         None,
             keep_alive:      None,
             listen_address:  IpAddress::default(),
             local_endpoint:  IpEndpoint::default(),
@@ -238,6 +243,7 @@ impl<'a> TcpSocket<'a> {
             remote_last_win: 0,
             remote_win_len:  0,
             remote_mss:      DEFAULT_MSS,
+            remote_last_ts:  None,
         })
     }
 
@@ -255,6 +261,29 @@ impl<'a> TcpSocket<'a> {
         self.debug_id = id
     }
 
+    /// Return the timeout duration.
+    ///
+    /// See also the [set_timeout](#method.set_timeout) method.
+    pub fn timeout(&self) -> Option<u64> {
+        self.timeout
+    }
+
+    /// Set the timeout duration.
+    ///
+    /// A socket with a timeout duration set will abort the connection if either of the following
+    /// occurs:
+    ///
+    ///   * After a [connect](#method.connect) call, the remote endpoint does not respond within
+    ///     the specified duration;
+    ///   * After establishing a connection, there is data in the transmit buffer and the remote
+    ///     endpoint exceeds the specified duration between any two packets it sends;
+    ///   * After enabling [keep-alive](#method.set_keep_alive), the remote endpoint exceeds
+    ///     the specified duration between any two packets it sends.
+    pub fn set_timeout(&mut self, duration: Option<u64>) {
+        self.timeout = duration;
+        self.remote_last_ts = None;
+    }
+
     /// Return the keep-alive interval.
     ///
     /// See also the [set_keep_alive](#method.set_keep_alive) method.
@@ -264,7 +293,7 @@ impl<'a> TcpSocket<'a> {
 
     /// Set the keep-alive interval.
     ///
-    /// An idle socket with a set keep-alive interval will transmit a "challenge ACK" packet
+    /// An idle socket with a keep-alive interval set will transmit a "challenge ACK" packet
     /// every time it receives no communication during that interval. As a result, three things
     /// may happen:
     ///
@@ -305,6 +334,7 @@ impl<'a> TcpSocket<'a> {
         self.state           = State::Closed;
         self.timer           = Timer::default();
         self.keep_alive      = None;
+        self.timeout         = None;
         self.listen_address  = IpAddress::default();
         self.local_endpoint  = IpEndpoint::default();
         self.remote_endpoint = IpEndpoint::default();
@@ -315,6 +345,7 @@ impl<'a> TcpSocket<'a> {
         self.remote_last_win = 0;
         self.remote_win_len  = 0;
         self.remote_mss      = DEFAULT_MSS;
+        self.remote_last_ts  = None;
         self.tx_buffer.clear();
         self.rx_buffer.clear();
     }
@@ -541,6 +572,11 @@ impl<'a> TcpSocket<'a> {
     pub fn send(&mut self, size: usize) -> Result<&mut [u8]> {
         if !self.may_send() { return Err(Error::Illegal) }
 
+        // The connection might have been idle for a long time, and so remote_last_ts
+        // would be far in the past. Unless we clear it here, we'll abort the connection
+        // down over in dispatch() by erroneously detecting it as timed out.
+        if self.tx_buffer.is_empty() { self.remote_last_ts = None }
+
         let _old_length = self.tx_buffer.len();
         let buffer = self.tx_buffer.enqueue_many(size);
         if buffer.len() > 0 {
@@ -561,6 +597,9 @@ impl<'a> TcpSocket<'a> {
     pub fn send_slice(&mut self, data: &[u8]) -> Result<usize> {
         if !self.may_send() { return Err(Error::Illegal) }
 
+        // See above.
+        if self.tx_buffer.is_empty() { self.remote_last_ts = None }
+
         let _old_length = self.tx_buffer.len();
         let enqueued = self.tx_buffer.enqueue_slice(data);
         if enqueued != 0 {
@@ -1051,7 +1090,8 @@ impl<'a> TcpSocket<'a> {
             }
         }
 
-        // Update window length.
+        // Update remote state.
+        self.remote_last_ts = Some(timestamp);
         self.remote_win_len = repr.window_len as usize;
 
         if ack_len > 0 {
@@ -1080,6 +1120,15 @@ impl<'a> TcpSocket<'a> {
         Ok(None)
     }
 
+    fn timed_out(&self, timestamp: u64) -> bool {
+        match (self.remote_last_ts, self.timeout) {
+            (Some(remote_last_ts), Some(timeout)) =>
+                timestamp >= remote_last_ts + timeout,
+            (_, _) =>
+                false
+        }
+    }
+
     fn seq_to_transmit(&self, control: TcpControl) -> bool {
         self.remote_last_seq < self.local_seq_no + self.tx_buffer.len() + control.len()
     }
@@ -1097,7 +1146,23 @@ impl<'a> TcpSocket<'a> {
             where F: FnOnce((IpRepr, TcpRepr)) -> Result<()> {
         if !self.remote_endpoint.is_specified() { return Err(Error::Exhausted) }
 
-        if !self.seq_to_transmit(TcpControl::None) {
+        if self.remote_last_ts.is_none() {
+            // We get here in exactly two cases:
+            //  1) This socket just transitioned into SYN-SENT.
+            //  2) This socket had an empty transmit buffer and some data was added there.
+            // Both are similar in that the socket has been quiet for an indefinite
+            // period of time, it isn't anymore, and the local endpoint is talking.
+            // So, we start counting the timeout not from the last received packet
+            // but from the first transmitted one.
+            self.remote_last_ts = Some(timestamp);
+        }
+
+        if self.timed_out(timestamp) {
+            // If a timeout expires, we should abort the connection.
+            net_debug!("[{}]{}:{}: timeout exceeded",
+                       self.debug_id, self.local_endpoint, self.remote_endpoint);
+            self.set_state(State::Closed);
+        } else if !self.seq_to_transmit(TcpControl::None) {
             if let Some(retransmit_delta) = self.timer.should_retransmit(timestamp) {
                 // If a retransmit timer expired, we should resend data starting at the last ACK.
                 net_debug!("[{}]{}:{}: retransmitting at t+{}ms",
@@ -1215,7 +1280,7 @@ impl<'a> TcpSocket<'a> {
         }
 
         let is_keep_alive;
-        if self.timer.should_keep_alive(timestamp) {
+        if self.timer.should_keep_alive(timestamp) && repr.is_empty() {
             net_trace!("[{}]{}:{}: sending a keep-alive",
                        self.debug_id, self.local_endpoint, self.remote_endpoint);
             repr.seq_number = repr.seq_number - 1;
@@ -1286,13 +1351,18 @@ impl<'a> TcpSocket<'a> {
     }
 
     pub(crate) fn poll_at(&self) -> Option<u64> {
-        self.timer.poll_at().or_else(|| {
-            if self.tx_buffer.is_empty() {
-                None
-            } else {
-                Some(0)
-            }
-        })
+        self.timer.poll_at()
+            .or_else(|| {
+                match (self.remote_last_ts, self.timeout) {
+                    (Some(remote_last_ts), Some(timeout))
+                            if !self.tx_buffer.is_empty() =>
+                        Some(remote_last_ts + timeout),
+                    (None, Some(timeout)) =>
+                        Some(0),
+                    (_, _) =>
+                        None
+                }
+            })
     }
 }
 
@@ -3037,6 +3107,103 @@ mod test {
         }]);
     }
 
+    // =========================================================================================//
+    // Tests for timeouts
+    // =========================================================================================//
+
+    #[test]
+    fn test_connect_timeout() {
+        let mut s = socket();
+        s.local_seq_no = LOCAL_SEQ;
+        s.connect(REMOTE_END, LOCAL_END.port).unwrap();
+        s.set_timeout(Some(100));
+        recv!(s, time 150, Ok(TcpRepr {
+            control:    TcpControl::Syn,
+            seq_number: LOCAL_SEQ,
+            ack_number: None,
+            max_seg_size: Some(1480),
+            ..RECV_TEMPL
+        }));
+        assert_eq!(s.state, State::SynSent);
+        assert_eq!(s.poll_at(), Some(250));
+        recv!(s, time 250, Ok(TcpRepr {
+            control:    TcpControl::Rst,
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(TcpSeqNumber(0)),
+            ..RECV_TEMPL
+        }));
+        assert_eq!(s.state, State::Closed);
+    }
+
+    #[test]
+    fn test_established_timeout() {
+        let mut s = socket_established();
+        s.set_timeout(Some(200));
+        recv!(s, time 250, Err(Error::Exhausted));
+        assert_eq!(s.poll_at(), None);
+        s.send_slice(b"abcdef").unwrap();
+        assert_eq!(s.poll_at(), Some(0));
+        recv!(s, time 255, Ok(TcpRepr {
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload:    &b"abcdef"[..],
+            ..RECV_TEMPL
+        }));
+        assert_eq!(s.poll_at(), Some(355));
+        recv!(s, time 355, Ok(TcpRepr {
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload:    &b"abcdef"[..],
+            ..RECV_TEMPL
+        }));
+        assert_eq!(s.poll_at(), Some(455));
+        recv!(s, time 500, Ok(TcpRepr {
+            control:    TcpControl::Rst,
+            seq_number: LOCAL_SEQ + 1 + 6,
+            ack_number: Some(REMOTE_SEQ + 1),
+            ..RECV_TEMPL
+        }));
+        assert_eq!(s.state, State::Closed);
+    }
+
+    #[test]
+    fn test_established_keep_alive_timeout() {
+        let mut s = socket_established();
+        s.set_keep_alive(Some(50));
+        s.set_timeout(Some(100));
+        recv!(s, time 100, Ok(TcpRepr {
+            seq_number: LOCAL_SEQ,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload:    &[0],
+            ..RECV_TEMPL
+        }));
+        recv!(s, time 100, Err(Error::Exhausted));
+        assert_eq!(s.poll_at(), Some(150));
+        send!(s, time 105, TcpRepr {
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
+            ..SEND_TEMPL
+        });
+        assert_eq!(s.poll_at(), Some(155));
+        recv!(s, time 155, Ok(TcpRepr {
+            seq_number: LOCAL_SEQ,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload:    &[0],
+            ..RECV_TEMPL
+        }));
+        recv!(s, time 155, Err(Error::Exhausted));
+        assert_eq!(s.poll_at(), Some(205));
+        recv!(s, time 200, Err(Error::Exhausted));
+        recv!(s, time 205, Ok(TcpRepr {
+            control:    TcpControl::Rst,
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(REMOTE_SEQ + 1),
+            ..RECV_TEMPL
+        }));
+        recv!(s, time 205, Err(Error::Exhausted));
+        assert_eq!(s.state, State::Closed);
+    }
+
     // =========================================================================================//
     // Tests for keep-alive
     // =========================================================================================//

+ 9 - 0
src/wire/tcp.rs

@@ -740,6 +740,15 @@ impl<'a> Repr<'a> {
     pub fn segment_len(&self) -> usize {
         self.payload.len() + self.control.len()
     }
+
+    /// Return whether the segment has no flags set (except PSH) and no data.
+    pub fn is_empty(&self) -> bool {
+        match self.control {
+            _ if self.payload.len() != 0 => false,
+            Control::Syn  | Control::Fin | Control::Rst => false,
+            Control::None | Control::Psh => true
+        }
+    }
 }
 
 impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {