Răsfoiți Sursa

Factor out UdpSocket::accepts.

Egor Karavaev 7 ani în urmă
părinte
comite
44cf21b91b
2 a modificat fișierele cu 45 adăugiri și 7 ștergeri
  1. 3 3
      src/iface/ethernet.rs
  2. 42 4
      src/socket/udp.rs

+ 3 - 3
src/iface/ethernet.rs

@@ -371,12 +371,12 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
 
         for udp_socket in sockets.iter_mut().filter_map(
                 <Socket as AsSocket<UdpSocket>>::try_as_socket) {
+            if !udp_socket.accepts(&ip_repr, &udp_repr) { continue }
+
             match udp_socket.process(&ip_repr, &udp_repr) {
                 // The packet is valid and handled by socket.
                 Ok(()) => return Ok(Packet::None),
-                // The packet isn't addressed to the socket.
-                Err(Error::Rejected) => continue,
-                // The packet is malformed, or addressed to the socket but cannot be accepted.
+                // The packet is malformed, or the socket buffer is full.
                 Err(e) => return Err(e)
             }
         }

+ 42 - 4
src/socket/udp.rs

@@ -179,11 +179,16 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
         Ok((length, endpoint))
     }
 
-    pub(crate) fn process(&mut self, ip_repr: &IpRepr, repr: &UdpRepr) -> Result<()> {
-        // Reject packets with a wrong destination.
-        if self.endpoint.port != repr.dst_port { return Err(Error::Rejected) }
+    pub(crate) fn accepts(&self, ip_repr: &IpRepr, repr: &UdpRepr) -> bool {
+        if self.endpoint.port != repr.dst_port { return false }
         if !self.endpoint.addr.is_unspecified() &&
-           self.endpoint.addr != ip_repr.dst_addr() { return Err(Error::Rejected) }
+           self.endpoint.addr != ip_repr.dst_addr() { return false }
+
+        true
+    }
+
+    pub(crate) fn process(&mut self, ip_repr: &IpRepr, repr: &UdpRepr) -> Result<()> {
+        debug_assert!(self.accepts(ip_repr, repr));
 
         let packet_buf = self.rx_buffer.enqueue_one_with(|buf| buf.resize(repr.payload.len()))?;
         packet_buf.as_mut().copy_from_slice(repr.payload);
@@ -351,10 +356,12 @@ mod test {
         assert!(!socket.can_recv());
         assert_eq!(socket.recv(), Err(Error::Exhausted));
 
+        assert!(socket.accepts(&REMOTE_IP_REPR, &REMOTE_UDP_REPR));
         assert_eq!(socket.process(&REMOTE_IP_REPR, &REMOTE_UDP_REPR),
                    Ok(()));
         assert!(socket.can_recv());
 
+        assert!(socket.accepts(&REMOTE_IP_REPR, &REMOTE_UDP_REPR));
         assert_eq!(socket.process(&REMOTE_IP_REPR, &REMOTE_UDP_REPR),
                    Err(Error::Exhausted));
         assert_eq!(socket.recv(), Ok((&b"abcdef"[..], REMOTE_END)));
@@ -366,6 +373,7 @@ mod test {
         let mut socket = socket(buffer(1), buffer(0));
         assert_eq!(socket.bind(LOCAL_PORT), Ok(()));
 
+        assert!(socket.accepts(&REMOTE_IP_REPR, &REMOTE_UDP_REPR));
         assert_eq!(socket.process(&REMOTE_IP_REPR, &REMOTE_UDP_REPR),
                    Ok(()));
 
@@ -380,7 +388,37 @@ mod test {
         assert_eq!(socket.bind(LOCAL_PORT), Ok(()));
 
         let udp_repr = UdpRepr { payload: &[0; 100][..], ..REMOTE_UDP_REPR };
+        assert!(socket.accepts(&REMOTE_IP_REPR, &udp_repr));
         assert_eq!(socket.process(&REMOTE_IP_REPR, &udp_repr),
                    Err(Error::Truncated));
     }
+
+    #[test]
+    fn test_doesnt_accept_wrong_port() {
+        let mut socket = socket(buffer(1), buffer(0));
+        assert_eq!(socket.bind(LOCAL_PORT), Ok(()));
+
+        let mut udp_repr = REMOTE_UDP_REPR;
+        assert!(socket.accepts(&REMOTE_IP_REPR, &udp_repr));
+        udp_repr.dst_port += 1;
+        assert!(!socket.accepts(&REMOTE_IP_REPR, &udp_repr));
+    }
+
+    #[test]
+    fn test_doesnt_accept_wrong_ip() {
+        let ip_repr = IpRepr::Ipv4(Ipv4Repr {
+            src_addr: Ipv4Address([10, 0, 0, 2]),
+            dst_addr: Ipv4Address([10, 0, 0, 10]),
+            protocol: IpProtocol::Udp,
+            payload_len: 8 + 6
+        });
+
+        let mut port_bound_socket = socket(buffer(1), buffer(0));
+        assert_eq!(port_bound_socket.bind(LOCAL_PORT), Ok(()));
+        assert!(port_bound_socket.accepts(&ip_repr, &REMOTE_UDP_REPR));
+
+        let mut ip_bound_socket = socket(buffer(1), buffer(0));
+        assert_eq!(ip_bound_socket.bind(LOCAL_END), Ok(()));
+        assert!(!ip_bound_socket.accepts(&ip_repr, &REMOTE_UDP_REPR));
+    }
 }