Browse Source

Return `Error::Finished` in `recv()` on graceful close.

This allows applications to distinguish whether the remote end has
gracefully closed the connection with a FIN (in which case it has
received all the data intact), or the connection has failed due to e.g. a
RST or a timeout (in which case the received data may be truncated).
Dario Nieuwenhuis 4 years ago
parent
commit
8d96626fc8
1 changed files with 176 additions and 5 deletions
  1. 176 5
      src/socket/tcp.rs

+ 176 - 5
src/socket/tcp.rs

@@ -192,6 +192,7 @@ pub struct TcpSocket<'a> {
     timer:           Timer,
     assembler:       Assembler,
     rx_buffer:       SocketBuffer<'a>,
+    rx_fin_received: bool,
     tx_buffer:       SocketBuffer<'a>,
     /// Interval after which, if no inbound packets are received, the connection is aborted.
     timeout:         Option<Duration>,
@@ -276,6 +277,7 @@ impl<'a> TcpSocket<'a> {
             assembler:       Assembler::new(rx_buffer.capacity()),
             tx_buffer:       tx_buffer,
             rx_buffer:       rx_buffer,
+            rx_fin_received: false,
             timeout:         None,
             keep_alive:      None,
             hop_limit:       None,
@@ -419,6 +421,7 @@ impl<'a> TcpSocket<'a> {
         self.assembler       = Assembler::new(self.rx_buffer.capacity());
         self.tx_buffer.clear();
         self.rx_buffer.clear();
+        self.rx_fin_received = false;
         self.keep_alive      = None;
         self.timeout         = None;
         self.hop_limit       = None;
@@ -706,12 +709,23 @@ impl<'a> TcpSocket<'a> {
         })
     }
 
