Преглед изворни кода

Add support for TCP FIN in ESTABLISHED state.

whitequark пре 8 година
родитељ
комит
256211843f
2 измењених фајлова са 61 додато и 11 уклоњено
  1. 6 6
      examples/smoltcpserver.rs
  2. 55 5
      src/socket/tcp.rs

+ 6 - 6
examples/smoltcpserver.rs

@@ -51,11 +51,11 @@ fn main() {
     let endpoint = IpEndpoint::new(IpAddress::default(), 6969);
 
     let udp_rx_buffer = UdpSocketBuffer::new(vec![UdpPacketBuffer::new(vec![0; 64])]);
-    let udp_tx_buffer = UdpSocketBuffer::new(vec![UdpPacketBuffer::new(vec![0; 64])]);
+    let udp_tx_buffer = UdpSocketBuffer::new(vec![UdpPacketBuffer::new(vec![0; 128])]);
     let udp_socket = UdpSocket::new(endpoint, udp_rx_buffer, udp_tx_buffer);
 
     let tcp_rx_buffer = TcpSocketBuffer::new(vec![0; 64]);
-    let tcp_tx_buffer = TcpSocketBuffer::new(vec![0; 64]);
+    let tcp_tx_buffer = TcpSocketBuffer::new(vec![0; 128]);
     let mut tcp_socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
     (tcp_socket.as_socket() : &mut TcpSocket).listen(endpoint);
 
@@ -75,7 +75,7 @@ fn main() {
             let socket: &mut UdpSocket = iface.sockets()[0].as_socket();
             let client = match socket.recv() {
                 Ok((endpoint, data)) => {
-                    debug!("udp recv data: {} from {}",
+                    debug!("udp recv data: {:?} from {}",
                            str::from_utf8(data.as_ref()).unwrap(), endpoint);
                     Some(endpoint)
                 }
@@ -89,7 +89,7 @@ fn main() {
             };
             if let Some(endpoint) = client {
                 let data = b"yo dawg";
-                debug!("udp send data: {}",
+                debug!("udp send data: {:?}",
                        str::from_utf8(data.as_ref()).unwrap());
                 socket.send_slice(endpoint, data).unwrap()
             }
@@ -100,7 +100,7 @@ fn main() {
             let data = {
                 let mut data = socket.recv(128).to_owned();
                 if data.len() > 0 {
-                    debug!("tcp recv data: {}",
+                    debug!("tcp recv data: {:?}",
                            str::from_utf8(data.as_ref()).unwrap());
                     data = data.split(|&b| b == b'\n').next().unwrap().to_owned();
                     data.reverse();
@@ -109,7 +109,7 @@ fn main() {
                 data
             };
             if data.len() > 0 {
-                debug!("tcp send data: {}",
+                debug!("tcp send data: {:?}",
                        str::from_utf8(data.as_ref()).unwrap());
                 socket.send_slice(&data[..]);
             }

+ 55 - 5
src/socket/tcp.rs

@@ -460,6 +460,12 @@ impl<'a> TcpSocket<'a> {
             // ACK packets in ESTABLISHED state do nothing.
             (State::Established, TcpRepr { control: TcpControl::None, .. }) => (),
 
+            // FIN packets in ESTABLISHED state indicate the remote side has closed.
+            (State::Established, TcpRepr { control: TcpControl::Fin, .. }) => {
+                self.set_state(State::CloseWait);
+                self.retransmit.reset()
+            }
+
             _ => {
                 net_trace!("tcp:{}:{}: unexpected packet {}",
                            self.local_endpoint, self.remote_endpoint, repr);
@@ -471,12 +477,13 @@ impl<'a> TcpSocket<'a> {
         if let Some(ack_number) = repr.ack_number {
             let control_len =
                 if old_state == State::SynReceived { 1 } else { 0 };
-            if ack_number - self.local_seq_no - control_len > 0 {
+            let ack_length = ack_number - self.local_seq_no - control_len;
+            if ack_length > 0 {
                 net_trace!("tcp:{}:{}: tx buffer: dequeueing {} octets",
                            self.local_endpoint, self.remote_endpoint,
-                           ack_number - self.local_seq_no - control_len);
+                           ack_length);
             }
-            self.tx_buffer.advance((ack_number - self.local_seq_no - control_len) as usize);
+            self.tx_buffer.advance(ack_length as usize);
             self.local_seq_no = ack_number;
         }
 
@@ -511,7 +518,13 @@ impl<'a> TcpSocket<'a> {
             payload:    &[]
         };
 
-        let ack_number = self.remote_seq_no + self.rx_buffer.len() as i32;
+        let mut ack_number = self.remote_seq_no + self.rx_buffer.len() as i32;
+        match self.state {
+            // In CLOSE_WAIT or CLOSING, we have received a FIN and must acknowledge it.
+            State::CloseWait | State::Closing =>
+                ack_number += 1,
+            _ => ()
+        }
 
         match self.state {
             State::Closed | State::Listen => return Err(Error::Exhausted),
@@ -524,7 +537,8 @@ impl<'a> TcpSocket<'a> {
                            self.local_endpoint, self.remote_endpoint);
             }
 
-            State::Established => {
+            State::Established |
+            State::CloseWait => {
                 // See if we should send data to the remote end because:
                 //   1. the retransmit timer has expired, or...
                 let mut may_send = self.retransmit.check();
@@ -955,6 +969,42 @@ mod test {
         assert_eq!(s.remote_seq_no, REMOTE_SEQ + 1);
     }
 
+    #[test]
+    fn test_established_fin() {
+        let mut s = socket_established();
+        send!(s, [TcpRepr {
+            control: TcpControl::Fin,
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
+            ..SEND_TEMPL
+        }]);
+        assert_eq!(s.state, State::CloseWait);
+        recv!(s, [TcpRepr {
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(REMOTE_SEQ + 1 + 1),
+            ..RECV_TEMPL
+        }]);
+    }
+
+    #[test]
+    fn test_established_send_fin() {
+        let mut s = socket_established();
+        s.tx_buffer.enqueue_slice(b"abcdef");
+        send!(s, [TcpRepr {
+            control: TcpControl::Fin,
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
+            ..SEND_TEMPL
+        }]);
+        assert_eq!(s.state, State::CloseWait);
+        recv!(s, [TcpRepr {
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(REMOTE_SEQ + 1 + 1),
+            payload: &b"abcdef"[..],
+            ..RECV_TEMPL
+        }]);
+    }
+
     #[test]
     fn test_established_rst() {
         let mut s = socket_established();