Browse Source

Implement TCP RST handling.

whitequark 8 years ago
parent
commit
a81dd55cca
1 changed files with 200 additions and 99 deletions
  1. 200 99
      src/socket/tcp.rs

+ 200 - 99
src/socket/tcp.rs

@@ -2,7 +2,7 @@ use core::fmt;
 
 use Error;
 use Managed;
-use wire::{IpProtocol, IpEndpoint};
+use wire::{IpProtocol, IpAddress, IpEndpoint};
 use wire::{TcpPacket, TcpRepr, TcpControl};
 use socket::{Socket, IpRepr, IpPayload};
 
@@ -162,6 +162,7 @@ impl Retransmit {
 #[derive(Debug)]
 pub struct TcpSocket<'a> {
     state:           State,
+    listen_address:  IpAddress,
     local_endpoint:  IpEndpoint,
     remote_endpoint: IpEndpoint,
     local_seq_no:    i32,
@@ -185,6 +186,7 @@ impl<'a> TcpSocket<'a> {
 
         Socket::Tcp(TcpSocket {
             state:           State::Closed,
+            listen_address:  IpAddress::default(),
             local_endpoint:  IpEndpoint::default(),
             remote_endpoint: IpEndpoint::default(),
             local_seq_no:    0,
@@ -235,6 +237,7 @@ impl<'a> TcpSocket<'a> {
     pub fn listen(&mut self, endpoint: IpEndpoint) {
         assert!(self.state == State::Closed);
 
+        self.listen_address  = endpoint.addr;
         self.local_endpoint  = endpoint;
         self.remote_endpoint = IpEndpoint::default();
         self.set_state(State::Listen);
@@ -260,16 +263,16 @@ impl<'a> TcpSocket<'a> {
 
         // Reject packets addressed to a closed socket.
         if self.state == State::Closed {
-            net_trace!("tcp:{}:{}:{}: packet sent to a closed socket",
+            net_trace!("tcp:{}:{}:{}: packet received by a closed socket",
                        self.local_endpoint, ip_repr.src_addr(), repr.src_port);
             return Err(Error::Malformed)
         }
 
         // Reject unacceptable acknowledgements.
         match (self.state, repr) {
-            // The initial SYN cannot contain an acknowledgement.
+            // The initial SYN (or whatever) cannot contain an acknowledgement.
             (State::Listen, TcpRepr { ack_number: Some(_), .. }) => {
-                net_trace!("tcp:{}:{}: ACK in initial SYN",
+                net_trace!("tcp:{}:{}: ACK received by a socket in LISTEN state",
                            self.local_endpoint, self.remote_endpoint);
                 return Err(Error::Malformed)
             }
@@ -312,27 +315,53 @@ impl<'a> TcpSocket<'a> {
             }
         }
 
-        // Reject segments not occupying a valid portion of the receive window.
-        // For now, do not try to reassemble out-of-order segments.
-        if self.state != State::Listen {
-            let next_remote_seq = self.remote_seq_no + self.rx_buffer.len() as i32 +
-                                  repr.control.len();
-            if repr.seq_number - next_remote_seq > 0 {
-                net_trace!("tcp:{}:{}: unacceptable SEQ ({} not in {}..)",
-                           self.local_endpoint, self.remote_endpoint,
-                           repr.seq_number, next_remote_seq);
-                return Err(Error::Malformed)
-            } else if repr.seq_number - next_remote_seq != 0 {
-                net_trace!("tcp:{}:{}: duplicate SEQ ({} in ..{})",
-                           self.local_endpoint, self.remote_endpoint,
-                           repr.seq_number, next_remote_seq);
-                return Ok(())
+        match (self.state, repr) {
+            // In LISTEN and SYN_SENT states, we have not yet synchronized with the remote end.
+            (State::Listen, _)  => (),
+            (State::SynSent, _) => (),
+            // In all other states, segments must occupy a valid portion of the receive window.
+            // For now, do not try to reassemble out-of-order segments.
+            (_, TcpRepr { control, seq_number, .. }) => {
+                let next_remote_seq = self.remote_seq_no + self.rx_buffer.len() as i32 +
+                                      control.len();
+                if seq_number - next_remote_seq > 0 {
+                    net_trace!("tcp:{}:{}: unacceptable SEQ ({} not in {}..)",
+                               self.local_endpoint, self.remote_endpoint,
+                               seq_number, next_remote_seq);
+                    return Err(Error::Malformed)
+                } else if seq_number - next_remote_seq != 0 {
+                    net_trace!("tcp:{}:{}: duplicate SEQ ({} in ..{})",
+                               self.local_endpoint, self.remote_endpoint,
+                               seq_number, next_remote_seq);
+                    return Ok(())
+                }
             }
         }
 
         // Validate and update the state.
         let old_state = self.state;
         match (self.state, repr) {
+            // RSTs are ignored in the LISTEN state.
+            (State::Listen, TcpRepr { control: TcpControl::Rst, .. }) =>
+                return Ok(()),
+
+            // RSTs in SYN_RECEIVED flip the socket back to the LISTEN state.
+            (State::SynReceived, TcpRepr { control: TcpControl::Rst, .. }) => {
+                self.local_endpoint.addr = self.listen_address;
+                self.remote_endpoint     = IpEndpoint::default();
+                self.set_state(State::Listen);
+                return Ok(())
+            }
+
+            // RSTs in any other state close the socket.
+            (_, TcpRepr { control: TcpControl::Rst, .. }) => {
+                self.local_endpoint  = IpEndpoint::default();
+                self.remote_endpoint = IpEndpoint::default();
+                self.set_state(State::Closed);
+                return Ok(())
+            }
+
+            // SYN packets in the LISTEN state change it to SYN_RECEIVED.
             (State::Listen, TcpRepr {
                 src_port, dst_port, control: TcpControl::Syn, seq_number, ack_number: None, ..
             }) => {
@@ -344,11 +373,13 @@ impl<'a> TcpSocket<'a> {
                 self.retransmit.reset()
             }
 
+            // SYN|ACK packets in the SYN_RECEIVED state change it to ESTABLISHED.
             (State::SynReceived, TcpRepr { control: TcpControl::None, .. }) => {
                 self.set_state(State::Established);
                 self.retransmit.reset()
             }
 
+            // ACK packets in ESTABLISHED state do nothing.
             (State::Established, TcpRepr { control: TcpControl::None, .. }) => (),
 
             _ => {
@@ -586,6 +617,9 @@ mod test {
         }
     }
 
+    // =========================================================================================//
+    // Tests for the CLOSED state.
+    // =========================================================================================//
     #[test]
     fn test_closed() {
         let mut s = socket();
@@ -597,169 +631,236 @@ mod test {
         }, Err(Error::Rejected));
     }
 
-    #[test]
-    fn test_listen() {
-        let mut s = socket();
-        s.listen(IpEndpoint::new(IpAddress::default(), LOCAL_PORT));
-        assert_eq!(s.state(), State::Listen);
-    }
-
-    #[test]
-    fn test_handshake() {
+    // =========================================================================================//
+    // Tests for the LISTEN state.
+    // =========================================================================================//
+    fn socket_listen() -> TcpSocket<'static> {
         let mut s = socket();
         s.state           = State::Listen;
         s.local_endpoint  = IpEndpoint::new(IpAddress::default(), LOCAL_PORT);
+        s
+    }
 
-        send!(s, [TcpRepr {
+    #[test]
+    fn test_listen_syn_no_ack() {
+        let mut s = socket_listen();
+        send!(s, TcpRepr {
             control: TcpControl::Syn,
             seq_number: REMOTE_SEQ,
-            ack_number: None,
+            ack_number: Some(LOCAL_SEQ),
             ..SEND_TEMPL
-        }]);
-        assert_eq!(s.state(), State::SynReceived);
-        assert_eq!(s.local_endpoint(), LOCAL_END);
-        assert_eq!(s.remote_endpoint(), REMOTE_END);
-        recv!(s, [TcpRepr {
-            control: TcpControl::Syn,
-            seq_number: LOCAL_SEQ,
-            ack_number: Some(REMOTE_SEQ + 1),
-            ..RECV_TEMPL
-        }]);
+        }, Err(Error::Malformed));
+        assert_eq!(s.state, State::Listen);
+    }
+
+    #[test]
+    fn test_listen_rst() {
+        let mut s = socket_listen();
         send!(s, [TcpRepr {
-            seq_number: REMOTE_SEQ + 1,
-            ack_number: Some(LOCAL_SEQ + 1),
+            control: TcpControl::Rst,
+            seq_number: REMOTE_SEQ,
+            ack_number: None,
             ..SEND_TEMPL
         }]);
-        assert_eq!(s.state(), State::Established);
-        assert_eq!(s.local_seq_no, LOCAL_SEQ + 1);
-        assert_eq!(s.remote_seq_no, REMOTE_SEQ + 1);
     }
 
-    #[test]
-    fn test_no_ack() {
+    // =========================================================================================//
+    // Tests for the SYN_RECEIVED state.
+    // =========================================================================================//
+    fn socket_syn_received() -> TcpSocket<'static> {
         let mut s = socket();
-        s.state           = State::Established;
+        s.state           = State::SynReceived;
         s.local_endpoint  = LOCAL_END;
         s.remote_endpoint = REMOTE_END;
-        s.local_seq_no    = LOCAL_SEQ + 1;
-        s.remote_seq_no   = REMOTE_SEQ + 1;
-
-        send!(s, TcpRepr {
-            seq_number: REMOTE_SEQ + 1,
-            ack_number: None,
-            ..SEND_TEMPL
-        }, Err(Error::Malformed));
+        s.local_seq_no    = LOCAL_SEQ;
+        s.remote_seq_no   = REMOTE_SEQ;
+        s
     }
 
     #[test]
-    fn test_bad_ack_listen() {
-        let mut s = socket();
-        s.state           = State::Listen;
-        s.local_endpoint  = IpEndpoint::new(IpAddress::default(), LOCAL_PORT);
-
-        send!(s, TcpRepr {
-            control: TcpControl::Syn,
+    fn test_syn_received_rst() {
+        let mut s = socket_syn_received();
+        send!(s, [TcpRepr {
+            control: TcpControl::Rst,
             seq_number: REMOTE_SEQ,
             ack_number: Some(LOCAL_SEQ),
             ..SEND_TEMPL
-        }, Err(Error::Malformed));
+        }]);
+        assert_eq!(s.state, State::Listen);
+        assert_eq!(s.local_endpoint, IpEndpoint::new(IpAddress::Unspecified, LOCAL_END.port));
+        assert_eq!(s.remote_endpoint, IpEndpoint::default());
     }
 
-    #[test]
-    fn test_no_ack_syn_sent_rst() {
+    // =========================================================================================//
+    // Tests for the SYN_SENT state.
+    // =========================================================================================//
+    fn socket_syn_sent() -> TcpSocket<'static> {
         let mut s = socket();
         s.state           = State::SynSent;
         s.local_endpoint  = LOCAL_END;
         s.remote_endpoint = REMOTE_END;
         s.local_seq_no    = LOCAL_SEQ;
+        s
+    }
 
+    #[test]
+    fn test_syn_sent_rst() {
+        let mut s = socket_syn_sent();
+        send!(s, [TcpRepr {
+            control: TcpControl::Rst,
+            seq_number: REMOTE_SEQ,
+            ack_number: Some(LOCAL_SEQ),
+            ..SEND_TEMPL
+        }]);
+        assert_eq!(s.state, State::Closed);
+    }
+
+    #[test]
+    fn test_syn_sent_rst_no_ack() {
+        let mut s = socket_syn_sent();
         send!(s, TcpRepr {
             control: TcpControl::Rst,
             seq_number: REMOTE_SEQ,
             ack_number: None,
             ..SEND_TEMPL
         }, Err(Error::Malformed));
+        assert_eq!(s.state, State::SynSent);
     }
 
     #[test]
-    fn test_bad_ack_syn_sent_rst() {
-        let mut s = socket();
-        s.state           = State::SynSent;
-        s.local_endpoint  = LOCAL_END;
-        s.remote_endpoint = REMOTE_END;
-        s.local_seq_no    = LOCAL_SEQ;
-
+    fn test_syn_sent_rst_bad_ack() {
+        let mut s = socket_syn_sent();
         send!(s, TcpRepr {
             control: TcpControl::Rst,
             seq_number: REMOTE_SEQ,
             ack_number: Some(1234),
             ..SEND_TEMPL
         }, Err(Error::Malformed));
+        assert_eq!(s.state, State::SynSent);
     }
 
-    #[test]
-    fn test_bad_ack_established() {
+    // =========================================================================================//
+    // Tests for the ESTABLISHED state.
+    // =========================================================================================//
+    fn socket_established() -> TcpSocket<'static> {
         let mut s = socket();
-        s.state           = State::Established;
+        s.state          = State::Established;
         s.local_endpoint  = LOCAL_END;
         s.remote_endpoint = REMOTE_END;
         s.local_seq_no    = LOCAL_SEQ + 1;
         s.remote_seq_no   = REMOTE_SEQ + 1;
-        s.tx_buffer.enqueue_slice(b"abcdef");
+        s
+    }
+
+    #[test]
+    fn test_established_data() {
+        let mut s = socket_established();
+        send!(s, [TcpRepr {
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
+            payload: &b"abcdef"[..],
+            ..SEND_TEMPL
+        }]);
+        recv!(s, [TcpRepr {
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(REMOTE_SEQ + 1 + 6),
+            window_len: 122,
+            ..RECV_TEMPL
+        }]);
+        assert_eq!(s.rx_buffer.dequeue(6), &b"abcdef"[..]);
+    }
 
+    #[test]
+    fn test_established_no_ack() {
+        let mut s = socket_established();
+        send!(s, TcpRepr {
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: None,
+            ..SEND_TEMPL
+        }, Err(Error::Malformed));
+    }
+
+    #[test]
+    fn test_established_bad_ack() {
+        let mut s = socket_established();
         // Already acknowledged data.
         send!(s, TcpRepr {
             seq_number: REMOTE_SEQ + 1,
             ack_number: Some(LOCAL_SEQ - 1),
             ..SEND_TEMPL
         }, Err(Error::Malformed));
-
+        assert_eq!(s.local_seq_no, LOCAL_SEQ + 1);
         // Data not yet transmitted.
         send!(s, TcpRepr {
             seq_number: REMOTE_SEQ + 1,
             ack_number: Some(LOCAL_SEQ + 10),
             ..SEND_TEMPL
         }, Err(Error::Malformed));
+        assert_eq!(s.local_seq_no, LOCAL_SEQ + 1);
     }
 
     #[test]
-    fn test_unacceptable_seq() {
-        let mut s = socket();
-        s.state          = State::Established;
-        s.local_endpoint  = LOCAL_END;
-        s.remote_endpoint = REMOTE_END;
-        s.local_seq_no    = LOCAL_SEQ + 1;
-        s.remote_seq_no   = REMOTE_SEQ + 1;
-
+    fn test_established_bad_seq() {
+        let mut s = socket_established();
         // Data outside of receive window.
         send!(s, TcpRepr {
             seq_number: REMOTE_SEQ + 1 + 256,
             ack_number: Some(LOCAL_SEQ + 1),
             ..SEND_TEMPL
         }, Err(Error::Malformed));
+        assert_eq!(s.remote_seq_no, REMOTE_SEQ + 1);
     }
 
     #[test]
-    fn test_recv_data() {
-        let mut s = socket();
-        s.state          = State::Established;
-        s.local_endpoint  = LOCAL_END;
-        s.remote_endpoint = REMOTE_END;
-        s.local_seq_no    = LOCAL_SEQ + 1;
-        s.remote_seq_no   = REMOTE_SEQ + 1;
-
+    fn test_established_rst() {
+        let mut s = socket_established();
         send!(s, [TcpRepr {
+            control: TcpControl::Rst,
             seq_number: REMOTE_SEQ + 1,
             ack_number: Some(LOCAL_SEQ + 1),
-            payload: &b"abcdef"[..],
             ..SEND_TEMPL
         }]);
+        assert_eq!(s.state, State::Closed);
+    }
+
+    // =========================================================================================//
+    // Tests for transitioning through multiple states.
+    // =========================================================================================//
+    #[test]
+    fn test_listen() {
+        let mut s = socket();
+        s.listen(IpEndpoint::new(IpAddress::default(), LOCAL_PORT));
+        assert_eq!(s.state, State::Listen);
+    }
+
+    #[test]
+    fn test_three_way_handshake() {
+        let mut s = socket();
+        s.state           = State::Listen;
+        s.local_endpoint  = IpEndpoint::new(IpAddress::default(), LOCAL_PORT);
+
+        send!(s, [TcpRepr {
+            control: TcpControl::Syn,
+            seq_number: REMOTE_SEQ,
+            ack_number: None,
+            ..SEND_TEMPL
+        }]);
+        assert_eq!(s.state(), State::SynReceived);
+        assert_eq!(s.local_endpoint(), LOCAL_END);
+        assert_eq!(s.remote_endpoint(), REMOTE_END);
         recv!(s, [TcpRepr {
-            seq_number: LOCAL_SEQ + 1,
-            ack_number: Some(REMOTE_SEQ + 1 + 6),
-            window_len: 122,
+            control: TcpControl::Syn,
+            seq_number: LOCAL_SEQ,
+            ack_number: Some(REMOTE_SEQ + 1),
             ..RECV_TEMPL
         }]);
-        assert_eq!(s.rx_buffer.dequeue(6), &b"abcdef"[..]);
+        send!(s, [TcpRepr {
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
+            ..SEND_TEMPL
+        }]);
+        assert_eq!(s.state(), State::Established);
+        assert_eq!(s.local_seq_no, LOCAL_SEQ + 1);
+        assert_eq!(s.remote_seq_no, REMOTE_SEQ + 1);
     }
 }