Explorar o código

Factor out TcpSocket::accepts.

Egor Karavaev %!s(int64=7) %!d(string=hai) anos
pai
achega
a2c66fdd88
Modificáronse 2 ficheiros con 86 adicións e 12 borrados
  1. 4 4
      src/iface/ethernet.rs
  2. 82 8
      src/socket/tcp.rs

+ 4 - 4
src/iface/ethernet.rs

@@ -412,13 +412,13 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
 
         for tcp_socket in sockets.iter_mut().filter_map(
                 <Socket as AsSocket<TcpSocket>>::try_as_socket) {
+            if !tcp_socket.accepts(&ip_repr, &tcp_repr) { continue }
+
             match tcp_socket.process(timestamp, &ip_repr, &tcp_repr) {
                 // The packet is valid and handled by socket.
                 Ok(reply) => return Ok(reply.map_or(Packet::None, Packet::Tcp)),
-                // The packet isn't addressed to the socket.
-                // Send RST only if no other socket accepts the packet.
-                Err(Error::Rejected) => continue,
-                // The packet is malformed, or addressed to the socket but cannot be accepted.
+                // The packet is malformed, or doesn't match the socket state,
+                // or the socket buffer is full.
                 Err(e) => return Err(e)
             }
         }

+ 82 - 8
src/socket/tcp.rs

@@ -666,25 +666,31 @@ impl<'a> TcpSocket<'a> {
         (ip_reply_repr, reply_repr)
     }
 
-    pub(crate) fn process(&mut self, timestamp: u64, ip_repr: &IpRepr, repr: &TcpRepr) ->
-                         Result<Option<(IpRepr, TcpRepr<'static>)>> {
-        if self.state == State::Closed { return Err(Error::Rejected) }
+    pub(crate) fn accepts(&self, ip_repr: &IpRepr, repr: &TcpRepr) -> bool {
+        if self.state == State::Closed { return false }
 
         // If we're still listening for SYNs and the packet has an ACK, it cannot
         // be destined to this socket, but another one may well listen on the same
         // local endpoint.
-        if self.state == State::Listen && repr.ack_number.is_some() { return Err(Error::Rejected) }
+        if self.state == State::Listen && repr.ack_number.is_some() { return false }
 
         // Reject packets with a wrong destination.
-        if self.local_endpoint.port != repr.dst_port { return Err(Error::Rejected) }
+        if self.local_endpoint.port != repr.dst_port { return false }
         if !self.local_endpoint.addr.is_unspecified() &&
-           self.local_endpoint.addr != ip_repr.dst_addr() { return Err(Error::Rejected) }
+            self.local_endpoint.addr != ip_repr.dst_addr() { return false }
 
         // Reject packets from a source to which we aren't connected.
         if self.remote_endpoint.port != 0 &&
-           self.remote_endpoint.port != repr.src_port { return Err(Error::Rejected) }
+            self.remote_endpoint.port != repr.src_port { return false }
         if !self.remote_endpoint.addr.is_unspecified() &&
-           self.remote_endpoint.addr != ip_repr.src_addr() { return Err(Error::Rejected) }
+            self.remote_endpoint.addr != ip_repr.src_addr() { return false }
+
+        true
+    }
+
+    pub(crate) fn process(&mut self, timestamp: u64, ip_repr: &IpRepr, repr: &TcpRepr) ->
+                         Result<Option<(IpRepr, TcpRepr<'static>)>> {
+        debug_assert!(self.accepts(ip_repr, repr));
 
         // Consider how much the sequence number space differs from the transmit buffer space.
         let (sent_syn, sent_fin) = match self.state {
@@ -1241,6 +1247,7 @@ mod test {
 
     const LOCAL_IP:     IpAddress    = IpAddress::Ipv4(Ipv4Address([10, 0, 0, 1]));
     const REMOTE_IP:    IpAddress    = IpAddress::Ipv4(Ipv4Address([10, 0, 0, 2]));
+    const THIRD_PARTY_IP: IpAddress  = IpAddress::Ipv4(Ipv4Address([10, 0, 0, 3]));
     const LOCAL_PORT:   u16          = 80;
     const REMOTE_PORT:  u16          = 49500;
     const LOCAL_END:    IpEndpoint   = IpEndpoint { addr: LOCAL_IP,  port: LOCAL_PORT  };
@@ -1272,6 +1279,11 @@ mod test {
             protocol:    IpProtocol::Tcp,
             payload_len: repr.buffer_len()
         };
+
+        if !socket.accepts(&ip_repr, repr) {
+            return Err(Error::Rejected);
+        }
+
         match socket.process(timestamp, &ip_repr, repr) {
             Ok(Some((_ip_repr, repr))) => {
                 trace!("recv: {}", repr);
@@ -2930,4 +2942,66 @@ mod test {
             ..RECV_TEMPL
         }]);
     }
+
+    // =========================================================================================//
+    // Tests for packet filtering
+    // =========================================================================================//
+
+    #[test]
+    fn test_doesnt_accept_wrong_port() {
+        let mut s = socket_established();
+        s.rx_buffer = SocketBuffer::new(vec![0; 6]);
+
+        send!(s, TcpRepr {
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
+            payload:    &b"abcdef"[..],
+            dst_port:   LOCAL_PORT + 1,
+            ..SEND_TEMPL
+        }, Err(Error::Rejected));
+
+        send!(s, TcpRepr {
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
+            payload:    &b"abcdef"[..],
+            src_port:   REMOTE_PORT + 1,
+            ..SEND_TEMPL
+        }, Err(Error::Rejected));
+    }
+
+    #[test]
+    fn test_doesnt_accept_wrong_ip() {
+        let s = socket_established();
+
+        let tcp_repr = TcpRepr {
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1),
+            payload:    &b"abcdef"[..],
+            ..SEND_TEMPL
+        };
+
+        let ip_repr = IpRepr::Unspecified {
+            src_addr:    REMOTE_IP,
+            dst_addr:    LOCAL_IP,
+            protocol:    IpProtocol::Tcp,
+            payload_len: tcp_repr.buffer_len()
+        };
+        assert!(s.accepts(&ip_repr, &tcp_repr));
+
+        let ip_repr_wrong_src = IpRepr::Unspecified {
+            src_addr:    THIRD_PARTY_IP,
+            dst_addr:    LOCAL_IP,
+            protocol:    IpProtocol::Tcp,
+            payload_len: tcp_repr.buffer_len()
+        };
+        assert!(!s.accepts(&ip_repr_wrong_src, &tcp_repr));
+
+        let ip_repr_wrong_dst = IpRepr::Unspecified {
+            src_addr:    REMOTE_IP,
+            dst_addr:    THIRD_PARTY_IP,
+            protocol:    IpProtocol::Tcp,
+            payload_len: tcp_repr.buffer_len()
+        };
+        assert!(!s.accepts(&ip_repr_wrong_dst, &tcp_repr));
+    }
 }