Browse Source

Validate TCP ACKs.

whitequark 8 years ago
parent
commit
a8309b7dff
3 changed files with 116 additions and 60 deletions
  1. 1 0
      README.md
  2. 16 2
      src/iface/ethernet.rs
  3. 99 58
      src/socket/tcp.rs

+ 1 - 0
README.md

@@ -47,6 +47,7 @@ The TCP protocol is supported over IPv4.
 
   * TCP header checksum is supported.
   * TCP options are **not** supported.
+  * TCP SYN packets with a payload are **not** supported.
 
 Installation
 ------------

+ 16 - 2
src/iface/ethernet.rs

@@ -215,8 +215,22 @@ impl<'a, 'b: 'a,
                         for socket in self.sockets.borrow_mut() {
                             match socket.collect(&src_addr.into(), &dst_addr.into(),
                                                  protocol, ip_packet.payload()) {
-                                Ok(()) => { handled = true; break }
-                                Err(Error::Rejected) => continue,
+                                Ok(()) => {
+                                    // The packet was valid and handled by socket.
+                                    handled = true;
+                                    break
+                                }
+                                Err(Error::Rejected) => {
+                                    // The packet wasn't addressed to the socket.
+                                    // For TCP, send RST only if no other socket accepts
+                                    // the packet.
+                                    continue
+                                }
+                                Err(Error::Malformed) => {
+                                    // The packet was addressed to the socket but is malformed.
+                                    // For TCP, send RST immediately.
+                                    break
+                                }
                                 Err(e) => return Err(e)
                             }
                         }

+ 99 - 58
src/socket/tcp.rs

@@ -25,14 +25,19 @@ impl<'a> SocketBuffer<'a> {
         }
     }
 
-    /// Return the amount of octets enqueued in the buffer.
+    /// Return the maximum amount of octets that can be enqueued in the buffer.
+    pub fn capacity(&self) -> usize {
+        self.storage.len()
+    }
+
+    /// Return the amount of octets already enqueued in the buffer.
     pub fn len(&self) -> usize {
         self.length
     }
 
