Переглянути джерело

Use the correct wrapping operations on TCP sequence numbers.

whitequark 8 роки тому
батько
коміт
2c0c1ea76a
3 змінених файлів з 98 додано та 52 видалено
  1. 34 33
      src/socket/tcp.rs
  2. 1 0
      src/wire/mod.rs
  3. 63 19
      src/wire/tcp.rs

+ 34 - 33
src/socket/tcp.rs

@@ -3,7 +3,7 @@ use core::fmt;
 use Error;
 use Managed;
 use wire::{IpProtocol, IpAddress, IpEndpoint};
-use wire::{TcpPacket, TcpRepr, TcpControl};
+use wire::{TcpSeqNumber, TcpPacket, TcpRepr, TcpControl};
 use socket::{Socket, IpRepr, IpPayload};
 
 /// A TCP stream ring buffer.
@@ -185,16 +185,16 @@ pub struct TcpSocket<'a> {
     remote_endpoint: IpEndpoint,
     /// The sequence number corresponding to the beginning of the transmit buffer.
     /// I.e. an ACK(local_seq_no+n) packet removes n bytes from the transmit buffer.
-    local_seq_no:    i32,
+    local_seq_no:    TcpSeqNumber,
     /// The sequence number corresponding to the beginning of the receive buffer.
     /// I.e. userspace reading n bytes adds n to remote_seq_no.
-    remote_seq_no:   i32,
+    remote_seq_no:   TcpSeqNumber,
     /// The last sequence number sent.
     /// I.e. in an idle socket, local_seq_no+tx_buffer.len().
-    remote_last_seq: i32,
+    remote_last_seq: TcpSeqNumber,
     /// The last acknowledgement number sent.
     /// I.e. in an idle socket, remote_seq_no+rx_buffer.len().
-    remote_last_ack: i32,
+    remote_last_ack: TcpSeqNumber,
     /// The speculative remote window size.
     /// I.e. the actual remote window size minus the count of in-flight octets.
     remote_win_len:  usize,
