Browse Source

Correctly treat TCP ACKs that acknowledge both data and a FIN.

whitequark 8 years ago
parent
commit
b90495fe8a
1 changed files with 48 additions and 18 deletions
  1. 48 18
      src/socket/tcp.rs

+ 48 - 18
src/socket/tcp.rs

@@ -575,6 +575,18 @@ impl<'a> TcpSocket<'a> {
         if !self.remote_endpoint.addr.is_unspecified() &&
            self.remote_endpoint.addr != ip_repr.src_addr() { return Err(Error::Rejected) }
 
+        // Consider how much the sequence number space differs from the transmit buffer space.
+        let (sent_syn, sent_fin) = match self.state {
+            // In SYN-SENT or SYN-RECEIVED, we've just sent a SYN.
+            State::SynSent | State::SynReceived => (true, false),
+            // In FIN-WAIT-1, LAST-ACK, or CLOSING, we've just sent a FIN.
+            State::FinWait1 | State::LastAck | State::Closing => (false, true),
+            // In all other states we've already got acknowledgemetns for
+            // all of the control flags we sent.
+            _ => (false, false)
+        };
+        let control_len = (sent_syn as usize) + (sent_fin as usize);
+
         // Reject unacceptable acknowledgements.
         match (self.state, repr) {
             // The initial SYN (or whatever) cannot contain an acknowledgement.
@@ -609,16 +621,7 @@ impl<'a> TcpSocket<'a> {
                 return Err(Error::Malformed)
             }
             // Every acknowledgement must be for transmitted but unacknowledged data.
-            (state, TcpRepr { ack_number: Some(ack_number), .. }) => {
-                let control_len = match state {
-                    // In SYN-SENT or SYN-RECEIVED, we've just sent a SYN.
-                    State::SynSent | State::SynReceived => 1,
-                    // In FIN-WAIT-1, LAST-ACK, or CLOSING, we've just sent a FIN.
-                    State::FinWait1 | State::LastAck | State::Closing => 1,
-                    // In all other states we've already got acknowledgemetns for
-                    // all of the control flags we sent.
-                    _ => 0
-                };
+            (_, TcpRepr { ack_number: Some(ack_number), .. }) => {
                 let unacknowledged = self.tx_buffer.len() + control_len;
                 if !(ack_number >= self.local_seq_no &&
                      ack_number <= (self.local_seq_no + unacknowledged)) {
@@ -708,7 +711,6 @@ impl<'a> TcpSocket<'a> {
 
             // ACK packets in the SYN-RECEIVED state change it to ESTABLISHED.
             (State::SynReceived, TcpRepr { control: TcpControl::None, .. }) => {
-                self.local_seq_no   += 1;
                 self.set_state(State::Established);
                 self.retransmit.reset();
             }
@@ -725,7 +727,6 @@ impl<'a> TcpSocket<'a> {
 
             // ACK packets in FIN-WAIT-1 state change it to FIN-WAIT-2.
             (State::FinWait1, TcpRepr { control: TcpControl::None, .. }) => {
-                self.local_seq_no   += 1;
                 self.set_state(State::FinWait2);
             }
 
@@ -745,7 +746,6 @@ impl<'a> TcpSocket<'a> {
 
             // ACK packets in CLOSING state change it to TIME-WAIT.
             (State::Closing, TcpRepr { control: TcpControl::None, .. }) => {
-                self.local_seq_no   += 1;
                 self.set_state(State::TimeWait);
                 self.retransmit.reset();
             }
@@ -758,7 +758,6 @@ impl<'a> TcpSocket<'a> {
                 // Clear the remote endpoint, or we'll send an RST there.
                 self.set_state(State::Closed);
                 self.remote_endpoint = IpEndpoint::default();
-                self.local_seq_no   += 1;
             }
 
             _ => {
@@ -770,13 +769,23 @@ impl<'a> TcpSocket<'a> {
 
         // Dequeue acknowledged octets.
         if let Some(ack_number) = repr.ack_number {
-            let ack_length = ack_number - self.local_seq_no;
-            if ack_length > 0 {
+            let mut ack_len = ack_number - self.local_seq_no;
+            // There could have been no data sent before the SYN, so we always remove it
+            // from the sequence space.
+            if sent_syn {
+                ack_len -= 1
+            }
+            // We could've sent data before the FIN, so only remove FIN from the sequence
+            // space if all of that data is acknowledged.
+            if sent_fin && self.tx_buffer.len() + 1 == ack_len {
+                ack_len -= 1
+            }
+            if ack_len > 0 {
                 net_trace!("[{}]{}:{}: tx buffer: dequeueing {} octets (now {})",
                            self.debug_id, self.local_endpoint, self.remote_endpoint,
-                           ack_length, self.tx_buffer.len() - ack_length);
+                           ack_len, self.tx_buffer.len() - ack_len);
             }
-            self.tx_buffer.advance(ack_length);
+            self.tx_buffer.advance(ack_len);
             self.local_seq_no = ack_number;
         }
 
@@ -1962,6 +1971,27 @@ mod test {
         }])
     }
 
+    #[test]
+    fn test_mutual_close_with_data() {
+        let mut s = socket_established();
+        s.send_slice(b"abcdef").unwrap();
+        s.close();
+        assert_eq!(s.state, State::FinWait1);
+        recv!(s, [TcpRepr {
+            control: TcpControl::Fin,
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload:    &b"abcdef"[..],
+            ..RECV_TEMPL
+        }]);
+        send!(s, TcpRepr {
+            control: TcpControl::Fin,
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1 + 6 + 1),
+            ..SEND_TEMPL
+        });
+    }
+
     // =========================================================================================//
     // Tests for retransmission on packet loss.
     // =========================================================================================//