Browse Source

More rigorously treat the TcpSocket::remote_last_ack field.

Zero is a valid sequence number, treating it as an absence of value
isn't any good. This is unlikely to cause any real harm but it
just isn't good practice, nor does it make for understandable code.
whitequark 7 years ago
parent
commit
86a05f13f5
1 changed files with 16 additions and 13 deletions
  1. 16 13
      src/socket/tcp.rs

+ 16 - 13
src/socket/tcp.rs

@@ -272,7 +272,7 @@ pub struct TcpSocket<'a> {
     remote_last_seq: TcpSeqNumber,
     /// The last acknowledgement number sent.
     /// I.e. in an idle socket, remote_seq_no+rx_buffer.len().
-    remote_last_ack: TcpSeqNumber,
+    remote_last_ack: Option<TcpSeqNumber>,
     /// The speculative remote window size.
     /// I.e. the actual remote window size minus the count of in-flight octets.
     remote_win_len:  usize,
@@ -304,7 +304,7 @@ impl<'a> TcpSocket<'a> {
             local_seq_no:    TcpSeqNumber::default(),
             remote_seq_no:   TcpSeqNumber::default(),
             remote_last_seq: TcpSeqNumber::default(),
-            remote_last_ack: TcpSeqNumber::default(),
+            remote_last_ack: None,
             remote_win_len:  0,
             remote_mss:      DEFAULT_MSS,
         })
@@ -350,7 +350,7 @@ impl<'a> TcpSocket<'a> {
         self.local_seq_no    = TcpSeqNumber::default();
         self.remote_seq_no   = TcpSeqNumber::default();
         self.remote_last_seq = TcpSeqNumber::default();
-        self.remote_last_ack = TcpSeqNumber::default();
+        self.remote_last_ack = None;
         self.remote_win_len  = 0;
         self.remote_mss      = DEFAULT_MSS;
         self.timer.reset();
@@ -739,7 +739,7 @@ impl<'a> TcpSocket<'a> {
         // and an acknowledgment indicating the next sequence number expected
         // to be received.
         reply_repr.seq_number = self.remote_last_seq;
-        reply_repr.ack_number = Some(self.remote_last_ack);
+        reply_repr.ack_number = self.remote_last_ack;
         reply_repr.window_len = self.rx_buffer.window() as u16;
 
         (ip_reply_repr, reply_repr)
@@ -974,7 +974,6 @@ impl<'a> TcpSocket<'a> {
                 self.local_endpoint  = IpEndpoint::new(ip_repr.dst_addr(), repr.dst_port);
                 self.remote_seq_no   = repr.seq_number + 1;
                 self.remote_last_seq = self.local_seq_no + 1;
-                self.remote_last_ack = repr.seq_number;
                 if let Some(max_seg_size) = repr.max_seg_size {
                     self.remote_mss = max_seg_size as usize;
                 }
@@ -1077,7 +1076,7 @@ impl<'a> TcpSocket<'a> {
             self.rx_buffer.enqueue_slice(repr.payload);
 
             // Send an acknowledgement.
-            self.remote_last_ack = self.remote_seq_no + self.rx_buffer.len();
+            self.remote_last_ack = Some(self.remote_seq_no + self.rx_buffer.len());
             Ok(Some(self.ack_reply(ip_repr, &repr)))
         } else {
             // No data to acknowledge; the logic to acknowledge SYN and FIN flags
@@ -1091,7 +1090,11 @@ impl<'a> TcpSocket<'a> {
     }
 
     fn ack_to_transmit(&self) -> bool {
-        self.remote_last_ack < self.remote_seq_no + self.rx_buffer.len()
+        if let Some(remote_last_ack) = self.remote_last_ack {
+            remote_last_ack < self.remote_seq_no + self.rx_buffer.len()
+        } else {
+            true
+        }
     }
 
     pub(crate) fn dispatch<F>(&mut self, timestamp: u64, limits: &DeviceLimits,
@@ -1182,7 +1185,7 @@ impl<'a> TcpSocket<'a> {
 
         if self.seq_to_transmit(repr.control) && repr.segment_len() > 0 {
             // If we have data to transmit and it fits into partner's window, do it.
-        } else if self.ack_to_transmit() {
+        } else if self.ack_to_transmit() && repr.ack_number.is_some() {
             // If we have data to acknowledge, do it.
         } else if self.timer.should_retransmit(timestamp).is_some() {
             // If we have packets to retransmit, do it.
@@ -1248,7 +1251,7 @@ impl<'a> TcpSocket<'a> {
 
         // We've sent a packet successfully, so we can update the internal state now.
         self.remote_last_seq = repr.seq_number + repr.segment_len();
-        self.remote_last_ack = repr.ack_number.unwrap_or_default();
+        self.remote_last_ack = repr.ack_number;
 
         if !self.seq_to_transmit(repr.control) && repr.segment_len() > 0 {
             // If we've transmitted all data could (and there was something at all,
@@ -1648,7 +1651,7 @@ mod test {
         })));
         assert_eq!(s.state, State::CloseWait);
         sanity!(s, TcpSocket {
-            remote_last_ack: REMOTE_SEQ + 1 + 6 + 1,
+            remote_last_ack: Some(REMOTE_SEQ + 1 + 6 + 1),
             ..socket_close_wait()
         });
     }
@@ -1843,7 +1846,7 @@ mod test {
         s.state           = State::Established;
         s.local_seq_no    = LOCAL_SEQ + 1;
         s.remote_last_seq = LOCAL_SEQ + 1;
-        s.remote_last_ack = REMOTE_SEQ + 1;
+        s.remote_last_ack = Some(REMOTE_SEQ + 1);
         s
     }
 
@@ -2214,7 +2217,7 @@ mod test {
         s.state           = State::TimeWait;
         s.remote_seq_no   = REMOTE_SEQ + 1 + 1;
         if from_closing {
-            s.remote_last_ack = REMOTE_SEQ + 1 + 1;
+            s.remote_last_ack = Some(REMOTE_SEQ + 1 + 1);
         }
         s.timer           = Timer::Close { expires_at: 1_000 + CLOSE_DELAY };
         s
@@ -2284,7 +2287,7 @@ mod test {
         let mut s = socket_established();
         s.state           = State::CloseWait;
         s.remote_seq_no   = REMOTE_SEQ + 1 + 1;
-        s.remote_last_ack = REMOTE_SEQ + 1 + 1;
+        s.remote_last_ack = Some(REMOTE_SEQ + 1 + 1);
         s
     }