Browse Source

Implement the TCP SYN-SENT state.

whitequark 8 years ago
parent
commit
1a32b98b48
2 changed files with 150 additions and 7 deletions
  1. 144 7
      src/socket/tcp.rs
  2. 6 0
      src/wire/ip.rs

+ 144 - 7
src/socket/tcp.rs

@@ -25,6 +25,11 @@ impl<'a> SocketBuffer<'a> {
         }
     }
 
+    fn clear(&mut self) {
+        self.read_at = 0;
+        self.length = 0;
+    }
+
     fn capacity(&self) -> usize {
         self.storage.len()
     }
@@ -253,6 +258,8 @@ pub struct TcpSocket<'a> {
     debug_id:        usize
 }
 
+const DEFAULT_MSS: usize = 536;
+
 impl<'a> TcpSocket<'a> {
     /// Create a socket using the given buffers.
     pub fn new<T>(rx_buffer: T, tx_buffer: T) -> Socket<'a, 'static>
@@ -273,7 +280,7 @@ impl<'a> TcpSocket<'a> {
             remote_last_seq: TcpSeqNumber(0),
             remote_last_ack: TcpSeqNumber(0),
             remote_win_len:  0,
-            remote_mss:      536,
+            remote_mss:      DEFAULT_MSS,
             retransmit:      Retransmit::new(),
             tx_buffer:       tx_buffer.into(),
             rx_buffer:       rx_buffer.into(),
@@ -311,20 +318,77 @@ impl<'a> TcpSocket<'a> {
         self.state
     }
 
+    fn reset(&mut self) {
+        self.listen_address  = IpAddress::default();
+        self.local_endpoint  = IpEndpoint::default();
+        self.remote_endpoint = IpEndpoint::default();
+        self.local_seq_no    = TcpSeqNumber(0);
+        self.remote_seq_no   = TcpSeqNumber(0);
+        self.remote_last_seq = TcpSeqNumber(0);
+        self.remote_last_ack = TcpSeqNumber(0);
+        self.remote_win_len  = 0;
+        self.remote_win_len  = 0;
+        self.remote_mss      = DEFAULT_MSS;
+        self.retransmit.reset();
+        self.tx_buffer.clear();
+        self.rx_buffer.clear();
+    }
+
     /// Start listening on the given endpoint.
     ///
     /// This function returns an error if the socket was open; see [is_open](#method.is_open).
-    pub fn listen<T: Into<IpEndpoint>>(&mut self, endpoint: T) -> Result<(), ()> {
+    /// It also returns an error if the specified port is zero.
+    pub fn listen<T>(&mut self, local_endpoint: T) -> Result<(), ()>
+            where T: Into<IpEndpoint> {
+        let local_endpoint = local_endpoint.into();
+
         if self.is_open() { return Err(()) }
+        if local_endpoint.port == 0 { return Err(()) }
 
-        let endpoint = endpoint.into();
-        self.listen_address  = endpoint.addr;
-        self.local_endpoint  = endpoint;
+        self.reset();
+        self.listen_address  = local_endpoint.addr;
+        self.local_endpoint  = local_endpoint;
         self.remote_endpoint = IpEndpoint::default();
         self.set_state(State::Listen);
         Ok(())
     }
 
+    /// Connect to a given endpoint.
+    ///
+    /// The local port must be provided explicitly. Assuming `fn get_ephemeral_port() -> u16`
+    /// allocates a port from the 49152 to 65535 range, a connection may be established as follows:
+    ///
+    /// ```rust,ignore
+    /// socket.connect((IpAddress::v4(10, 0, 0, 1), 80), get_ephemeral_port())
+    /// ```
+    ///
+    /// The local address may optionally be provided.
+    ///
+    /// This function returns an error if the socket was open; see [is_open](#method.is_open).
+    /// It also returns an error if the local or remote port is zero, or if
+    /// the local or remote address is unspecified.
+    pub fn connect<T, U>(&mut self, remote_endpoint: T, local_endpoint: U) -> Result<(), ()>
+            where T: Into<IpEndpoint>, U: Into<IpEndpoint> {
+        let remote_endpoint = remote_endpoint.into();
+        let local_endpoint  = local_endpoint.into();
+
+        if self.is_open() { return Err(()) }
+        if remote_endpoint.port == 0 { return Err(()) }
+        if remote_endpoint.addr.is_unspecified() { return Err(()) }
+        if local_endpoint.port == 0 { return Err(()) }
+        if local_endpoint.addr.is_unspecified() { return Err(()) }
+
+        // Carry over the local sequence number.
+        let local_seq_no = self.local_seq_no;
+
+        self.reset();
+        self.local_endpoint  = local_endpoint;
+        self.remote_endpoint = remote_endpoint;
+        self.local_seq_no    = local_seq_no;
+        self.set_state(State::SynSent);
+        Ok(())
+    }
+
     /// Close the transmit half of the full-duplex connection.
     ///
     /// Note that there is no corresponding function for the receive half of the full-duplex
@@ -715,6 +779,23 @@ impl<'a> TcpSocket<'a> {
                 self.retransmit.reset();
             }
 
+            // SYN|ACK packets in the SYN-SENT state change it to ESTABLISHED.
+            (State::SynSent, TcpRepr {
+                control: TcpControl::Syn, seq_number, ack_number: Some(_),
+                max_seg_size, ..
+            }) => {
+                net_trace!("[{}]{}:{}: received SYN|ACK",
+                           self.debug_id, self.local_endpoint, self.remote_endpoint);
+                self.remote_last_seq = self.local_seq_no + 1;
+                self.remote_seq_no   = seq_number + 1;
+                self.remote_last_ack = seq_number;
+                if let Some(max_seg_size) = max_seg_size {
+                    self.remote_mss = max_seg_size as usize;
+                }
+                self.set_state(State::Established);
+                self.retransmit.reset();
+            }
+
             // ACK packets in ESTABLISHED state reset the retransmit timer.
             (State::Established, TcpRepr { control: TcpControl::None, .. }) => {
                 self.retransmit.reset()
@@ -962,8 +1043,10 @@ impl<'a> TcpSocket<'a> {
                            self.retransmit.delay);
             }
 
-            repr.ack_number = Some(ack_number);
-            self.remote_last_ack = ack_number;
+            if self.state != State::SynSent {
+                repr.ack_number = Some(ack_number);
+                self.remote_last_ack = ack_number;
+            }
 
             // Remember the header length before enabling the MSS option, since that option
             // only affects SYN packets.
@@ -1249,6 +1332,12 @@ mod test {
         sanity!(s, socket_listen());
     }
 
+    #[test]
+    fn test_listen_validation() {
+        let mut s = socket();
+        assert_eq!(s.listen(0), Err(()));
+    }
+
     #[test]
     fn test_listen_syn() {
         let mut s = socket_listen();
@@ -1358,6 +1447,54 @@ mod test {
         s
     }
 
+    #[test]
+    fn test_connect_validation() {
+        let mut s = socket();
+        assert_eq!(s.connect(REMOTE_END, (IpAddress::Unspecified, 80)), Err(()));
+        assert_eq!(s.connect(REMOTE_END, (IpAddress::v4(0, 0, 0, 0), 80)), Err(()));
+        assert_eq!(s.connect(REMOTE_END, (IpAddress::v4(10, 0, 0, 0), 0)), Err(()));
+        assert_eq!(s.connect((IpAddress::Unspecified, 80), LOCAL_END), Err(()));
+        assert_eq!(s.connect((IpAddress::v4(0, 0, 0, 0), 80), LOCAL_END), Err(()));
+        assert_eq!(s.connect((IpAddress::v4(10, 0, 0, 0), 0), LOCAL_END), Err(()));
+    }
+
+    #[test]
+    fn test_syn_sent_sanity() {
+        let mut s = socket();
+        s.local_seq_no    = LOCAL_SEQ;
+        s.connect(REMOTE_END, LOCAL_END).unwrap();
+        sanity!(s, socket_syn_sent());
+    }
+
+    #[test]
+    fn test_syn_sent_syn_ack() {
+        let mut s = socket_syn_sent();
+        recv!(s, [TcpRepr {
+            control:    TcpControl::Syn,
+            seq_number: LOCAL_SEQ,
+            ack_number: None,
+            max_seg_size: Some(1480),
+            ..RECV_TEMPL
+        }]);
+        send!(s, TcpRepr {
+            control:    TcpControl::Syn,
+            seq_number: REMOTE_SEQ,
+            ack_number: Some(LOCAL_SEQ + 1),
+            max_seg_size: Some(1400),
+            ..SEND_TEMPL
+        });
+        recv!(s, [TcpRepr {
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(REMOTE_SEQ + 1),
+            ..RECV_TEMPL
+        }]);
+        assert_eq!(s.state, State::Established);
+        sanity!(s, TcpSocket {
+            retransmit: Retransmit { resend_at: 100, delay: 100 },
+            ..socket_established()
+        });
+    }
+
     #[test]
     fn test_syn_sent_rst() {
         let mut s = socket_syn_sent();

+ 6 - 0
src/wire/ip.rs

@@ -112,6 +112,12 @@ impl From<u16> for Endpoint {
     }
 }
 
+impl<T: Into<Address>> From<(T, u16)> for Endpoint {
+    fn from((addr, port): (T, u16)) -> Endpoint {
+        Endpoint { addr: addr.into(), port: port }
+    }
+}
+
 /// An IP packet representation.
 ///
 /// This enum abstracts the various versions of IP packets. It either contains a concrete