Selaa lähdekoodia

Factor out RawSocket::accepts.

Egor Karavaev 7 vuotta sitten
vanhempi
commit
5303501173
2 muutettua tiedostoa jossa 32 lisäystä ja 8 poistoa
  1. 6 4
      src/iface/ethernet.rs
  2. 26 4
      src/socket/raw.rs

+ 6 - 4
src/iface/ethernet.rs

@@ -289,13 +289,15 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
         let mut handled_by_raw_socket = false;
         for raw_socket in sockets.iter_mut().filter_map(
                 <Socket as AsSocket<RawSocket>>::try_as_socket) {
+            if !raw_socket.accepts(&ip_repr) { continue }
+
             match raw_socket.process(&ip_repr, ip_payload) {
                 // The packet is valid and handled by socket.
                 Ok(()) => handled_by_raw_socket = true,
-                // The packet isn't addressed to the socket, or cannot be accepted by it.
-                Err(Error::Rejected) | Err(Error::Exhausted) => (),
-                // Raw sockets either accept or reject packets, not parse them.
-                _ => unreachable!(),
+                // The socket buffer is full.
+                Err(Error::Exhausted) => (),
+                // Raw sockets don't validate the packets in any way.
+                Err(_) => unreachable!(),
             }
         }
 

+ 26 - 4
src/socket/raw.rs

@@ -166,9 +166,15 @@ impl<'a, 'b> RawSocket<'a, 'b> {
         Ok(length)
     }
 
+    pub(crate) fn accepts(&self, ip_repr: &IpRepr) -> bool {
+        if ip_repr.version() != self.ip_version { return false }
+        if ip_repr.protocol() != self.ip_protocol { return false }
+
+        true
+    }
+
     pub(crate) fn process(&mut self, ip_repr: &IpRepr, payload: &[u8]) -> Result<()> {
-        if ip_repr.version() != self.ip_version { return Err(Error::Rejected) }
-        if ip_repr.protocol() != self.ip_protocol { return Err(Error::Rejected) }
+        debug_assert!(self.accepts(ip_repr));
 
         let header_len = ip_repr.buffer_len();
         let total_len = header_len + payload.len();
@@ -246,17 +252,18 @@ mod test {
     fn socket(rx_buffer: SocketBuffer<'static, 'static>,
               tx_buffer: SocketBuffer<'static, 'static>)
             -> RawSocket<'static, 'static> {
-        match RawSocket::new(IpVersion::Ipv4, IpProtocol::Unknown(63),
+        match RawSocket::new(IpVersion::Ipv4, IpProtocol::Unknown(IP_PROTO),
                              rx_buffer, tx_buffer) {
             Socket::Raw(socket) => socket,
             _ => unreachable!()
         }
     }
 
+    const IP_PROTO: u8 = 63;
     const HEADER_REPR: IpRepr = IpRepr::Ipv4(Ipv4Repr {
         src_addr: Ipv4Address([10, 0, 0, 1]),
         dst_addr: Ipv4Address([10, 0, 0, 2]),
-        protocol: IpProtocol::Unknown(63),
+        protocol: IpProtocol::Unknown(IP_PROTO),
         payload_len: 4
     });
     const PACKET_BYTES: [u8; 24] = [
@@ -332,10 +339,12 @@ mod test {
         Ipv4Packet::new(&mut cksumd_packet).fill_checksum();
 
         assert_eq!(socket.recv(), Err(Error::Exhausted));
+        assert!(socket.accepts(&HEADER_REPR));
         assert_eq!(socket.process(&HEADER_REPR, &PACKET_PAYLOAD),
                    Ok(()));
         assert!(socket.can_recv());
 
+        assert!(socket.accepts(&HEADER_REPR));
         assert_eq!(socket.process(&HEADER_REPR, &PACKET_PAYLOAD),
                    Err(Error::Exhausted));
         assert_eq!(socket.recv(), Ok(&cksumd_packet[..]));
@@ -346,6 +355,7 @@ mod test {
     fn test_recv_truncated_slice() {
         let mut socket = socket(buffer(1), buffer(0));
 
+        assert!(socket.accepts(&HEADER_REPR));
         assert_eq!(socket.process(&HEADER_REPR, &PACKET_PAYLOAD),
                    Ok(()));
 
@@ -361,7 +371,19 @@ mod test {
         let mut buffer = vec![0; 128];
         buffer[..PACKET_BYTES.len()].copy_from_slice(&PACKET_BYTES[..]);
 
+        assert!(socket.accepts(&HEADER_REPR));
         assert_eq!(socket.process(&HEADER_REPR, &buffer),
                    Err(Error::Truncated));
     }
+
+    #[test]
+    fn test_doesnt_accept_wrong_proto() {
+        let socket = match RawSocket::new(IpVersion::Ipv4,
+                                          IpProtocol::Unknown(IP_PROTO+1),
+                                          buffer(1), buffer(1)) {
+            Socket::Raw(socket) => socket,
+            _ => unreachable!()
+        };
+        assert!(!socket.accepts(&HEADER_REPR));
+    }
 }