瀏覽代碼

Implement TCP keep-alive.

whitequark 7 年之前
父節點
當前提交
a04b32441b
共有 2 個文件被更改,包括 184 次插入34 次删除
  1. 1 1
      README.md
  2. 183 33
      src/socket/tcp.rs

+ 1 - 1
README.md

@@ -59,6 +59,7 @@ The TCP protocol is supported over IPv4. Server and client sockets are supported
   * Multiple packets will be transmitted without waiting for an acknowledgement.
   * Lost packets will be retransmitted with exponential backoff, starting at
     a fixed delay of 100 ms.
+  * Sending keep-alive packets is supported, with a configurable interval.
   * After arriving at the TIME-WAIT state, sockets will close after a fixed delay of 10 s.
   * TCP urgent pointer is **not** supported; any urgent octets will be received alongside
     data octets.
@@ -71,7 +72,6 @@ The TCP protocol is supported over IPv4. Server and client sockets are supported
   * Timestamping (used in round-trip time measurement and protection against wrapped sequences)
     is **not** supported.
   * Fast open is **not** supported when smoltcp initiates connection.
-  * Keepalive is **not** supported.
 
 ## Installation
 

+ 183 - 33
src/socket/tcp.rs

@@ -48,7 +48,9 @@ impl fmt::Display for State {
 
 #[derive(Debug, Clone, Copy, PartialEq)]
 enum Timer {
-    Idle,
+    Idle {
+        keep_alive_at: Option<u64>,
+    },
     Retransmit {
         expires_at: u64,
         delay:      u64
@@ -61,7 +63,23 @@ enum Timer {
 const RETRANSMIT_DELAY: u64 = 100;
 const CLOSE_DELAY:      u64 = 10_000;
 
+impl Default for Timer {
+    fn default() -> Timer {
+        Timer::Idle { keep_alive_at: None }
+    }
+}
+
 impl Timer {
+    fn should_keep_alive(&self, timestamp: u64) -> bool {
+        match *self {
+            Timer::Idle { keep_alive_at: Some(keep_alive_at) }
+                    if timestamp >= keep_alive_at => {
+                true
+            }
+            _ => false
+        }
+    }
+
     fn should_retransmit(&self, timestamp: u64) -> Option<u64> {
         match *self {
             Timer::Retransmit { expires_at, delay }
@@ -84,19 +102,40 @@ impl Timer {
 
     fn poll_at(&self) -> Option<u64> {
         match *self {
-            Timer::Idle => None,
+            Timer::Idle { keep_alive_at } => keep_alive_at,
             Timer::Retransmit { expires_at, .. } => Some(expires_at),
-            Timer::Close { expires_at, .. } => Some(expires_at),
+            Timer::Close { expires_at } => Some(expires_at),
         }
     }
 
-    fn reset(&mut self) {
-        *self = Timer::Idle
+    fn set_for_idle(&mut self, timestamp: u64, interval: Option<u64>) {
+        *self = Timer::Idle {
+            keep_alive_at: interval.map(|interval| timestamp + interval)
+        }
+    }
+
+    fn set_keep_alive(&mut self) {
+        match *self {
+            Timer::Idle { ref mut keep_alive_at }
+                    if keep_alive_at.is_none() => {
+                *keep_alive_at = Some(0)
+            }
+            _ => ()
+        }
+    }
+
+    fn rewind_keep_alive(&mut self, timestamp: u64, interval: Option<u64>) {
+        match self {
+            &mut Timer::Idle { ref mut keep_alive_at } => {
+                *keep_alive_at = interval.map(|interval| timestamp + interval)
+            }
+            _ => ()
+        }
     }
 
     fn set_for_retransmit(&mut self, timestamp: u64) {
         match *self {
-            Timer::Idle => {
+            Timer::Idle { .. } => {
                 *self = Timer::Retransmit {
                     expires_at: timestamp + RETRANSMIT_DELAY,
                     delay:      RETRANSMIT_DELAY,
@@ -134,6 +173,8 @@ pub struct TcpSocket<'a> {
     timer:           Timer,
     rx_buffer:       SocketBuffer<'a>,
     tx_buffer:       SocketBuffer<'a>,
+    /// Interval at which keep-alive packets will be sent.
+    keep_alive:      Option<u64>,
     /// Address passed to listen(). Listen address is set when listen() is called and
     /// used every time the socket is reset back to the LISTEN state.
     listen_address:  IpAddress,
@@ -183,9 +224,10 @@ impl<'a> TcpSocket<'a> {
         Socket::Tcp(TcpSocket {
             debug_id:        0,
             state:           State::Closed,
-            timer:           Timer::Idle,
+            timer:           Timer::default(),
             tx_buffer:       tx_buffer.into(),
             rx_buffer:       rx_buffer.into(),
+            keep_alive:      None,
             listen_address:  IpAddress::default(),
             local_endpoint:  IpEndpoint::default(),
             remote_endpoint: IpEndpoint::default(),
@@ -213,6 +255,34 @@ impl<'a> TcpSocket<'a> {
         self.debug_id = id
     }
 
+    /// Return the keep-alive interval.
+    ///
+    /// See also the [set_keep_alive](#method.set_keep_alive) method.
+    pub fn keep_alive(&self) -> Option<u64> {
+        self.keep_alive
+    }
+
+    /// Set the keep-alive interval.
+    ///
+    /// An idle socket with a set keep-alive interval will transmit a "challenge ACK" packet
+    /// every time it receives no communication during that interval. As a result, three things
+    /// may happen:
+    ///
+    ///   * The remote endpoint is fine and answers with an ACK packet.
+    ///   * The remote endpoint has rebooted and answers with an RST packet.
+    ///   * The remote endpoint has crashed and does not answer.
+    ///
+    /// The keep-alive functionality together with the timeout functionality allows to react
+    /// to these error conditions.
+    pub fn set_keep_alive(&mut self, interval: Option<u64>) {
+        self.keep_alive = interval;
+        if self.keep_alive.is_some() {
+            // If the connection is idle and we've just set the option, it would not take effect
+            // until the next packet, unless we wind up the timer explicitly.
+            self.timer.set_keep_alive();
+        }
+    }
+
     /// Return the local endpoint.
     #[inline]
     pub fn local_endpoint(&self) -> IpEndpoint {
@@ -233,6 +303,8 @@ impl<'a> TcpSocket<'a> {
 
     fn reset(&mut self) {
         self.state           = State::Closed;
+        self.timer           = Timer::default();
+        self.keep_alive      = None;
         self.listen_address  = IpAddress::default();
         self.local_endpoint  = IpEndpoint::default();
         self.remote_endpoint = IpEndpoint::default();
@@ -243,7 +315,6 @@ impl<'a> TcpSocket<'a> {
         self.remote_last_win = 0;
         self.remote_win_len  = 0;
         self.remote_mss      = DEFAULT_MSS;
-        self.timer.reset();
         self.tx_buffer.clear();
         self.rx_buffer.clear();
     }
@@ -328,14 +399,10 @@ impl<'a> TcpSocket<'a> {
                 self.set_state(State::Closed),
             // In the SYN-RECEIVED, ESTABLISHED and CLOSE-WAIT states the transmit half
             // of the connection is open, and needs to be explicitly closed with a FIN.
-            State::SynReceived | State::Established => {
-                self.timer.reset();
-                self.set_state(State::FinWait1);
-            }
-            State::CloseWait => {
-                self.timer.reset();
-                self.set_state(State::LastAck);
-            }
+            State::SynReceived | State::Established =>
+                self.set_state(State::FinWait1),
+            State::CloseWait =>
+                self.set_state(State::LastAck),
             // In the FIN-WAIT-1, FIN-WAIT-2, CLOSING, LAST-ACK, TIME-WAIT and CLOSED states,
             // the transmit half of the connection is already closed, and no further
             // action is needed.
@@ -481,7 +548,6 @@ impl<'a> TcpSocket<'a> {
             net_trace!("[{}]{}:{}: tx buffer: enqueueing {} octets (now {})",
                        self.debug_id, self.local_endpoint, self.remote_endpoint,
                        buffer.len(), _old_length + buffer.len());
-            self.timer.reset();
         }
         Ok(buffer)
     }
@@ -495,14 +561,13 @@ impl<'a> TcpSocket<'a> {
     pub fn send_slice(&mut self, data: &[u8]) -> Result<usize> {
         if !self.may_send() { return Err(Error::Illegal) }
 
-        let old_length = self.tx_buffer.len();
+        let _old_length = self.tx_buffer.len();
         let enqueued = self.tx_buffer.enqueue_slice(data);
         if enqueued != 0 {
             #[cfg(any(test, feature = "verbose"))]
             net_trace!("[{}]{}:{}: tx buffer: enqueueing {} octets (now {})",
                        self.debug_id, self.local_endpoint, self.remote_endpoint,
-                       enqueued, old_length + enqueued);
-            self.timer.reset();
+                       enqueued, _old_length + enqueued);
         }
         Ok(enqueued)
     }
@@ -884,13 +949,13 @@ impl<'a> TcpSocket<'a> {
                     self.remote_mss = max_seg_size as usize
                 }
                 self.set_state(State::SynReceived);
-                self.timer.reset();
+                self.timer.set_for_idle(timestamp, self.keep_alive);
             }
 
             // ACK packets in the SYN-RECEIVED state change it to ESTABLISHED.
             (State::SynReceived, TcpControl::None) => {
                 self.set_state(State::Established);
-                self.timer.reset();
+                self.timer.set_for_idle(timestamp, self.keep_alive);
             }
 
             // FIN packets in the SYN-RECEIVED state change it to CLOSE-WAIT.
@@ -899,7 +964,7 @@ impl<'a> TcpSocket<'a> {
             (State::SynReceived, TcpControl::Fin) => {
                 self.remote_seq_no  += 1;
                 self.set_state(State::CloseWait);
-                self.timer.reset();
+                self.timer.set_for_idle(timestamp, self.keep_alive);
             }
 
             // SYN|ACK packets in the SYN-SENT state change it to ESTABLISHED.
@@ -913,19 +978,19 @@ impl<'a> TcpSocket<'a> {
                     self.remote_mss = max_seg_size as usize;
                 }
                 self.set_state(State::Established);
-                self.timer.reset();
+                self.timer.set_for_idle(timestamp, self.keep_alive);
             }
 
             // ACK packets in ESTABLISHED state reset the retransmit timer.
             (State::Established, TcpControl::None) => {
-                self.timer.reset()
+                self.timer.set_for_idle(timestamp, self.keep_alive);
             },
 
             // FIN packets in ESTABLISHED state indicate the remote side has closed.
             (State::Established, TcpControl::Fin) => {
                 self.remote_seq_no  += 1;
                 self.set_state(State::CloseWait);
-                self.timer.reset();
+                self.timer.set_for_idle(timestamp, self.keep_alive);
             }
 
             // ACK packets in FIN-WAIT-1 state change it to FIN-WAIT-2, if we've already
@@ -934,7 +999,7 @@ impl<'a> TcpSocket<'a> {
                 if ack_of_fin {
                     self.set_state(State::FinWait2);
                 }
-                self.timer.reset();
+                self.timer.set_for_idle(timestamp, self.keep_alive);
             }
 
             // FIN packets in FIN-WAIT-1 state change it to CLOSING, or to TIME-WAIT
@@ -946,7 +1011,7 @@ impl<'a> TcpSocket<'a> {
                     self.timer.set_for_close(timestamp);
                 } else {
                     self.set_state(State::Closing);
-                    self.timer.reset();
+                    self.timer.set_for_idle(timestamp, self.keep_alive);
                 }
             }
 
@@ -963,13 +1028,13 @@ impl<'a> TcpSocket<'a> {
                     self.set_state(State::TimeWait);
                     self.timer.set_for_close(timestamp);
                 } else {
-                    self.timer.reset();
+                    self.timer.set_for_idle(timestamp, self.keep_alive);
                 }
             }
 
             // ACK packets in CLOSE-WAIT state reset the retransmit timer.
             (State::CloseWait, TcpControl::None) => {
-                self.timer.reset();
+                self.timer.set_for_idle(timestamp, self.keep_alive);
             }
 
             // ACK packets in LAST-ACK state change it to CLOSED.
@@ -1121,6 +1186,8 @@ impl<'a> TcpSocket<'a> {
             // If we have window length increase to advertise, do it.
         } else if self.timer.should_retransmit(timestamp).is_some() {
             // If we have packets to retransmit, do it.
+        } else if self.timer.should_keep_alive(timestamp) {
+            // If we need to transmit a keep-alive packet, do it.
         } else if repr.control == TcpControl::Rst {
             // If we need to abort the connection, do it.
         } else {
@@ -1147,6 +1214,17 @@ impl<'a> TcpSocket<'a> {
                        flags);
         }
 
+        let is_keep_alive;
+        if self.timer.should_keep_alive(timestamp) {
+            net_trace!("[{}]{}:{}: sending a keep-alive",
+                       self.debug_id, self.local_endpoint, self.remote_endpoint);
+            repr.seq_number = repr.seq_number - 1;
+            repr.payload    = b"\x00"; // RFC 1122 says we should do this
+            is_keep_alive = true;
+        } else {
+            is_keep_alive = false;
+        }
+
         // Remember the header length before enabling the MSS option, since that option
         // only affects SYN packets.
         let header_len = repr.header_len();
@@ -1179,13 +1257,20 @@ impl<'a> TcpSocket<'a> {
 
         emit((ip_repr, repr))?;
 
+        // We've sent something, whether useful data or a keep-alive packet, so rewind
+        // the keep-alive timer.
+        self.timer.rewind_keep_alive(timestamp, self.keep_alive);
+
+        // Leave the rest of the state intact if sending a keep-alive packet.
+        if is_keep_alive { return Ok(()) }
+
         // We've sent a packet successfully, so we can update the internal state now.
         self.remote_last_seq = repr.seq_number + repr.segment_len();
         self.remote_last_ack = repr.ack_number;
         self.remote_last_win = repr.window_len;
 
         if !self.seq_to_transmit(repr.control) && repr.segment_len() > 0 {
-            // If we've transmitted all data could (and there was something at all,
+            // If we've transmitted all data we could (and there was something at all,
             // data or flag, to transmit, not just an ACK), wind up the retransmit timer.
             self.timer.set_for_retransmit(timestamp);
         }
@@ -1229,7 +1314,7 @@ mod test {
 
     #[test]
     fn test_timer_retransmit() {
-        let mut r = Timer::Idle;
+        let mut r = Timer::default();
         assert_eq!(r.should_retransmit(1000), None);
         r.set_for_retransmit(1000);
         assert_eq!(r.should_retransmit(1000), None);
@@ -1240,7 +1325,7 @@ mod test {
         assert_eq!(r.should_retransmit(1150), None);
         assert_eq!(r.should_retransmit(1200), None);
         assert_eq!(r.should_retransmit(1301), Some(300));
-        r.reset();
+        r.set_for_idle(1301, None);
         assert_eq!(r.should_retransmit(1350), None);
     }
 
@@ -2952,6 +3037,71 @@ mod test {
         }]);
     }
 
+    // =========================================================================================//
+    // Tests for keep-alive
+    // =========================================================================================//
+
+    #[test]
+    fn test_responds_to_keep_alive() {
+        let mut s = socket_established();
+        send!(s, TcpRepr {
+            seq_number: REMOTE_SEQ,
+            ack_number: Some(LOCAL_SEQ + 1),
+            ..SEND_TEMPL
+        }, Ok(Some(TcpRepr {
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(REMOTE_SEQ + 1),
+            ..RECV_TEMPL
+        })));
+    }
+
+    #[test]
+    fn test_sends_keep_alive() {
+        let mut s = socket_established();
+        s.set_keep_alive(Some(100));
+
+        // drain the forced keep-alive packet
+        assert_eq!(s.poll_at(), Some(0));
+        recv!(s, time 0, Ok(TcpRepr {
+            seq_number: LOCAL_SEQ,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload:    &[0],
+            ..RECV_TEMPL
+        }));
+
+        assert_eq!(s.poll_at(), Some(100));
+        recv!(s, time 95, Err(Error::Exhausted));
+        recv!(s, time 100, Ok(TcpRepr {
+            seq_number: LOCAL_SEQ,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload:    &[0],
+            ..RECV_TEMPL
+        }));
+
+        assert_eq!(s.poll_at(), Some(200));
+        recv!(s, time 195, Err(Error::Exhausted));
+        recv!(s, time 200, Ok(TcpRepr {
+            seq_number: LOCAL_SEQ,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload:    &[0],
+            ..RECV_TEMPL
+        }));
+
+        send!(s, time 250, TcpRepr {
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
+            ..SEND_TEMPL
+        });
+        assert_eq!(s.poll_at(), Some(350));
+        recv!(s, time 345, Err(Error::Exhausted));
+        recv!(s, time 350, Ok(TcpRepr {
+            seq_number: LOCAL_SEQ,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload:    &b"\x00"[..],
+            ..RECV_TEMPL
+        }));
+    }
+
     // =========================================================================================//
     // Tests for packet filtering
     // =========================================================================================//