Browse Source

Implement the TCP TIME-WAIT state.

whitequark 8 years ago
parent
commit
553d640057
4 changed files with 150 additions and 34 deletions
  1. 6 3
      README.md
  2. 37 14
      examples/server.rs
  3. 99 17
      src/socket/tcp.rs
  4. 8 0
      src/wire/ip.rs

+ 6 - 3
README.md

@@ -52,11 +52,12 @@ The TCP protocol is supported over IPv4.
 
   * TCP header checksum is supported.
   * Multiple packets will be transmitted without waiting for an acknowledgement.
-  * TCP urgent pointer is **not** supported; any urgent octets will be received alongside data.
+  * TCP urgent pointer is **not** supported; any urgent octets will be received alongside
+    data octets.
   * Reassembly of out-of-order segments is **not** supported.
   * TCP options are **not** supported, in particular:
     * Maximum segment size is hardcoded at the default value, 536.
-    * Window scaling is **not** supported.
+    * Window scaling is **not** supported, and the maximum buffer size is 65536.
   * Keepalive is **not** supported.
 
 Installation
@@ -143,7 +144,9 @@ It responds to:
   * pings (`ping 192.168.69.1`);
   * UDP packets on port 6969 (`socat stdio udp4-connect:192.168.69.1:6969 <<<"abcdefg"`),
     where it will respond "yo dawg" to any incoming packet;
-  * TCP packets on port 6969 (`socat stdio tcp4-connect:192.168.69.1:6969 <<<"abcdefg"`),
+  * TCP packets on port 6969 (`socat stdio tcp4-connect:192.168.69.1:6969`),
+    where it will respond "yo dawg" to any incoming connection and immediately close it;
+  * TCP packets on port 6970 (`socat stdio tcp4-connect:192.168.69.1:6970 <<<"abcdefg"`),
     where it will respond with reversed chunks of the input indefinitely.
 
 The buffers are only 64 bytes long, for convenience of testing resource exhaustion conditions.

+ 37 - 14
examples/server.rs

@@ -53,23 +53,28 @@ fn main() {
     let udp_tx_buffer = UdpSocketBuffer::new(vec![UdpPacketBuffer::new(vec![0; 128])]);
     let udp_socket = UdpSocket::new(endpoint, udp_rx_buffer, udp_tx_buffer);
 
-    let tcp_rx_buffer = TcpSocketBuffer::new(vec![0; 64]);
-    let tcp_tx_buffer = TcpSocketBuffer::new(vec![0; 128]);
-    let tcp_socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
+    let tcp1_rx_buffer = TcpSocketBuffer::new(vec![0; 64]);
+    let tcp1_tx_buffer = TcpSocketBuffer::new(vec![0; 128]);
+    let tcp1_socket = TcpSocket::new(tcp1_rx_buffer, tcp1_tx_buffer);
+
+    let tcp2_rx_buffer = TcpSocketBuffer::new(vec![0; 64]);
+    let tcp2_tx_buffer = TcpSocketBuffer::new(vec![0; 128]);
+    let tcp2_socket = TcpSocket::new(tcp2_rx_buffer, tcp2_tx_buffer);
 
     let hardware_addr = EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]);
     let protocol_addrs = [IpAddress::v4(192, 168, 69, 1)];