-    fn recv_impl<'b, F, R>(&'b mut self, f: F) -> Result<R>
-            where F: FnOnce(&'b mut SocketBuffer<'a>) -> (usize, R) {
+    fn recv_error_check(&mut self) -> Result<()> {
         // We may have received some data inside the initial SYN, but until the connection
         // is fully open we must not dequeue any data, as it may be overwritten by e.g.
         // another (stale) SYN. (We do not support TCP Fast Open.)
-        if !self.may_recv() { return Err(Error::Illegal) }
+        if !self.may_recv() {
+            if self.rx_fin_received {
+                return Err(Error::Finished)
+            }
+            return Err(Error::Illegal)
+        }
+
+        Ok(())
+    }
+
+    fn recv_impl<'b, F, R>(&'b mut self, f: F) -> Result<R>
+            where F: FnOnce(&'b mut SocketBuffer<'a>) -> (usize, R) {
+        self.recv_error_check()?;
 
         let _old_length = self.rx_buffer.len();
         let (size, result) = f(&mut self.rx_buffer);
@@ -755,8 +769,7 @@ impl<'a> TcpSocket<'a> {
     ///
     /// This function otherwise behaves identically to [recv](#method.recv).
     pub fn peek(&mut self, size: usize) -> Result<&[u8]> {
-        // See recv() above.
-        if !self.may_recv() { return Err(Error::Illegal) }
+        self.recv_error_check()?;
 
         let buffer = self.rx_buffer.get_allocated(0, size);
         if buffer.len() > 0 {
@@ -1140,6 +1153,7 @@ impl<'a> TcpSocket<'a> {
             // 7th and 8th steps in the "SEGMENT ARRIVES" event describe this behavior.
             (State::SynReceived, TcpControl::Fin) => {
                 self.remote_seq_no  += 1;
+                self.rx_fin_received = true;
                 self.set_state(State::CloseWait);
                 self.timer.set_for_idle(timestamp, self.keep_alive);
             }
@@ -1170,6 +1184,7 @@ impl<'a> TcpSocket<'a> {
             // FIN packets in ESTABLISHED state indicate the remote side has closed.
             (State::Established, TcpControl::Fin) => {
                 self.remote_seq_no  += 1;
+                self.rx_fin_received = true;
                 self.set_state(State::CloseWait);
                 self.timer.set_for_idle(timestamp, self.keep_alive);
             }
@@ -1187,6 +1202,7 @@ impl<'a> TcpSocket<'a> {
             // if they also acknowledge our FIN.
             (State::FinWait1, TcpControl::Fin) => {
                 self.remote_seq_no  += 1;
+                self.rx_fin_received = true;
                 if ack_of_fin {
                     self.set_state(State::TimeWait);
                     self.timer.set_for_close(timestamp);
@@ -1204,6 +1220,7 @@ impl<'a> TcpSocket<'a> {
             // FIN packets in FIN-WAIT-2 state change it to TIME-WAIT.
             (State::FinWait2, TcpControl::Fin) => {
                 self.remote_seq_no  += 1;
+                self.rx_fin_received = true;
                 self.set_state(State::TimeWait);
                 self.timer.set_for_close(timestamp);
             }
@@ -4600,6 +4617,160 @@ mod test {
         }));
     }
 
+    // =========================================================================================//
+    // Tests for graceful vs ungraceful rx close
+    // =========================================================================================//
+
+    #[test]
+    fn test_rx_close_fin() {
+        let mut s = socket_established();
+        send!(s, TcpRepr {
+            control:    TcpControl::Fin,
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
+            payload:    &b"abc"[..],
+            ..SEND_TEMPL
+        });
+        s.recv(|data| {
+            assert_eq!(data, b"abc");
+            (3, ())
+        }).unwrap();
+        assert_eq!(s.recv(|_| (0, ())), Err(Error::Finished));
+    }
+
+    #[test]
+    fn test_rx_close_fin_in_fin_wait_1() {
+        let mut s = socket_fin_wait_1();
+        send!(s, TcpRepr {
+            control:    TcpControl::Fin,
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
+            payload:    &b"abc"[..],
+            ..SEND_TEMPL
+        });
+        assert_eq!(s.state, State::Closing);
+        s.recv(|data| {
+            assert_eq!(data, b"abc");
+            (3, ())
+        }).unwrap();
+        assert_eq!(s.recv(|_| (0, ())), Err(Error::Finished));
+    }
+
+    #[test]
+    fn test_rx_close_fin_in_fin_wait_2() {
+        let mut s = socket_fin_wait_2();
+        send!(s, TcpRepr {
+            control:    TcpControl::Fin,
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1 + 1),
+            payload:    &b"abc"[..],
+            ..SEND_TEMPL
+        });
+        assert_eq!(s.state, State::TimeWait);
+        s.recv(|data| {
+            assert_eq!(data, b"abc");
+            (3, ())
+        }).unwrap();
+        assert_eq!(s.recv(|_| (0, ())), Err(Error::Finished));
+    }
+
+
+
+    #[test]
+    fn test_rx_close_fin_with_hole() {
+        let mut s = socket_established();
+        send!(s, TcpRepr {
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
+            payload:    &b"abc"[..],
+            ..SEND_TEMPL
+        });
+        send!(s, TcpRepr {
+            control:    TcpControl::Fin,
+            seq_number: REMOTE_SEQ + 1 + 6,
+            ack_number: Some(LOCAL_SEQ + 1),
+            payload:    &b"ghi"[..],
+            ..SEND_TEMPL
+        }, Ok(Some(TcpRepr {
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(REMOTE_SEQ + 1 + 3),
+            window_len: 61,
+            ..RECV_TEMPL
+        })));
+        s.recv(|data| {
+            assert_eq!(data, b"abc");
+            (3, ())
+        }).unwrap();
+        s.recv(|data| {
+            assert_eq!(data, b"");
+            (0, ())
+        }).unwrap();
+        send!(s, TcpRepr {
+            control:    TcpControl::Rst,
+            seq_number: REMOTE_SEQ + 1 + 9,
+            ack_number: Some(LOCAL_SEQ + 1),
+            ..SEND_TEMPL
+        });
+        // Error must be `Illegal` even if we've received a FIN,
+        // because we are missing data.
+        assert_eq!(s.recv(|_| (0, ())), Err(Error::Illegal));
+    }
+
+    #[test]
+    fn test_rx_close_rst() {
+        let mut s = socket_established();
+        send!(s, TcpRepr {
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
+            payload:    &b"abc"[..],
+            ..SEND_TEMPL
+        });
+        send!(s, TcpRepr {
+            control:    TcpControl::Rst,
+            seq_number: REMOTE_SEQ + 1 + 3,
+            ack_number: Some(LOCAL_SEQ + 1),
+            ..SEND_TEMPL
+        });
+        s.recv(|data| {
+            assert_eq!(data, b"abc");
+            (3, ())
+        }).unwrap();
+        assert_eq!(s.recv(|_| (0, ())), Err(Error::Illegal));
+    }
+
+    #[test]
+    fn test_rx_close_rst_with_hole() {
+        let mut s = socket_established();
+        send!(s, TcpRepr {
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
+            payload:    &b"abc"[..],
+            ..SEND_TEMPL
+        });
+        send!(s, TcpRepr {
+            seq_number: REMOTE_SEQ + 1 + 6,
+            ack_number: Some(LOCAL_SEQ + 1),
+            payload:    &b"ghi"[..],
+            ..SEND_TEMPL
+        }, Ok(Some(TcpRepr {
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(REMOTE_SEQ + 1 + 3),
+            window_len: 61,
+            ..RECV_TEMPL
+        })));
+        send!(s, TcpRepr {
+            control:    TcpControl::Rst,
+            seq_number: REMOTE_SEQ + 1 + 9,
+            ack_number: Some(LOCAL_SEQ + 1),
+            ..SEND_TEMPL
+        });
+        s.recv(|data| {
+            assert_eq!(data, b"abc");
+            (3, ())
+        }).unwrap();
+        assert_eq!(s.recv(|_| (0, ())), Err(Error::Illegal));
+    }
+
     // =========================================================================================//
     // Tests for packet filtering.
     // =========================================================================================//