Browse Source

Implement TCP data reception.

whitequark 8 years ago
parent
commit
843b79bff2
3 changed files with 185 additions and 76 deletions
  1. 1 1
      README.md
  2. 156 59
      src/socket/tcp.rs
  3. 28 16
      src/wire/tcp.rs

+ 1 - 1
README.md

@@ -47,7 +47,7 @@ The TCP protocol is supported over IPv4.
 
   * TCP header checksum is supported.
   * TCP options are **not** supported.
-  * TCP SYN packets with a payload are **not** supported.
+  * Reassembly of out-of-order segments is **not** supported.
 
 Installation
 ------------

+ 156 - 59
src/socket/tcp.rs

@@ -25,27 +25,19 @@ impl<'a> SocketBuffer<'a> {
         }
     }
 
-    /// Return the maximum amount of octets that can be enqueued in the buffer.
-    pub fn capacity(&self) -> usize {
+    fn capacity(&self) -> usize {
         self.storage.len()
     }
 
-    /// Return the amount of octets already enqueued in the buffer.
-    pub fn len(&self) -> usize {
+    fn len(&self) -> usize {
         self.length
     }
 
-    /// Return the amount of octets that remain to be enqueued in the buffer.
-    pub fn window(&self) -> usize {
+    fn window(&self) -> usize {
         self.capacity() - self.len()
     }
 
-    /// Enqueue a slice of octets up to the given size into the buffer, and return a pointer
-    /// to the slice.
-    ///
-    /// The returned slice may be shorter than requested, as short as an empty slice,
-    /// if there is not enough contiguous free space in the buffer.
-    pub fn enqueue(&mut self, mut size: usize) -> &mut [u8] {
+    fn clamp_writer(&self, mut size: usize) -> (usize, usize) {
         let write_at = (self.read_at + self.length) % self.storage.len();
         // We can't enqueue more than there is free space.
         let free = self.storage.len() - self.length;
@@ -54,23 +46,52 @@ impl<'a> SocketBuffer<'a> {
         let until_end = self.storage.len() - write_at;
         if size > until_end { size = until_end }
 
+        (write_at, size)
+    }
+
+    fn enqueue(&mut self, size: usize) -> &mut [u8] {
+        let (write_at, size) = self.clamp_writer(size);
         self.length += size;
         &mut self.storage[write_at..write_at + size]
     }
 
-    /// Dequeue a slice of octets up to the given size from the buffer, and return a pointer
-    /// to the slice.
-    ///
-    /// The returned slice may be shorter than requested, as short as an empty slice,
-    /// if there is not enough contiguous filled space in the buffer.
-    pub fn dequeue(&mut self, mut size: usize) -> &[u8] {
+    fn enqueue_slice(&mut self, data: &[u8]) {
+        let data = {
+            let mut dest = self.enqueue(data.len());
+            let (data, rest) = data.split_at(dest.len());
+            dest.copy_from_slice(data);
+            rest
+        };
+        // Retry, in case we had a wraparound.
+        let mut dest = self.enqueue(data.len());
+        let (data, _) = data.split_at(dest.len());
+        dest.copy_from_slice(data);
+    }
+
+    fn clamp_reader(&self, mut size: usize) -> (usize, usize) {
         let read_at = self.read_at;
         // We can't dequeue more than was queued.
         if size > self.length { size = self.length }
         // We can't contiguously dequeue past the end of the storage.
-        let until_end = self.storage.len() - self.read_at;
+        let until_end = self.storage.len() - read_at;
         if size > until_end { size = until_end }
 
+        (read_at, size)
+    }
+
+    fn peek(&self, size: usize) -> &[u8] {
+        let (read_at, size) = self.clamp_reader(size);
+        &self.storage[read_at..read_at + size]
+    }
+
+    fn advance(&mut self, size: usize) {
+        let (read_at, size) = self.clamp_reader(size);
+        self.read_at = (read_at + size) % self.storage.len();
+        self.length -= size;
+    }
+
+    fn dequeue(&mut self, size: usize) -> &[u8] {
+        let (read_at, size) = self.clamp_reader(size);
         self.read_at = (self.read_at + size) % self.storage.len();
         self.length -= size;
         &self.storage[read_at..read_at + size]
@@ -145,6 +166,7 @@ pub struct TcpSocket<'a> {
     remote_endpoint: IpEndpoint,
     local_seq_no:    i32,
     remote_seq_no:   i32,
+    remote_win_len:  usize,
     retransmit:      Retransmit,
     rx_buffer:       SocketBuffer<'a>,
     tx_buffer:       SocketBuffer<'a>
@@ -166,6 +188,7 @@ impl<'a> TcpSocket<'a> {
             remote_endpoint: IpEndpoint::default(),
             local_seq_no:    0,
             remote_seq_no:   0,
+            remote_win_len:  0,
             retransmit:      Retransmit::new(),
             tx_buffer:       tx_buffer.into(),
             rx_buffer:       rx_buffer.into()
@@ -235,27 +258,30 @@ impl<'a> TcpSocket<'a> {
         if !self.remote_endpoint.addr.is_unspecified() &&
            self.remote_endpoint.addr != *src_addr { return Err(Error::Rejected) }
 
+        // Reject packets addressed to a closed socket.
+        if self.state == State::Closed {
+            net_trace!("tcp:{}:{}:{}: packet sent to a closed socket",
+                       self.local_endpoint, src_addr, repr.src_port);
+            return Err(Error::Malformed)
+        }
+
+        // Reject unacceptable acknowledgements.
         match (self.state, repr) {
-            // Reject packets addressed to a closed socket.
-            (State::Closed, TcpRepr { src_port, .. }) => {
-                net_trace!("tcp:{}:{}:{}: packet sent to a closed socket",
-                           self.local_endpoint, src_addr, src_port);
-                return Err(Error::Malformed)
-            }
             // Don't care about ACKs when performing the handshake.
             (State::Listen, _) => (),
             (State::SynSent, _) => (),
             // Every packet after the initial SYN must be an acknowledgement.
             (_, TcpRepr { ack_number: None, .. }) => {
-                net_trace!("tcp:{}:{}: expecting an ACK packet",
+                net_trace!("tcp:{}:{}: expecting an ACK",
                            self.local_endpoint, self.remote_endpoint);
                 return Err(Error::Malformed)
             }
-            // Reject unacceptable acknowledgements.
+            // Every acknowledgement must be for transmitted but unacknowledged data.
             (state, TcpRepr { ack_number: Some(ack_number), .. }) => {
-                let unacknowledged =
-                    if state != State::SynReceived { self.rx_buffer.len() as i32 } else { 1 };
-                if !(ack_number - self.local_seq_no > 0 &&
+                let control_len =
+                    if state == State::SynReceived { 1 } else { 0 };
+                let unacknowledged = self.tx_buffer.len() as i32 + control_len;
+                if !(ack_number - self.local_seq_no >= 0 &&
                      ack_number - (self.local_seq_no + unacknowledged) <= 0) {
                     net_trace!("tcp:{}:{}: unacceptable ACK ({} not in {}..{})",
                                self.local_endpoint, self.remote_endpoint,
@@ -265,43 +291,80 @@ impl<'a> TcpSocket<'a> {
             }
         }
 
-        // Handle the incoming packet.
+        // Reject segments not occupying a valid portion of the receive window.
+        // For now, do not try to reassemble out-of-order segments.
+        if self.state != State::Listen {
+            let next_remote_seq = self.remote_seq_no + self.rx_buffer.len() as i32 +
+                                  repr.control.len();
+            if repr.seq_number - next_remote_seq > 0 {
+                net_trace!("tcp:{}:{}: unacceptable SEQ ({} not in {}..)",
+                           self.local_endpoint, self.remote_endpoint,
+                           repr.seq_number, next_remote_seq);
+                return Err(Error::Malformed)
+            } else if repr.seq_number - next_remote_seq != 0 {
+                net_trace!("tcp:{}:{}: duplicate SEQ ({} in ..{})",
+                           self.local_endpoint, self.remote_endpoint,
+                           repr.seq_number, next_remote_seq);
+                return Ok(())
+            }
+        }
+
+        // Validate and update the state.
+        let old_state = self.state;
         match (self.state, repr) {
             (State::Listen, TcpRepr {
-                src_port, dst_port, control: TcpControl::Syn, seq_number, ack_number: None,
-                payload, ..
+                src_port, dst_port, control: TcpControl::Syn, seq_number, ack_number: None, ..
             }) => {
-                // FIXME: don't do this, just enqueue the payload
-                if payload.len() > 0 {
-                    net_trace!("tcp:{}:{}: SYN with payload rejected",
-                               IpEndpoint::new(*dst_addr, dst_port),
-                               IpEndpoint::new(*src_addr, src_port));
-                    return Err(Error::Malformed)
-                }
-
                 self.local_endpoint  = IpEndpoint::new(*dst_addr, dst_port);
                 self.remote_endpoint = IpEndpoint::new(*src_addr, src_port);
-                self.remote_seq_no   = seq_number + 1;
                 self.local_seq_no    = -seq_number; // FIXME: use something more secure
+                self.remote_seq_no   = seq_number + 1;
                 self.set_state(State::SynReceived);
-
-                self.retransmit.reset();
-                Ok(())
+                self.retransmit.reset()
             }
 
-            (State::SynReceived, TcpRepr {
-                control: TcpControl::None, ack_number: Some(ack_number), ..
-            }) => {
-                self.local_seq_no    = ack_number;
+            (State::SynReceived, TcpRepr { control: TcpControl::None, .. }) => {
                 self.set_state(State::Established);
+                self.retransmit.reset()
+            }
 
-                // FIXME: queue data from ACK
-                self.retransmit.reset();
-                Ok(())
+            (State::Established, TcpRepr { control: TcpControl::None, .. }) => (),
+
+            _ => {
+                net_trace!("tcp:{}:{}: unexpected packet {}",
+                           self.local_endpoint, self.remote_endpoint, repr);
+                return Err(Error::Malformed)
             }
+        }
 
-            _ => Err(Error::Malformed)
+        // Dequeue acknowledged octets.
+        if let Some(ack_number) = repr.ack_number {
+            let control_len =
+                if old_state == State::SynReceived { 1 } else { 0 };
+            if control_len > 0 {
+                net_trace!("tcp:{}:{}: ACK for a control flag",
+                           self.local_endpoint, self.remote_endpoint);
+            }
+            if ack_number - self.local_seq_no - control_len > 0 {
+                net_trace!("tcp:{}:{}: ACK for {} octets",
+                           self.local_endpoint, self.remote_endpoint,
+                           ack_number - self.local_seq_no - control_len);
+            }
+            self.tx_buffer.advance((ack_number - self.local_seq_no - control_len) as usize);
+            self.local_seq_no = ack_number;
+        }
+
+        // Enqueue payload octets, which is guaranteed to be in order, unless we already did.
+        if repr.payload.len() > 0 {
+            net_trace!("tcp:{}:{}: receiving {} octets",
+                       self.local_endpoint, self.remote_endpoint, repr.payload.len());
+            self.rx_buffer.enqueue_slice(repr.payload)
         }
+
+        // Update window length.
+        self.remote_win_len = repr.window_len as usize;
+
+        Ok(())
     }
 
     /// See [Socket::dispatch](enum.Socket.html#method.dispatch).
@@ -374,14 +437,22 @@ mod test {
         buffer.enqueue(8).copy_from_slice(b"gefug");    // ...gefug
     }
 
+    #[test]
+    fn test_buffer_wraparound() {
+        let mut buffer = SocketBuffer::new(vec![0; 8]); // ........
+        buffer.enqueue_slice(&b"foobar"[..]);           // foobar..
+        assert_eq!(buffer.dequeue(3), b"foo");          // ...bar..
+        buffer.enqueue_slice(&b"bazhoge"[..]);          // zhobarba
+    }
+
     const LOCAL_IP:     IpAddress  = IpAddress::v4(10, 0, 0, 1);
     const REMOTE_IP:    IpAddress  = IpAddress::v4(10, 0, 0, 2);
     const LOCAL_PORT:   u16        = 80;
     const REMOTE_PORT:  u16        = 49500;
     const LOCAL_END:    IpEndpoint = IpEndpoint::new(LOCAL_IP, LOCAL_PORT);
     const REMOTE_END:   IpEndpoint = IpEndpoint::new(REMOTE_IP, REMOTE_PORT);
-    const LOCAL_SEQ:    i32        = 100;
-    const REMOTE_SEQ:   i32        = -100;
+    const LOCAL_SEQ:    i32        = 10000;
+    const REMOTE_SEQ:   i32        = -10000;
 
     const SEND_TEMPL: TcpRepr<'static> = TcpRepr {
         src_port: REMOTE_PORT, dst_port: LOCAL_PORT,
@@ -473,7 +544,8 @@ mod test {
 
         send!(s, TcpRepr {
             control: TcpControl::Syn,
-            seq_number: REMOTE_SEQ, ack_number: None,
+            seq_number: REMOTE_SEQ,
+            ack_number: None,
             ..SEND_TEMPL
         });
         assert_eq!(s.state(), State::SynReceived);
@@ -481,16 +553,41 @@ mod test {
         assert_eq!(s.remote_endpoint(), REMOTE_END);
         recv!(s, TcpRepr {
             control: TcpControl::Syn,
-            seq_number: LOCAL_SEQ, ack_number: Some(REMOTE_SEQ + 1),
+            seq_number: LOCAL_SEQ,
+            ack_number: Some(REMOTE_SEQ + 1),
             ..RECV_TEMPL
         });
         send!(s, TcpRepr {
-            control: TcpControl::None,
-            seq_number: REMOTE_SEQ + 1, ack_number: Some(LOCAL_SEQ + 1),
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
             ..SEND_TEMPL
         });
         assert_eq!(s.state(), State::Established);
         assert_eq!(s.local_seq_no, LOCAL_SEQ + 1);
         assert_eq!(s.remote_seq_no, REMOTE_SEQ + 1);
     }
+
+    #[test]
+    fn test_recv_data() {
+        let mut s = socket();
+        s.state = State::Established;
+        s.local_endpoint  = LOCAL_END;
+        s.remote_endpoint = REMOTE_END;
+        s.local_seq_no    = LOCAL_SEQ + 1;
+        s.remote_seq_no   = REMOTE_SEQ + 1;
+
+        send!(s, TcpRepr {
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
+            payload: &b"abcdef"[..],
+            ..SEND_TEMPL
+        });
+        send!(s, TcpRepr {
+            seq_number: REMOTE_SEQ + 1 + 6,
+            ack_number: Some(LOCAL_SEQ + 1),
+            payload: &b"foo"[..],
+            ..SEND_TEMPL
+        });
+        assert_eq!(s.rx_buffer.dequeue(9), &b"abcdeffoo"[..]);
+    }
 }

+ 28 - 16
src/wire/tcp.rs

@@ -187,8 +187,8 @@ impl<T: AsRef<[u8]>> Packet<T> {
     pub fn segment_len(&self) -> i32 {
         let data = self.buffer.as_ref();
         let mut length = data.len() - self.header_len() as usize;
-        if self.syn() { length += 1}
-        if self.fin() { length += 1}
+        if self.syn() { length += 1 }
+        if self.fin() { length += 1 }
         length as i32
     }
 
@@ -395,6 +395,27 @@ impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> Packet<&'a mut T> {
     }
 }
 
+/// The control flags of a Transmission Control Protocol packet.
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub enum Control {
+    None,
+    Syn,
+    Fin,
+    Rst
+}
+
+impl Control {
+    /// Return the length of the control flag, in terms of sequence space.
+    pub fn len(self) -> i32 {
+        match self {
+            Control::None => 0,
+            Control::Syn  => 1,
+            Control::Fin  => 1,
+            Control::Rst  => 0
+        }
+    }
+}
+
 /// A high-level representation of a Transmission Control Protocol packet.
 #[derive(Debug, PartialEq, Eq, Clone, Copy)]
 pub struct Repr<'a> {
@@ -407,15 +428,6 @@ pub struct Repr<'a> {
     pub payload:    &'a [u8]
 }
 
-/// The control flags of a Transmission Control Protocol packet.
-#[derive(Debug, PartialEq, Eq, Clone, Copy)]
-pub enum Control {
-    None,
-    Syn,
-    Fin,
-    Rst
-}
-
 impl<'a> Repr<'a> {
     /// Parse a Transmission Control Protocol packet and return a high-level representation.
     pub fn parse<T: ?Sized>(packet: &Packet<&'a T>,
@@ -498,9 +510,9 @@ impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {
         if self.ece() { try!(write!(f, " ece")) }
         if self.cwr() { try!(write!(f, " cwr")) }
         if self.ns()  { try!(write!(f, " ns" )) }
-        try!(write!(f, " seq={}", self.seq_number() as u32));
+        try!(write!(f, " seq={}", self.seq_number()));
         if self.ack() {
-            try!(write!(f, " ack={}", self.ack_number() as u32));
+            try!(write!(f, " ack={}", self.ack_number()));
         }
         try!(write!(f, " win={}", self.window_len()));
         if self.urg() {
@@ -513,7 +525,7 @@ impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {
 
 impl<'a> fmt::Display for Repr<'a> {
     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
-        try!(write!(f, "TCP src={} dst={} ",
+        try!(write!(f, "TCP src={} dst={}",
                     self.src_port, self.dst_port));
         match self.control {
             Control::Syn => try!(write!(f, " syn")),
@@ -521,9 +533,9 @@ impl<'a> fmt::Display for Repr<'a> {
             Control::Rst => try!(write!(f, " rst")),
             Control::None => ()
         }
-        try!(write!(f, " seq={}", self.seq_number as u32));
+        try!(write!(f, " seq={}", self.seq_number));
         if let Some(ack_number) = self.ack_number {
-            try!(write!(f, " ack={}", ack_number as u32));
+            try!(write!(f, " ack={}", ack_number));
         }
         try!(write!(f, " win={}", self.window_len));
         try!(write!(f, " len={}", self.payload.len()));