-    let sockets = vec![udp_socket, tcp_socket];
+    let sockets = vec![udp_socket, tcp1_socket, tcp2_socket];
     let mut iface = EthernetInterface::new(device, arp_cache,
         hardware_addr, protocol_addrs, sockets);
 
     let mut tcp_6969_connected = false;
     loop {
+        // udp:6969: respond "yo dawg"
         {
             let socket: &mut UdpSocket = iface.sockets()[0].as_socket();
             let client = match socket.recv() {
                 Ok((endpoint, data)) => {
-                    debug!("udp recv data: {:?} from {}",
+                    debug!("udp:6969 recv data: {:?} from {}",
                            str::from_utf8(data.as_ref()).unwrap(), endpoint);
                     Some(endpoint)
                 }
@@ -77,28 +82,46 @@ fn main() {
                     None
                 }
                 Err(e) => {
-                    debug!("udp recv error: {}", e);
+                    debug!("udp:6969 recv error: {}", e);
                     None
                 }
             };
             if let Some(endpoint) = client {
-                let data = b"yo dawg";
-                debug!("udp send data: {:?}",
+                let data = b"yo dawg\n";
+                debug!("udp:6969 send data: {:?}",
                        str::from_utf8(data.as_ref()).unwrap());
                 socket.send_slice(endpoint, data).unwrap()
             }
         }
 
+        // tcp:6969: respond "yo dawg"
         {
             let socket: &mut TcpSocket = iface.sockets()[1].as_socket();
             if !socket.is_open() {
-                socket.listen(endpoint).unwrap()
+                socket.listen(6969).unwrap();
+            }
+
+            if socket.can_send() {
+                let data = b"yo dawg\n";
+                debug!("tcp:6969 send data: {:?}",
+                       str::from_utf8(data.as_ref()).unwrap());
+                socket.send_slice(data).unwrap();
+                debug!("tcp:6969 close");
+                socket.close();
+            }
+        }
+
+        // tcp:6970: echo with reverse
+        {
+            let socket: &mut TcpSocket = iface.sockets()[2].as_socket();
+            if !socket.is_open() {
+                socket.listen(6970).unwrap()
             }
 
             if socket.is_connected() && !tcp_6969_connected {
-                debug!("tcp connected");
+                debug!("tcp:6970 connected");
             } else if !socket.is_connected() && tcp_6969_connected {
-                debug!("tcp disconnected");
+                debug!("tcp:6970 disconnected");
             }
             tcp_6969_connected = socket.is_connected();
 
@@ -106,7 +129,7 @@ fn main() {
                 let data = {
                     let mut data = socket.recv(128).unwrap().to_owned();
                     if data.len() > 0 {
-                        debug!("tcp recv data: {:?}",
+                        debug!("tcp:6970 recv data: {:?}",
                                str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)"));
                         data = data.split(|&b| b == b'\n').collect::<Vec<_>>().concat();
                         data.reverse();
@@ -115,13 +138,13 @@ fn main() {
                     data
                 };
                 if socket.can_send() && data.len() > 0 {
-                    debug!("tcp send data: {:?}",
+                    debug!("tcp:6970 send data: {:?}",
                            str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)"));
                     socket.send_slice(&data[..]).unwrap();
                 }
             } else if socket.can_send() {
+                debug!("tcp:6970 close");
                 socket.close();
-                debug!("tcp closed")
             }
         }
 

+ 99 - 17
src/socket/tcp.rs

@@ -249,9 +249,10 @@ impl<'a> TcpSocket<'a> {
     /// Start listening on the given endpoint.
     ///
     /// This function returns an error if the socket was open; see [is_open](#method.is_open).
-    pub fn listen(&mut self, endpoint: IpEndpoint) -> Result<(), ()> {
+    pub fn listen<T: Into<IpEndpoint>>(&mut self, endpoint: T) -> Result<(), ()> {
         if self.is_open() { return Err(()) }
 
+        let endpoint = endpoint.into();
         self.listen_address  = endpoint.addr;
         self.local_endpoint  = endpoint;
         self.remote_endpoint = IpEndpoint::default();
@@ -507,8 +508,8 @@ impl<'a> TcpSocket<'a> {
                 let control_len = match state {
                     // In SYN-SENT or SYN-RECEIVED, we've just sent a SYN.
                     State::SynSent | State::SynReceived => 1,
-                    // In FIN-WAIT-1 or LAST-ACK, we've just sent a FIN.
-                    State::FinWait1 | State::LastAck => 1,
+                    // In FIN-WAIT-1, LAST-ACK, or CLOSING, we've just sent a FIN.
+                    State::FinWait1 | State::LastAck | State::Closing => 1,
                     // In all other states we've already got acknowledgemetns for
                     // all of the control flags we sent.
                     _ => 0
@@ -625,9 +626,9 @@ impl<'a> TcpSocket<'a> {
 
             // ACK packets in CLOSING state change it to TIME-WAIT.
             (State::Closing, TcpRepr { control: TcpControl::None, .. }) => {
-                // Clear the remote endpoint, or we'll send an ACK there.
-                self.remote_endpoint = IpEndpoint::default();
+                self.local_seq_no   += 1;
                 self.set_state(State::TimeWait);
+                self.retransmit.reset();
             }
 
             // ACK packets in CLOSE-WAIT state do nothing.
@@ -696,8 +697,8 @@ impl<'a> TcpSocket<'a> {
 
         let mut should_send = false;
         match self.state {
-            // We never transmit anything in the CLOSED, LISTEN, TIME-WAIT or FIN-WAIT-2 states.
-            State::Closed | State::Listen | State::TimeWait | State::FinWait2 => {
+            // We never transmit anything in the CLOSED, LISTEN, or FIN-WAIT-2 states.
+            State::Closed | State::Listen | State::FinWait2 => {
                 return Err(Error::Exhausted)
             }
 
@@ -723,11 +724,11 @@ impl<'a> TcpSocket<'a> {
             }
 
             // We transmit data in the ESTABLISHED state,
-            // ACK in CLOSE-WAIT and CLOSING states,
+            // ACK in CLOSE-WAIT, CLOSING, and TIME-WAIT states,
             // FIN in FIN-WAIT-1 and LAST-ACK states.
             State::Established |
-            State::CloseWait   | State::LastAck |
-            State::FinWait1    | State::Closing => {
+            State::CloseWait   | State::Closing | State::TimeWait |
+            State::FinWait1    | State::LastAck => {
                 // See if we should send data to the remote end because:
                 let mut may_send = false;
                 //   1. the retransmit timer has expired or was reset, or...
@@ -1375,7 +1376,6 @@ mod test {
             ..SEND_TEMPL
         }]);
         assert_eq!(s.state, State::TimeWait);
-        assert!(!s.remote_endpoint.is_unspecified());
     }
 
     #[test]
@@ -1391,7 +1391,7 @@ mod test {
     fn socket_closing() -> TcpSocket<'static> {
         let mut s = socket_fin_wait_1();
         s.state           = State::Closing;
-        s.local_seq_no    = LOCAL_SEQ + 1 + 1;
+        s.local_seq_no    = LOCAL_SEQ + 1;
         s.remote_seq_no   = REMOTE_SEQ + 1 + 1;
         s
     }
@@ -1400,7 +1400,7 @@ mod test {
     fn test_closing_ack_fin() {
         let mut s = socket_closing();
         recv!(s, [TcpRepr {
-            seq_number: LOCAL_SEQ + 1 + 1,
+            seq_number: LOCAL_SEQ + 1,
             ack_number: Some(REMOTE_SEQ + 1 + 1),
             ..RECV_TEMPL
         }]);
@@ -1410,7 +1410,6 @@ mod test {
             ..SEND_TEMPL
         }]);
         assert_eq!(s.state, State::TimeWait);
-        assert!(s.remote_endpoint.is_unspecified());
     }
 
     #[test]
@@ -1423,17 +1422,35 @@ mod test {
     // =========================================================================================//
     // Tests for the TIME-WAIT state.
     // =========================================================================================//
-    fn socket_time_wait() -> TcpSocket<'static> {
+    fn socket_time_wait(from_closing: bool) -> TcpSocket<'static> {
         let mut s = socket_fin_wait_2();
         s.state           = State::TimeWait;
         s.remote_seq_no   = REMOTE_SEQ + 1 + 1;
-        s.remote_last_ack = REMOTE_SEQ + 1 + 1;
+        if from_closing {
+            s.remote_last_ack = REMOTE_SEQ + 1 + 1;
+        }
         s
     }
 
+    #[test]
+    fn test_time_wait_from_fin_wait_2_ack() {
+        let mut s = socket_time_wait(false);
+        recv!(s, [TcpRepr {
+            seq_number: LOCAL_SEQ + 1 + 1,
+            ack_number: Some(REMOTE_SEQ + 1 + 1),
+            ..RECV_TEMPL
+        }]);
+    }
+
+    #[test]
+    fn test_time_wait_from_closing_no_ack() {
+        let mut s = socket_time_wait(true);
+        recv!(s, []);
+    }
+
     #[test]
     fn test_time_wait_close() {
-        let mut s = socket_time_wait();
+        let mut s = socket_time_wait(false);
         s.close();
         assert_eq!(s.state, State::TimeWait);
     }
@@ -1573,5 +1590,70 @@ mod test {
             ack_number: Some(LOCAL_SEQ + 1 + 1),
             ..SEND_TEMPL
         }]);
+        assert_eq!(s.state, State::Closed);
+    }
+
+    #[test]
+    fn test_local_close() {
+        let mut s = socket_established();
+        s.close();
+        assert_eq!(s.state, State::FinWait1);
+        recv!(s, [TcpRepr {
+            control: TcpControl::Fin,
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(REMOTE_SEQ + 1),
+            ..RECV_TEMPL
+        }]);
+        send!(s, [TcpRepr {
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1 + 1),
+            ..SEND_TEMPL
+        }]);
+        assert_eq!(s.state, State::FinWait2);
+        send!(s, [TcpRepr {
+            control: TcpControl::Fin,
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1 + 1),
+            ..SEND_TEMPL
+        }]);
+        assert_eq!(s.state, State::TimeWait);
+        recv!(s, [TcpRepr {
+            seq_number: LOCAL_SEQ + 1 + 1,
+            ack_number: Some(REMOTE_SEQ + 1 + 1),
+            ..RECV_TEMPL
+        }]);
+    }
+
+    #[test]
+    fn test_simultaneous_close() {
+        let mut s = socket_established();
+        s.close();
+        assert_eq!(s.state, State::FinWait1);
+        recv!(s, [TcpRepr { // this is logically located...
+            control: TcpControl::Fin,
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(REMOTE_SEQ + 1),
+            ..RECV_TEMPL
+        }]);
+        send!(s, [TcpRepr {
+            control: TcpControl::Fin,
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
+            ..SEND_TEMPL
+        }]);
+        assert_eq!(s.state, State::Closing);
+        recv!(s, [TcpRepr {
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(REMOTE_SEQ + 1 + 1),
+            ..RECV_TEMPL
+        }]);
+        // ... at this point
+        send!(s, [TcpRepr {
+            seq_number: REMOTE_SEQ + 1 + 1,
+            ack_number: Some(LOCAL_SEQ + 1 + 1),
+            ..SEND_TEMPL
+        }]);
+        assert_eq!(s.state, State::TimeWait);
+        recv!(s, []);
     }
 }

+ 8 - 0
src/wire/ip.rs

@@ -78,6 +78,8 @@ impl fmt::Display for Address {
 }
 
 /// An internet endpoint address.
+///
+/// An endpoint can be constructed from a port, in which case the address is unspecified.
 #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)]
 pub struct Endpoint {
     pub addr: Address,
@@ -104,6 +106,12 @@ impl fmt::Display for Endpoint {
     }
 }
 
+impl From<u16> for Endpoint {
+    fn from(port: u16) -> Endpoint {
+        Endpoint { addr: Address::Unspecified, port: port }
+    }
+}
+
 /// An IP packet representation.
 ///
 /// This enum abstracts the various versions of IP packets. It either contains a concrete