@@ -218,10 +218,10 @@ impl<'a> TcpSocket<'a> {
             listen_address:  IpAddress::default(),
             local_endpoint:  IpEndpoint::default(),
             remote_endpoint: IpEndpoint::default(),
-            local_seq_no:    0,
-            remote_seq_no:   0,
-            remote_last_seq: 0,
-            remote_last_ack: 0,
+            local_seq_no:    TcpSeqNumber(0),
+            remote_seq_no:   TcpSeqNumber(0),
+            remote_last_seq: TcpSeqNumber(0),
+            remote_last_ack: TcpSeqNumber(0),
             remote_win_len:  0,
             retransmit:      Retransmit::new(),
             tx_buffer:       tx_buffer.into(),
@@ -341,7 +341,7 @@ impl<'a> TcpSocket<'a> {
         if !self.can_recv() { return Err(()) }
 
         let buffer = self.rx_buffer.dequeue(size);
-        self.remote_seq_no += buffer.len() as i32;
+        self.remote_seq_no += buffer.len();
         if buffer.len() > 0 {
             net_trace!("tcp:{}:{}: rx buffer: dequeueing {} octets",
                        self.local_endpoint, self.remote_endpoint, buffer.len());
@@ -450,9 +450,9 @@ impl<'a> TcpSocket<'a> {
                     // all of the control flags we sent.
                     _ => 0
                 };
-                let unacknowledged = self.tx_buffer.len() as i32 + control_len;
-                if !(ack_number - self.local_seq_no >= 0 &&
-                     ack_number - (self.local_seq_no + unacknowledged) <= 0) {
+                let unacknowledged = self.tx_buffer.len() + control_len;
+                if !(ack_number >= self.local_seq_no &&
+                     ack_number <= (self.local_seq_no + unacknowledged)) {
                     net_trace!("tcp:{}:{}: unacceptable ACK ({} not in {}..{})",
                                self.local_endpoint, self.remote_endpoint,
                                ack_number, self.local_seq_no, self.local_seq_no + unacknowledged);
@@ -468,13 +468,13 @@ impl<'a> TcpSocket<'a> {
             // In all other states, segments must occupy a valid portion of the receive window.
             // For now, do not try to reassemble out-of-order segments.
             (_, TcpRepr { seq_number, .. }) => {
-                let next_remote_seq = self.remote_seq_no + self.rx_buffer.len() as i32;
-                if seq_number - next_remote_seq > 0 {
+                let next_remote_seq = self.remote_seq_no + self.rx_buffer.len();
+                if seq_number > next_remote_seq {
                     net_trace!("tcp:{}:{}: unacceptable SEQ ({} not in {}..)",
                                self.local_endpoint, self.remote_endpoint,
                                seq_number, next_remote_seq);
                     return Err(Error::Malformed)
-                } else if seq_number - next_remote_seq != 0 {
+                } else if seq_number != next_remote_seq {
                     net_trace!("tcp:{}:{}: duplicate SEQ ({} in ..{})",
                                self.local_endpoint, self.remote_endpoint,
                                seq_number, next_remote_seq);
@@ -511,7 +511,8 @@ impl<'a> TcpSocket<'a> {
             }) => {
                 self.local_endpoint  = IpEndpoint::new(ip_repr.dst_addr(), dst_port);
                 self.remote_endpoint = IpEndpoint::new(ip_repr.src_addr(), src_port);
-                self.local_seq_no    = -seq_number; // FIXME: use something more secure
+                // FIXME: use something more secure here
+                self.local_seq_no    = TcpSeqNumber(-seq_number.0);
                 self.remote_last_seq = self.local_seq_no + 1;
                 self.remote_seq_no   = seq_number + 1;
                 self.set_state(State::SynReceived);
@@ -607,7 +608,7 @@ impl<'a> TcpSocket<'a> {
                 //   1. the retransmit timer has expired, or...
                 let mut may_send = self.retransmit.check();
                 //   2. we've got new data in the transmit buffer.
-                let remote_next_seq = self.local_seq_no + self.tx_buffer.len() as i32;
+                let remote_next_seq = self.local_seq_no + self.tx_buffer.len();
                 if self.remote_last_seq != remote_next_seq {
                     may_send = true;
                 }
@@ -627,9 +628,9 @@ impl<'a> TcpSocket<'a> {
                     repr.payload = data;
                     // Speculatively shrink the remote window. This will get updated the next
                     // time we receive a packet.
-                    self.remote_win_len -= data.len();
+                    self.remote_win_len  -= data.len();
                     // Advance the in-flight sequence number.
-                    self.remote_last_seq += data.len() as i32;
+                    self.remote_last_seq += data.len();
                     should_send = true;
                 }
             }
@@ -637,7 +638,7 @@ impl<'a> TcpSocket<'a> {
             _ => unreachable!()
         }
 
-        let ack_number = self.remote_seq_no + self.rx_buffer.len() as i32;
+        let ack_number = self.remote_seq_no + self.rx_buffer.len();
         if !should_send && self.remote_last_ack != ack_number {
             // Acknowledge all data we have received, since it is all in order.
             net_trace!("tcp:{}:{}: sending ACK",
@@ -692,25 +693,25 @@ mod test {
         buffer.enqueue_slice(&b"bazhoge"[..]);          // zhobarba
     }
 
-    const LOCAL_IP:     IpAddress  = IpAddress::v4(10, 0, 0, 1);
-    const REMOTE_IP:    IpAddress  = IpAddress::v4(10, 0, 0, 2);
-    const LOCAL_PORT:   u16        = 80;
-    const REMOTE_PORT:  u16        = 49500;
-    const LOCAL_END:    IpEndpoint = IpEndpoint::new(LOCAL_IP, LOCAL_PORT);
-    const REMOTE_END:   IpEndpoint = IpEndpoint::new(REMOTE_IP, REMOTE_PORT);
-    const LOCAL_SEQ:    i32        = 10000;
-    const REMOTE_SEQ:   i32        = -10000;
+    const LOCAL_IP:     IpAddress    = IpAddress::v4(10, 0, 0, 1);
+    const REMOTE_IP:    IpAddress    = IpAddress::v4(10, 0, 0, 2);
+    const LOCAL_PORT:   u16          = 80;
+    const REMOTE_PORT:  u16          = 49500;
+    const LOCAL_END:    IpEndpoint   = IpEndpoint::new(LOCAL_IP, LOCAL_PORT);
+    const REMOTE_END:   IpEndpoint   = IpEndpoint::new(REMOTE_IP, REMOTE_PORT);
+    const LOCAL_SEQ:    TcpSeqNumber = TcpSeqNumber(10000);
+    const REMOTE_SEQ:   TcpSeqNumber = TcpSeqNumber(-10000);
 
     const SEND_TEMPL: TcpRepr<'static> = TcpRepr {
         src_port: REMOTE_PORT, dst_port: LOCAL_PORT,
         control: TcpControl::None,
-        seq_number: 0, ack_number: Some(0),
+        seq_number: TcpSeqNumber(0), ack_number: Some(TcpSeqNumber(0)),
         window_len: 256, payload: &[]
     };
     const RECV_TEMPL:  TcpRepr<'static> = TcpRepr {
         src_port: LOCAL_PORT, dst_port: REMOTE_PORT,
         control: TcpControl::None,
-        seq_number: 0, ack_number: Some(0),
+        seq_number: TcpSeqNumber(0), ack_number: Some(TcpSeqNumber(0)),
         window_len: 64, payload: &[]
     };
 
@@ -917,7 +918,7 @@ mod test {
         send!(s, TcpRepr {
             control: TcpControl::Rst,
             seq_number: REMOTE_SEQ,
-            ack_number: Some(1234),
+            ack_number: Some(TcpSeqNumber(1234)),
             ..SEND_TEMPL
         }, Err(Error::Malformed));
         assert_eq!(s.state, State::SynSent);
@@ -1005,7 +1006,7 @@ mod test {
         // Already acknowledged data.
         send!(s, TcpRepr {
             seq_number: REMOTE_SEQ + 1,
-            ack_number: Some(LOCAL_SEQ - 1),
+            ack_number: Some(TcpSeqNumber(LOCAL_SEQ.0 - 1)),
             ..SEND_TEMPL
         }, Err(Error::Malformed));
         assert_eq!(s.local_seq_no, LOCAL_SEQ + 1);

+ 1 - 0
src/wire/mod.rs

@@ -119,6 +119,7 @@ pub use self::icmpv4::Repr as Icmpv4Repr;
 pub use self::udp::Packet as UdpPacket;
 pub use self::udp::Repr as UdpRepr;
 
+pub use self::tcp::SeqNumber as TcpSeqNumber;
 pub use self::tcp::Packet as TcpPacket;
 pub use self::tcp::Repr as TcpRepr;
 pub use self::tcp::Control as TcpControl;

+ 63 - 19
src/wire/tcp.rs

@@ -1,10 +1,54 @@
-use core::fmt;
+use core::{i32, ops, cmp, fmt};
 use byteorder::{ByteOrder, NetworkEndian};
 
 use Error;
 use super::{IpProtocol, IpAddress};
 use super::ip::checksum;
 
+/// A TCP sequence number.
+///
+/// A sequence number is a monotonically advancing integer modulo 2<sup>32</sup>.
+/// Sequence numbers do not have a discontiguity when compared pairwise across a signed overflow.
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub struct SeqNumber(pub i32);
+
+impl fmt::Display for SeqNumber {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        write!(f, "{}", self.0 as u32)
+    }
+}
+
+impl ops::Add<usize> for SeqNumber {
+    type Output = SeqNumber;
+
+    fn add(self, rhs: usize) -> SeqNumber {
+        if rhs > i32::MAX as usize {
+            panic!("attempt to add to sequence number with unsigned overflow")
+        }
+        SeqNumber(self.0.wrapping_add(rhs as i32))
+    }
+}
+
+impl ops::AddAssign<usize> for SeqNumber {
+    fn add_assign(&mut self, rhs: usize) {
+        *self = *self + rhs;
+    }
+}
+
+impl ops::Sub for SeqNumber {
+    type Output = usize;
+
+    fn sub(self, rhs: SeqNumber) -> usize {
+        (self.0 - rhs.0) as usize
+    }
+}
+
+impl cmp::PartialOrd for SeqNumber {
+    fn partial_cmp(&self, other: &SeqNumber) -> Option<cmp::Ordering> {
+        (self.0 - other.0).partial_cmp(&0)
+    }
+}
+
 /// A read/write wrapper around an Transmission Control Protocol packet buffer.
 #[derive(Debug)]
 pub struct Packet<T: AsRef<[u8]>> {
@@ -69,16 +113,16 @@ impl<T: AsRef<[u8]>> Packet<T> {
 
     /// Return the sequence number field.
     #[inline(always)]
-    pub fn seq_number(&self) -> i32 {
+    pub fn seq_number(&self) -> SeqNumber {
         let data = self.buffer.as_ref();
-        NetworkEndian::read_i32(&data[field::SEQ_NUM])
+        SeqNumber(NetworkEndian::read_i32(&data[field::SEQ_NUM]))
     }
 
     /// Return the acknowledgement number field.
     #[inline(always)]
-    pub fn ack_number(&self) -> i32 {
+    pub fn ack_number(&self) -> SeqNumber {
         let data = self.buffer.as_ref();
-        NetworkEndian::read_i32(&data[field::ACK_NUM])
+        SeqNumber(NetworkEndian::read_i32(&data[field::ACK_NUM]))
     }
 
     /// Return the FIN flag.
@@ -184,12 +228,12 @@ impl<T: AsRef<[u8]>> Packet<T> {
 
     /// Return the length of the segment, in terms of sequence space.
     #[inline(always)]
-    pub fn segment_len(&self) -> i32 {
+    pub fn segment_len(&self) -> usize {
         let data = self.buffer.as_ref();
         let mut length = data.len() - self.header_len() as usize;
         if self.syn() { length += 1 }
         if self.fin() { length += 1 }
-        length as i32
+        length
     }
 
     /// Validate the packet checksum.
@@ -234,16 +278,16 @@ impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> {
 
     /// Set the sequence number field.
     #[inline(always)]
-    pub fn set_seq_number(&mut self, value: i32) {
+    pub fn set_seq_number(&mut self, value: SeqNumber) {
         let mut data = self.buffer.as_mut();
-        NetworkEndian::write_i32(&mut data[field::SEQ_NUM], value)
+        NetworkEndian::write_i32(&mut data[field::SEQ_NUM], value.0)
     }
 
     /// Set the acknowledgement number field.
     #[inline(always)]
-    pub fn set_ack_number(&mut self, value: i32) {
+    pub fn set_ack_number(&mut self, value: SeqNumber) {
         let mut data = self.buffer.as_mut();
-        NetworkEndian::write_i32(&mut data[field::ACK_NUM], value)
+        NetworkEndian::write_i32(&mut data[field::ACK_NUM], value.0)
     }
 
     /// Clear the entire flags field.
@@ -422,8 +466,8 @@ pub struct Repr<'a> {
     pub src_port:   u16,
     pub dst_port:   u16,
     pub control:    Control,
-    pub seq_number: i32,
-    pub ack_number: Option<i32>,
+    pub seq_number: SeqNumber,
+    pub ack_number: Option<SeqNumber>,
     pub window_len: u16,
     pub payload:    &'a [u8]
 }
@@ -482,7 +526,7 @@ impl<'a> Repr<'a> {
         packet.set_src_port(self.src_port);
         packet.set_dst_port(self.dst_port);
         packet.set_seq_number(self.seq_number);
-        packet.set_ack_number(self.ack_number.unwrap_or(0));
+        packet.set_ack_number(self.ack_number.unwrap_or(SeqNumber(0)));
         packet.set_window_len(self.window_len);
         packet.set_header_len(field::URGENT.end as u8);
         packet.clear_flags();
@@ -579,8 +623,8 @@ mod test {
         let packet = Packet::new(&PACKET_BYTES[..]).unwrap();
         assert_eq!(packet.src_port(), 48896);
         assert_eq!(packet.dst_port(), 80);
-        assert_eq!(packet.seq_number(), 0x01234567);
-        assert_eq!(packet.ack_number(), 0x89abcdefu32 as i32);
+        assert_eq!(packet.seq_number(), SeqNumber(0x01234567));
+        assert_eq!(packet.ack_number(), SeqNumber(0x89abcdefu32 as i32));
         assert_eq!(packet.header_len(), 20);
         assert_eq!(packet.fin(), true);
         assert_eq!(packet.syn(), false);
@@ -601,8 +645,8 @@ mod test {
         let mut packet = Packet::new(&mut bytes).unwrap();
         packet.set_src_port(48896);
         packet.set_dst_port(80);
-        packet.set_seq_number(0x01234567);
-        packet.set_ack_number(0x89abcdefu32 as i32);
+        packet.set_seq_number(SeqNumber(0x01234567));
+        packet.set_ack_number(SeqNumber(0x89abcdefu32 as i32));
         packet.set_header_len(20);
         packet.set_fin(true);
         packet.set_syn(false);
@@ -630,7 +674,7 @@ mod test {
         Repr {
             src_port:   48896,
             dst_port:   80,
-            seq_number: 0x01234567,
+            seq_number: SeqNumber(0x01234567),
             ack_number: None,
             window_len: 0x0123,
             control:    Control::Syn,