Ver código fonte

Simplify TCP ACK handling.

whitequark 8 anos atrás
pai
commit
86bd034ec0
1 arquivos alterados com 32 adições e 33 exclusões
  1. 32 33
      src/socket/tcp.rs

+ 32 - 33
src/socket/tcp.rs

@@ -380,8 +380,15 @@ impl<'a> TcpSocket<'a> {
             }
             // Every acknowledgement must be for transmitted but unacknowledged data.
             (state, TcpRepr { ack_number: Some(ack_number), .. }) => {
-                let control_len =
-                    if state == State::SynReceived { 1 } else { 0 };
+                let control_len = match state {
+                    // In SYN_SENT or SYN_RECEIVED, we've just sent a SYN.
+                    State::SynSent | State::SynReceived => 1,
+                    // In FIN_WAIT_1 or LAST_ACK, we've just sent a FIN.
+                    State::FinWait1 | State::LastAck => 1,
+                    // In all other states we've already got acknowledgemetns for
+                    // all of the control flags we sent.
+                    _ => 0
+                };
                 let unacknowledged = self.tx_buffer.len() as i32 + control_len;
                 if !(ack_number - self.local_seq_no >= 0 &&
                      ack_number - (self.local_seq_no + unacknowledged) <= 0) {
@@ -416,7 +423,6 @@ impl<'a> TcpSocket<'a> {
         }
 
         // Validate and update the state.
-        let old_state = self.state;
         match (self.state, repr) {
             // RSTs are ignored in the LISTEN state.
             (State::Listen, TcpRepr { control: TcpControl::Rst, .. }) =>
@@ -445,14 +451,15 @@ impl<'a> TcpSocket<'a> {
                 self.local_endpoint  = IpEndpoint::new(ip_repr.dst_addr(), dst_port);
                 self.remote_endpoint = IpEndpoint::new(ip_repr.src_addr(), src_port);
                 self.local_seq_no    = -seq_number; // FIXME: use something more secure
+                self.remote_last_seq = self.local_seq_no + 1;
                 self.remote_seq_no   = seq_number + 1;
                 self.set_state(State::SynReceived);
                 self.retransmit.reset()
             }
 
-            // SYN|ACK packets in the SYN_RECEIVED state change it to ESTABLISHED.
+            // ACK packets in the SYN_RECEIVED state change it to ESTABLISHED.
             (State::SynReceived, TcpRepr { control: TcpControl::None, .. }) => {
-                self.remote_last_seq = self.local_seq_no + 1;
+                self.local_seq_no   += 1;
                 self.set_state(State::Established);
                 self.retransmit.reset()
             }
@@ -462,6 +469,7 @@ impl<'a> TcpSocket<'a> {
 
             // FIN packets in ESTABLISHED state indicate the remote side has closed.
             (State::Established, TcpRepr { control: TcpControl::Fin, .. }) => {
+                self.remote_seq_no  += 1;
                 self.set_state(State::CloseWait);
                 self.retransmit.reset()
             }
@@ -475,9 +483,7 @@ impl<'a> TcpSocket<'a> {
 
         // Dequeue acknowledged octets.
         if let Some(ack_number) = repr.ack_number {
-            let control_len =
-                if old_state == State::SynReceived { 1 } else { 0 };
-            let ack_length = ack_number - self.local_seq_no - control_len;
+            let ack_length = ack_number - self.local_seq_no;
             if ack_length > 0 {
                 net_trace!("tcp:{}:{}: tx buffer: dequeueing {} octets",
                            self.local_endpoint, self.remote_endpoint,
@@ -518,14 +524,7 @@ impl<'a> TcpSocket<'a> {
             payload:    &[]
         };
 
-        let mut ack_number = self.remote_seq_no + self.rx_buffer.len() as i32;
-        match self.state {
-            // In CLOSE_WAIT or CLOSING, we have received a FIN and must acknowledge it.
-            State::CloseWait | State::Closing =>
-                ack_number += 1,
-            _ => ()
-        }
-
+        let mut should_send = false;
         match self.state {
             State::Closed | State::Listen => return Err(Error::Exhausted),
 
@@ -535,6 +534,7 @@ impl<'a> TcpSocket<'a> {
                 repr.control = TcpControl::Syn;
                 net_trace!("tcp:{}:{}: sending SYN|ACK",
                            self.local_endpoint, self.remote_endpoint);
+                should_send = true;
             }
 
             State::Established |
@@ -566,32 +566,29 @@ impl<'a> TcpSocket<'a> {
                     self.remote_win_len -= data.len();
                     // Advance the in-flight sequence number.
                     self.remote_last_seq += data.len() as i32;
-                } else if self.remote_last_ack != ack_number {
-                    // We don't have anything to send, or can't because the remote end does not
-                    // have any space to accept it, but we haven't yet acknowledged everything
-                    // we have received. So, do it.
-                    net_trace!("tcp:{}:{}: sending ACK",
-                               self.local_endpoint, self.remote_endpoint);
-                } else {
-                    // We don't have anything to send and we've already acknowledged everything.
-                    return Err(Error::Exhausted)
+                    should_send = true;
                 }
             }
 
             _ => unreachable!()
         }
 
-        match self.state {
-            // We don't have anything to acknowledge yet.
-            State::Closed | State::Listen | State::SynSent => (),
+        let ack_number = self.remote_seq_no + self.rx_buffer.len() as i32;
+        if !should_send && self.remote_last_ack != ack_number {
             // Acknowledge all data we have received, since it is all in order.
-            _ => {
-                self.remote_last_ack = ack_number;
-                repr.ack_number = Some(ack_number);
-            }
+            net_trace!("tcp:{}:{}: sending ACK",
+                       self.local_endpoint, self.remote_endpoint);
+            should_send = true;
         }
 
-        emit(&ip_repr, &repr)
+        if should_send {
+            repr.ack_number = Some(ack_number);
+            self.remote_last_ack = ack_number;
+
+            emit(&ip_repr, &repr)
+        } else {
+            Err(Error::Exhausted)
+        }
     }
 }
 
@@ -654,6 +651,7 @@ mod test {
     };
 
     fn send(socket: &mut TcpSocket, repr: &TcpRepr) -> Result<(), Error> {
+        trace!("send: {}", repr);
         let mut buffer = vec![0; repr.buffer_len()];
         let mut packet = TcpPacket::new(&mut buffer).unwrap();
         repr.emit(&mut packet, &REMOTE_IP, &LOCAL_IP);
@@ -677,6 +675,7 @@ mod test {
             payload.emit(&ip_repr, &mut buffer[..]);
             let packet = TcpPacket::new(&buffer[..]).unwrap();
             let repr = try!(TcpRepr::parse(&packet, &ip_repr.src_addr(), &ip_repr.dst_addr()));
+            trace!("recv: {}", repr);
             Ok(f(Ok(repr)))
         });
         // Appease borrow checker.