-    /// Return the maximum amount of octets that can be enqueued in the buffer.
-    pub fn capacity(&self) -> usize {
-        self.storage.len()
+    /// Return the amount of octets that remain to be enqueued in the buffer.
+    pub fn window(&self) -> usize {
+        self.capacity() - self.len()
     }
 
     /// Enqueue a slice of octets up to the given size into the buffer, and return a pointer
@@ -135,14 +140,14 @@ impl Retransmit {
 /// A Transmission Control Protocol data stream.
 #[derive(Debug)]
 pub struct TcpSocket<'a> {
-    state:         State,
-    local_end:     IpEndpoint,
-    remote_end:    IpEndpoint,
-    local_seq_no:  i32,
-    remote_seq_no: i32,
-    retransmit:    Retransmit,
-    rx_buffer:     SocketBuffer<'a>,
-    tx_buffer:     SocketBuffer<'a>
+    state:           State,
+    local_endpoint:  IpEndpoint,
+    remote_endpoint: IpEndpoint,
+    local_seq_no:    i32,
+    remote_seq_no:   i32,
+    retransmit:      Retransmit,
+    rx_buffer:       SocketBuffer<'a>,
+    tx_buffer:       SocketBuffer<'a>
 }
 
 impl<'a> TcpSocket<'a> {
@@ -156,14 +161,14 @@ impl<'a> TcpSocket<'a> {
         }
 
         Socket::Tcp(TcpSocket {
-            state:         State::Closed,
-            local_end:     IpEndpoint::default(),
-            remote_end:    IpEndpoint::default(),
-            local_seq_no:  0,
-            remote_seq_no: 0,
-            retransmit:    Retransmit::new(),
-            tx_buffer:     tx_buffer.into(),
-            rx_buffer:     rx_buffer.into()
+            state:           State::Closed,
+            local_endpoint:  IpEndpoint::default(),
+            remote_endpoint: IpEndpoint::default(),
+            local_seq_no:    0,
+            remote_seq_no:   0,
+            retransmit:      Retransmit::new(),
+            tx_buffer:       tx_buffer.into(),
+            rx_buffer:       rx_buffer.into()
         })
     }
 
@@ -176,23 +181,23 @@ impl<'a> TcpSocket<'a> {
     /// Return the local endpoint.
     #[inline(always)]
     pub fn local_endpoint(&self) -> IpEndpoint {
-        self.local_end
+        self.local_endpoint
     }
 
     /// Return the remote endpoint.
     #[inline(always)]
     pub fn remote_endpoint(&self) -> IpEndpoint {
-        self.remote_end
+        self.remote_endpoint
     }
 
     fn set_state(&mut self, state: State) {
         if self.state != state {
-            if self.remote_end.addr.is_unspecified() {
+            if self.remote_endpoint.addr.is_unspecified() {
                 net_trace!("tcp:{}: state={}→{}",
-                           self.local_end, self.state, state);
+                           self.local_endpoint, self.state, state);
             } else {
                 net_trace!("tcp:{}:{}: state={}→{}",
-                           self.local_end, self.remote_end, self.state, state);
+                           self.local_endpoint, self.remote_endpoint, self.state, state);
             }
         }
         self.state = state
@@ -205,8 +210,8 @@ impl<'a> TcpSocket<'a> {
     pub fn listen(&mut self, endpoint: IpEndpoint) {
         assert!(self.state == State::Closed);
 
-        self.local_end  = endpoint;
-        self.remote_end = IpEndpoint::default();
+        self.local_endpoint  = endpoint;
+        self.remote_endpoint = IpEndpoint::default();
         self.set_state(State::Listen);
     }
 
@@ -219,29 +224,67 @@ impl<'a> TcpSocket<'a> {
         let packet = try!(TcpPacket::new(payload));
         let repr = try!(TcpRepr::parse(&packet, src_addr, dst_addr));
 
-        if self.local_end.port != repr.dst_port { return Err(Error::Rejected) }
-        if !self.local_end.addr.is_unspecified() &&
-           self.local_end.addr != *dst_addr { return Err(Error::Rejected) }
+        // Reject packets with a wrong destination.
+        if self.local_endpoint.port != repr.dst_port { return Err(Error::Rejected) }
+        if !self.local_endpoint.addr.is_unspecified() &&
+           self.local_endpoint.addr != *dst_addr { return Err(Error::Rejected) }
 
-        if self.remote_end.port != 0 &&
-           self.remote_end.port != repr.src_port { return Err(Error::Rejected) }
-        if !self.remote_end.addr.is_unspecified() &&
-           self.remote_end.addr != *src_addr { return Err(Error::Rejected) }
+        // Reject packets from a source to which we aren't connected.
+        if self.remote_endpoint.port != 0 &&
+           self.remote_endpoint.port != repr.src_port { return Err(Error::Rejected) }
+        if !self.remote_endpoint.addr.is_unspecified() &&
+           self.remote_endpoint.addr != *src_addr { return Err(Error::Rejected) }
 
         match (self.state, repr) {
-            (State::Closed, _) => Err(Error::Rejected),
+            // Reject packets addressed to a closed socket.
+            (State::Closed, TcpRepr { src_port, .. }) => {
+                net_trace!("tcp:{}:{}:{}: packet sent to a closed socket",
+                           self.local_endpoint, src_addr, src_port);
+                return Err(Error::Malformed)
+            }
+            // Don't care about ACKs when performing the handshake.
+            (State::Listen, _) => (),
+            (State::SynSent, _) => (),
+            // Every packet after the initial SYN must be an acknowledgement.
+            (_, TcpRepr { ack_number: None, .. }) => {
+                net_trace!("tcp:{}:{}: expecting an ACK packet",
+                           self.local_endpoint, self.remote_endpoint);
+                return Err(Error::Malformed)
+            }
+            // Reject unacceptable acknowledgements.
+            (state, TcpRepr { ack_number: Some(ack_number), .. }) => {
+                let unacknowledged =
+                    if state != State::SynReceived { self.rx_buffer.len() as i32 } else { 1 };
+                if !(ack_number - self.local_seq_no > 0 &&
+                     ack_number - (self.local_seq_no + unacknowledged) <= 0) {
+                    net_trace!("tcp:{}:{}: unacceptable ACK ({} not in {}..{})",
+                               self.local_endpoint, self.remote_endpoint,
+                               ack_number, self.local_seq_no, self.local_seq_no + unacknowledged);
+                    return Err(Error::Malformed)
+                }
+            }
+        }
 
+        // Handle the incoming packet.
+        match (self.state, repr) {
             (State::Listen, TcpRepr {
-                src_port, dst_port, control: TcpControl::Syn, seq_number, ack_number: None, ..
+                src_port, dst_port, control: TcpControl::Syn, seq_number, ack_number: None,
+                payload, ..
             }) => {
-                self.local_end     = IpEndpoint::new(*dst_addr, dst_port);
-                self.remote_end    = IpEndpoint::new(*src_addr, src_port);
-                self.remote_seq_no = seq_number;
-                // FIXME: use something more secure
-                self.local_seq_no  = !seq_number;
+                // FIXME: don't do this, just enqueue the payload
+                if payload.len() > 0 {
+                    net_trace!("tcp:{}:{}: SYN with payload rejected",
+                               IpEndpoint::new(*dst_addr, dst_port),
+                               IpEndpoint::new(*src_addr, src_port));
+                    return Err(Error::Malformed)
+                }
+
+                self.local_endpoint  = IpEndpoint::new(*dst_addr, dst_port);
+                self.remote_endpoint = IpEndpoint::new(*src_addr, src_port);
+                self.remote_seq_no   = seq_number + 1;
+                self.local_seq_no    = -seq_number; // FIXME: use something more secure
                 self.set_state(State::SynReceived);
 
-                // FIXME: queue data from SYN
                 self.retransmit.reset();
                 Ok(())
             }
@@ -249,19 +292,15 @@ impl<'a> TcpSocket<'a> {
             (State::SynReceived, TcpRepr {
                 control: TcpControl::None, ack_number: Some(ack_number), ..
             }) => {
-                if ack_number != self.local_seq_no + 1 { return Err(Error::Rejected) }
+                self.local_seq_no    = ack_number;
                 self.set_state(State::Established);
 
                 // FIXME: queue data from ACK
-                // FIXME: update sequence numbers
                 self.retransmit.reset();
                 Ok(())
             }
 
-            _ => {
-                // This will cause the interface to reply with an RST.
-                Err(Error::Rejected)
-            }
+            _ => Err(Error::Malformed)
         }
     }
 
@@ -270,12 +309,12 @@ impl<'a> TcpSocket<'a> {
                                              IpProtocol, &PacketRepr) -> Result<(), Error>)
             -> Result<(), Error> {
         let mut repr = TcpRepr {
-            src_port:   self.local_end.port,
-            dst_port:   self.remote_end.port,
+            src_port:   self.local_endpoint.port,
+            dst_port:   self.remote_endpoint.port,
             control:    TcpControl::None,
             seq_number: 0,
             ack_number: None,
-            window_len: (self.rx_buffer.capacity() - self.rx_buffer.len()) as u16,
+            window_len: self.rx_buffer.window() as u16,
             payload:    &[]
         };
 
@@ -291,9 +330,9 @@ impl<'a> TcpSocket<'a> {
                 if !self.retransmit.check() { return Err(Error::Exhausted) }
                 repr.control    = TcpControl::Syn;
                 repr.seq_number = self.local_seq_no;
-                repr.ack_number = Some(self.remote_seq_no + 1);
+                repr.ack_number = Some(self.remote_seq_no);
                 net_trace!("tcp:{}:{}: SYN sent",
-                           self.local_end, self.remote_end);
+                           self.local_endpoint, self.remote_endpoint);
             }
 
             State::Established => {
@@ -304,7 +343,7 @@ impl<'a> TcpSocket<'a> {
             _ => unreachable!()
         }
 
-        f(&self.local_end.addr, &self.remote_end.addr, IpProtocol::Tcp, &repr)
+        f(&self.local_endpoint.addr, &self.remote_endpoint.addr, IpProtocol::Tcp, &repr)
     }
 }
 
@@ -342,7 +381,7 @@ mod test {
     const LOCAL_END:    IpEndpoint = IpEndpoint::new(LOCAL_IP, LOCAL_PORT);
     const REMOTE_END:   IpEndpoint = IpEndpoint::new(REMOTE_IP, REMOTE_PORT);
     const LOCAL_SEQ:    i32        = 100;
-    const REMOTE_SEQ:   i32        = !100;
+    const REMOTE_SEQ:   i32        = -100;
 
     const SEND_TEMPL: TcpRepr<'static> = TcpRepr {
         src_port: REMOTE_PORT, dst_port: LOCAL_PORT,
@@ -434,7 +473,7 @@ mod test {
 
         send!(s, TcpRepr {
             control: TcpControl::Syn,
-            seq_number: LOCAL_SEQ, ack_number: None,
+            seq_number: REMOTE_SEQ, ack_number: None,
             ..SEND_TEMPL
         });
         assert_eq!(s.state(), State::SynReceived);
@@ -442,14 +481,16 @@ mod test {
         assert_eq!(s.remote_endpoint(), REMOTE_END);
         recv!(s, TcpRepr {
             control: TcpControl::Syn,
-            seq_number: REMOTE_SEQ, ack_number: Some(LOCAL_SEQ + 1),
+            seq_number: LOCAL_SEQ, ack_number: Some(REMOTE_SEQ + 1),
             ..RECV_TEMPL
         });
         send!(s, TcpRepr {
             control: TcpControl::None,
-            seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1),
+            seq_number: REMOTE_SEQ + 1, ack_number: Some(LOCAL_SEQ + 1),
             ..SEND_TEMPL
         });
         assert_eq!(s.state(), State::Established);
+        assert_eq!(s.local_seq_no, LOCAL_SEQ + 1);
+        assert_eq!(s.remote_seq_no, REMOTE_SEQ + 1);
     }
 }