Browse Source

Implement the TCP close operation.

whitequark 8 years ago
parent
commit
1a618aac45
2 changed files with 158 additions and 1 deletions
  1. 2 0
      examples/smoltcpserver.rs
  2. 156 1
      src/socket/tcp.rs

+ 2 - 0
examples/smoltcpserver.rs

@@ -121,6 +121,8 @@ fn main() {
                            str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)"));
                     socket.send_slice(&data[..]).unwrap();
                 }
+            } else if socket.can_send() {
+                socket.close()
             }
         }
 

+ 156 - 1
src/socket/tcp.rs

@@ -257,6 +257,34 @@ impl<'a> TcpSocket<'a> {
         Ok(())
     }
 
+    /// Close the transmit half of the full-duplex connection.
+    ///
+    /// Note that there is no corresponding function for the receive half of the full-duplex
+    /// connection; only the remote end can close it. If you no longer wish to receive any
+    /// data and would like to reuse the socket right away, use [abort](#method.abort).
+    pub fn close(&mut self) {
+        match self.state {
+            // In the LISTEN state there is no established connection.
+            State::Listen =>
+                self.set_state(State::Closed),
+            // In the SYN_SENT state the remote endpoint is not yet synchronized and, upon
+            // receiving an RST, will abort the connection.
+            State::SynSent =>
+                self.set_state(State::Closed),
+            // In the SYN_RECEIVED, ESTABLISHED and CLOSE_WAIT states the transmit half
+            // of the connection is open, and needs to be explicitly closed with a FIN.
+            State::SynReceived | State::Established =>
+                self.set_state(State::FinWait1),
+            State::CloseWait =>
+                self.set_state(State::LastAck),
+            // In the FIN_WAIT_1, FIN_WAIT_2, CLOSING, LAST_ACK, TIME_WAIT and CLOSED states,
+            // the transmit half of the connection is already closed, and no further
+            // action is needed.
+            State::FinWait1 | State::FinWait2 | State::Closing |
+            State::TimeWait | State::LastAck | State::Closed => ()
+        }
+    }
+
     /// Return whether the socket is open.
     ///
     /// This function returns true if the socket will process incoming or dispatch outgoing
@@ -854,7 +882,7 @@ mod test {
     #[test]
     fn test_closed() {
         let mut s = socket();
-        assert_eq!(s.state(), State::Closed);
+        assert_eq!(s.state, State::Closed);
 
         send!(s, TcpRepr {
             control: TcpControl::Syn,
@@ -862,6 +890,13 @@ mod test {
         }, Err(Error::Rejected));
     }
 
+    #[test]
+    fn test_closed_close() {
+        let mut s = socket();
+        s.close();
+        assert_eq!(s.state, State::Closed);
+    }
+
     // =========================================================================================//
     // Tests for the LISTEN state.
     // =========================================================================================//
@@ -895,6 +930,13 @@ mod test {
         }]);
     }
 
+    #[test]
+    fn test_listen_close() {
+        let mut s = socket_listen();
+        s.close();
+        assert_eq!(s.state, State::Closed);
+    }
+
     // =========================================================================================//
     // Tests for the SYN_RECEIVED state.
     // =========================================================================================//
@@ -922,6 +964,13 @@ mod test {
         assert_eq!(s.remote_endpoint, IpEndpoint::default());
     }
 
+    #[test]
+    fn test_syn_received_close() {
+        let mut s = socket_syn_received();
+        s.close();
+        assert_eq!(s.state, State::FinWait1);
+    }
+
     // =========================================================================================//
     // Tests for the SYN_SENT state.
     // =========================================================================================//
@@ -970,6 +1019,13 @@ mod test {
         assert_eq!(s.state, State::SynSent);
     }
 
+    #[test]
+    fn test_syn_sent_close() {
+        let mut s = socket();
+        s.close();
+        assert_eq!(s.state, State::Closed);
+    }
+
     // =========================================================================================//
     // Tests for the ESTABLISHED state.
     // =========================================================================================//
@@ -1170,6 +1226,64 @@ mod test {
         assert_eq!(s.state, State::Closed);
     }
 
+    #[test]
+    fn test_established_close() {
+        let mut s = socket_established();
+        s.close();
+        assert_eq!(s.state, State::FinWait1);
+    }
+
+    // =========================================================================================//
+    // Tests for the FIN_WAIT_1 state.
+    // =========================================================================================//
+    fn socket_fin_wait_1() -> TcpSocket<'static> {
+        let mut s = socket_established();
+        s.state           = State::FinWait1;
+        s
+    }
+
+    #[test]
+    fn test_fin_wait_1_close() {
+        let mut s = socket_fin_wait_1();
+        s.close();
+        assert_eq!(s.state, State::FinWait1);
+    }
+
+    // =========================================================================================//
+    // Tests for the FIN_WAIT_2 state.
+    // =========================================================================================//
+    fn socket_fin_wait_2() -> TcpSocket<'static> {
+        let mut s = socket_fin_wait_1();
+        s.state           = State::FinWait2;
+        s.local_seq_no    = LOCAL_SEQ + 1 + 1;
+        s
+    }
+
+    #[test]
+    fn test_fin_wait_2_close() {
+        let mut s = socket_fin_wait_2();
+        s.close();
+        assert_eq!(s.state, State::FinWait2);
+    }
+
+    // =========================================================================================//
+    // Tests for the CLOSING state.
+    // =========================================================================================//
+    fn socket_closing() -> TcpSocket<'static> {
+        let mut s = socket_fin_wait_1();
+        s.state           = State::Closing;
+        s.remote_seq_no   = REMOTE_SEQ + 1 + 1;
+        s.remote_last_ack = REMOTE_SEQ + 1 + 1;
+        s
+    }
+
+    #[test]
+    fn test_closing_close() {
+        let mut s = socket_closing();
+        s.close();
+        assert_eq!(s.state, State::Closing);
+    }
+
     // =========================================================================================//
     // Tests for the CLOSE_WAIT state.
     // =========================================================================================//
@@ -1198,6 +1312,47 @@ mod test {
         }]);
     }
 
+    #[test]
+    fn test_close_wait_close() {
+        let mut s = socket_close_wait();
+        s.close();
+        assert_eq!(s.state, State::LastAck);
+    }
+
+    // =========================================================================================//
+    // Tests for the LAST_ACK state.
+    // =========================================================================================//
+    fn socket_last_ack() -> TcpSocket<'static> {
+        let mut s = socket_close_wait();
+        s.state           = State::LastAck;
+        s
+    }
+
+    #[test]
+    fn test_last_ack_close() {
+        let mut s = socket_last_ack();
+        s.close();
+        assert_eq!(s.state, State::LastAck);
+    }
+
+    // =========================================================================================//
+    // Tests for the TIME_WAIT state.
+    // =========================================================================================//
+    fn socket_time_wait() -> TcpSocket<'static> {
+        let mut s = socket_fin_wait_2();
+        s.state           = State::TimeWait;
+        s.remote_seq_no   = REMOTE_SEQ + 1 + 1;
+        s.remote_last_ack = REMOTE_SEQ + 1 + 1;
+        s
+    }
+
+    #[test]
+    fn test_time_wait_close() {
+        let mut s = socket_time_wait();
+        s.close();
+        assert_eq!(s.state, State::TimeWait);
+    }
+
     // =========================================================================================//
     // Tests for transitioning through multiple states.
     // =========================================================================================//