Ver código fonte

Merge #734

734: Add peek and peek_slice functions to RawSocket r=Dirbaio a=ngc0202

This adds the `peek` and `peek_slice` functions to raw sockets, same as was added to UDP sockets in #278 

Co-authored-by: Nicholas Cyprus <nicholas.cyprus@caci.com>
bors[bot] 2 anos atrás
pai
commit
fe7713871e
1 arquivos alterados com 99 adições e 0 exclusões
  1. 99 0
      src/socket/raw.rs

+ 99 - 0
src/socket/raw.rs

@@ -247,6 +247,36 @@ impl<'a> Socket<'a> {
         Ok(length)
     }
 
+    /// Peek at a packet in the receive buffer and return a pointer to the
+    /// payload without removing the packet from the receive buffer.
+    /// This function otherwise behaves identically to [recv](#method.recv).
+    ///
+    /// It returns `Err(Error::Exhausted)` if the receive buffer is empty.
+    pub fn peek(&mut self) -> Result<&[u8], RecvError> {
+        let ((), packet_buf) = self.rx_buffer.peek().map_err(|_| RecvError::Exhausted)?;
+
+        net_trace!(
+            "raw:{}:{}: receive {} buffered octets",
+            self.ip_version,
+            self.ip_protocol,
+            packet_buf.len()
+        );
+
+        Ok(packet_buf)
+    }
+
+    /// Peek at a packet in the receive buffer, copy the payload into the given slice,
+    /// and return the amount of octets copied without removing the packet from the receive buffer.
+    /// This function otherwise behaves identically to [recv_slice](#method.recv_slice).
+    ///
+    /// See also [peek](#method.peek).
+    pub fn peek_slice(&mut self, data: &mut [u8]) -> Result<usize, RecvError> {
+        let buffer = self.peek()?;
+        let length = min(data.len(), buffer.len());
+        data[..length].copy_from_slice(&buffer[..length]);
+        Ok(length)
+    }
+
     pub(crate) fn accepts(&self, ip_repr: &IpRepr) -> bool {
         if ip_repr.version() != self.ip_version {
             return false;
@@ -536,6 +566,22 @@ mod test {
                     assert!(socket.accepts(&$hdr));
                     socket.process(&mut cx, &$hdr, &buffer);
                 }
+
+                #[test]
+                fn test_peek_truncated_slice() {
+                    let mut socket = $socket(buffer(1), buffer(0));
+                    let mut cx = Context::mock();
+
+                    assert!(socket.accepts(&$hdr));
+                    socket.process(&mut cx, &$hdr, &$payload);
+
+                    let mut slice = [0; 4];
+                    assert_eq!(socket.peek_slice(&mut slice[..]), Ok(4));
+                    assert_eq!(&slice, &$packet[..slice.len()]);
+                    assert_eq!(socket.recv_slice(&mut slice[..]), Ok(4));
+                    assert_eq!(&slice, &$packet[..slice.len()]);
+                    assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Exhausted));
+                }
             }
         };
     }
@@ -664,6 +710,59 @@ mod test {
         }
     }
 
+    #[test]
+    fn test_peek_process() {
+        #[cfg(feature = "proto-ipv4")]
+        {
+            let mut socket = ipv4_locals::socket(buffer(1), buffer(0));
+            let mut cx = Context::mock();
+
+            let mut cksumd_packet = ipv4_locals::PACKET_BYTES;
+            Ipv4Packet::new_unchecked(&mut cksumd_packet).fill_checksum();
+
+            assert_eq!(socket.peek(), Err(RecvError::Exhausted));
+            assert!(socket.accepts(&ipv4_locals::HEADER_REPR));
+            socket.process(
+                &mut cx,
+                &ipv4_locals::HEADER_REPR,
+                &ipv4_locals::PACKET_PAYLOAD,
+            );
+
+            assert!(socket.accepts(&ipv4_locals::HEADER_REPR));
+            socket.process(
+                &mut cx,
+                &ipv4_locals::HEADER_REPR,
+                &ipv4_locals::PACKET_PAYLOAD,
+            );
+            assert_eq!(socket.peek(), Ok(&cksumd_packet[..]));
+            assert_eq!(socket.recv(), Ok(&cksumd_packet[..]));
+            assert_eq!(socket.peek(), Err(RecvError::Exhausted));
+        }
+        #[cfg(feature = "proto-ipv6")]
+        {
+            let mut socket = ipv6_locals::socket(buffer(1), buffer(0));
+            let mut cx = Context::mock();
+
+            assert_eq!(socket.peek(), Err(RecvError::Exhausted));
+            assert!(socket.accepts(&ipv6_locals::HEADER_REPR));
+            socket.process(
+                &mut cx,
+                &ipv6_locals::HEADER_REPR,
+                &ipv6_locals::PACKET_PAYLOAD,
+            );
+
+            assert!(socket.accepts(&ipv6_locals::HEADER_REPR));
+            socket.process(
+                &mut cx,
+                &ipv6_locals::HEADER_REPR,
+                &ipv6_locals::PACKET_PAYLOAD,
+            );
+            assert_eq!(socket.peek(), Ok(&ipv6_locals::PACKET_BYTES[..]));
+            assert_eq!(socket.recv(), Ok(&ipv6_locals::PACKET_BYTES[..]));
+            assert_eq!(socket.peek(), Err(RecvError::Exhausted));
+        }
+    }
+
     #[test]
     fn test_doesnt_accept_wrong_proto() {
         #[cfg(feature = "proto-ipv4")]