Bläddra i källkod

Merge pull request #859 from thvdveld/fix-recv_slice-truncation

return RecvError::Truncated when calling recv_slice with a small buffer
Thibaut Vandervelden 1 år sedan
förälder
incheckning
9cb718e41d
4 ändrade filer med 91 tillägg och 26 borttagningar
  1. 50 4
      src/socket/icmp.rs
  2. 19 6
      src/socket/raw.rs
  3. 1 1
      src/socket/tcp.rs
  4. 21 15
      src/socket/udp.rs

+ 50 - 4
src/socket/icmp.rs

@@ -61,12 +61,14 @@ impl std::error::Error for SendError {}
 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
 pub enum RecvError {
     Exhausted,
+    Truncated,
 }
 
 impl core::fmt::Display for RecvError {
     fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
         match self {
             RecvError::Exhausted => write!(f, "exhausted"),
+            RecvError::Truncated => write!(f, "truncated"),
         }
     }
 }
@@ -130,8 +132,8 @@ impl<'a> Socket<'a> {
     /// Create an ICMP socket with the given buffers.
     pub fn new(rx_buffer: PacketBuffer<'a>, tx_buffer: PacketBuffer<'a>) -> Socket<'a> {
         Socket {
-            rx_buffer: rx_buffer,
-            tx_buffer: tx_buffer,
+            rx_buffer,
+            tx_buffer,
             endpoint: Default::default(),
             hop_limit: None,
             #[cfg(feature = "async")]
@@ -394,9 +396,17 @@ impl<'a> Socket<'a> {
     /// Dequeue a packet received from a remote endpoint, copy the payload into the given slice,
     /// and return the amount of octets copied as well as the `IpAddress`
     ///
+    /// **Note**: when the size of the provided buffer is smaller than the size of the payload,
+    /// the packet is dropped and a `RecvError::Truncated` error is returned.
+    ///
     /// See also [recv](#method.recv).
     pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, IpAddress), RecvError> {
         let (buffer, endpoint) = self.recv()?;
+
+        if data.len() < buffer.len() {
+            return Err(RecvError::Truncated);
+        }
+
         let length = cmp::min(data.len(), buffer.len());
         data[..length].copy_from_slice(&buffer[..length]);
         Ok((length, endpoint))
@@ -555,7 +565,7 @@ impl<'a> Socket<'a> {
                         dst_addr,
                         next_header: IpProtocol::Icmp,
                         payload_len: repr.buffer_len(),
-                        hop_limit: hop_limit,
+                        hop_limit,
                     });
                     emit(cx, (ip_repr, IcmpRepr::Ipv4(repr)))
                 }
@@ -592,7 +602,7 @@ impl<'a> Socket<'a> {
                         dst_addr,
                         next_header: IpProtocol::Icmpv6,
                         payload_len: repr.buffer_len(),
-                        hop_limit: hop_limit,
+                        hop_limit,
                     });
                     emit(cx, (ip_repr, IcmpRepr::Ipv6(repr)))
                 }
@@ -1096,6 +1106,42 @@ mod test_ipv6 {
         assert!(!socket.can_recv());
     }
 
+    #[rstest]
+    #[case::ethernet(Medium::Ethernet)]
+    #[cfg(feature = "medium-ethernet")]
+    fn test_truncated_recv_slice(#[case] medium: Medium) {
+        let (mut iface, _, _) = setup(medium);
+        let cx = iface.context();
+
+        let mut socket = socket(buffer(1), buffer(1));
+        assert_eq!(socket.bind(Endpoint::Ident(0x1234)), Ok(()));
+
+        let checksum = ChecksumCapabilities::default();
+
+        let mut bytes = [0xff; 24];
+        let mut packet = Icmpv6Packet::new_unchecked(&mut bytes[..]);
+        ECHOV6_REPR.emit(
+            &LOCAL_IPV6.into(),
+            &REMOTE_IPV6.into(),
+            &mut packet,
+            &checksum,
+        );
+
+        assert!(socket.accepts(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into()));
+        socket.process(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into());
+        assert!(socket.can_recv());
+
+        assert!(socket.accepts(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into()));
+        socket.process(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into());
+
+        let mut buffer = [0u8; 1];
+        assert_eq!(
+            socket.recv_slice(&mut buffer[..]),
+            Err(RecvError::Truncated)
+        );
+        assert!(!socket.can_recv());
+    }
+
     #[rstest]
     #[case::ethernet(Medium::Ethernet)]
     #[cfg(feature = "medium-ethernet")]

+ 19 - 6
src/socket/raw.rs

@@ -57,12 +57,14 @@ impl std::error::Error for SendError {}
 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
 pub enum RecvError {
     Exhausted,
+    Truncated,
 }
 
 impl core::fmt::Display for RecvError {
     fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
         match self {
             RecvError::Exhausted => write!(f, "exhausted"),
+            RecvError::Truncated => write!(f, "truncated"),
         }
     }
 }
@@ -273,9 +275,16 @@ impl<'a> Socket<'a> {
 
     /// Dequeue a packet, and copy the payload into the given slice.
     ///
+    /// **Note**: when the size of the provided buffer is smaller than the size of the payload,
+    /// the packet is dropped and a `RecvError::Truncated` error is returned.
+    ///
     /// See also [recv](#method.recv).
     pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<usize, RecvError> {
         let buffer = self.recv()?;
+        if data.len() < buffer.len() {
+            return Err(RecvError::Truncated);
+        }
+
         let length = min(data.len(), buffer.len());
         data[..length].copy_from_slice(&buffer[..length]);
         Ok(length)
@@ -303,9 +312,16 @@ impl<'a> Socket<'a> {
     /// and return the amount of octets copied without removing the packet from the receive buffer.
     /// This function otherwise behaves identically to [recv_slice](#method.recv_slice).
     ///
+    /// **Note**: when the size of the provided buffer is smaller than the size of the payload,
+    /// no data is copied into the provided buffer and a `RecvError::Truncated` error is returned.
+    ///
     /// See also [peek](#method.peek).
     pub fn peek_slice(&mut self, data: &mut [u8]) -> Result<usize, RecvError> {
         let buffer = self.peek()?;
+        if data.len() < buffer.len() {
+            return Err(RecvError::Truncated);
+        }
+
         let length = min(data.len(), buffer.len());
         data[..length].copy_from_slice(&buffer[..length]);
         Ok(length)
@@ -602,8 +618,7 @@ mod test {
                     socket.process(&mut cx, &$hdr, &$payload);
 
                     let mut slice = [0; 4];
-                    assert_eq!(socket.recv_slice(&mut slice[..]), Ok(4));
-                    assert_eq!(&slice, &$packet[..slice.len()]);
+                    assert_eq!(socket.recv_slice(&mut slice[..]), Err(RecvError::Truncated));
                 }
 
                 #[rstest]
@@ -641,10 +656,8 @@ mod test {
                     socket.process(&mut cx, &$hdr, &$payload);
 
                     let mut slice = [0; 4];
-                    assert_eq!(socket.peek_slice(&mut slice[..]), Ok(4));
-                    assert_eq!(&slice, &$packet[..slice.len()]);
-                    assert_eq!(socket.recv_slice(&mut slice[..]), Ok(4));
-                    assert_eq!(&slice, &$packet[..slice.len()]);
+                    assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Truncated));
+                    assert_eq!(socket.recv_slice(&mut slice[..]), Err(RecvError::Truncated));
                     assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Exhausted));
                 }
             }

+ 1 - 1
src/socket/tcp.rs

@@ -1331,7 +1331,7 @@ impl<'a> Socket<'a> {
         // Rate-limit to 1 per second max.
         self.challenge_ack_timer = cx.now() + Duration::from_secs(1);
 
-        return Some(self.ack_reply(ip_repr, repr));
+        Some(self.ack_reply(ip_repr, repr))
     }
 
     pub(crate) fn accepts(&self, _cx: &mut Context, ip_repr: &IpRepr, repr: &TcpRepr) -> bool {

+ 21 - 15
src/socket/udp.rs

@@ -88,12 +88,14 @@ impl std::error::Error for SendError {}
 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
 pub enum RecvError {
     Exhausted,
+    Truncated,
 }
 
 impl core::fmt::Display for RecvError {
     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
         match self {
             RecvError::Exhausted => write!(f, "exhausted"),
+            RecvError::Truncated => write!(f, "truncated"),
         }
     }
 }
@@ -393,9 +395,17 @@ impl<'a> Socket<'a> {
     /// Dequeue a packet received from a remote endpoint, copy the payload into the given slice,
     /// and return the amount of octets copied as well as the endpoint.
     ///
+    /// **Note**: when the size of the provided buffer is smaller than the size of the payload,
+    /// the packet is dropped and a `RecvError::Truncated` error is returned.
+    ///
     /// See also [recv](#method.recv).
     pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, UdpMetadata), RecvError> {
         let (buffer, endpoint) = self.recv().map_err(|_| RecvError::Exhausted)?;
+
+        if data.len() < buffer.len() {
+            return Err(RecvError::Truncated);
+        }
+
         let length = min(data.len(), buffer.len());
         data[..length].copy_from_slice(&buffer[..length]);
         Ok((length, endpoint))
@@ -426,9 +436,17 @@ impl<'a> Socket<'a> {
     /// packet from the receive buffer.
     /// This function otherwise behaves identically to [recv_slice](#method.recv_slice).
     ///
+    /// **Note**: when the size of the provided buffer is smaller than the size of the payload,
+    /// no data is copied into the provided buffer and a `RecvError::Truncated` error is returned.
+    ///
     /// See also [peek](#method.peek).
     pub fn peek_slice(&mut self, data: &mut [u8]) -> Result<(usize, &UdpMetadata), RecvError> {
         let (buffer, endpoint) = self.peek()?;
+
+        if data.len() < buffer.len() {
+            return Err(RecvError::Truncated);
+        }
+
         let length = min(data.len(), buffer.len());
         data[..length].copy_from_slice(&buffer[..length]);
         Ok((length, endpoint))
@@ -851,11 +869,7 @@ mod test {
         );
 
         let mut slice = [0; 4];
-        assert_eq!(
-            socket.recv_slice(&mut slice[..]),
-            Ok((4, REMOTE_END.into()))
-        );
-        assert_eq!(&slice, b"abcd");
+        assert_eq!(socket.recv_slice(&mut slice[..]), Err(RecvError::Truncated));
     }
 
     #[rstest]
@@ -882,16 +896,8 @@ mod test {
         );
 
         let mut slice = [0; 4];
-        assert_eq!(
-            socket.peek_slice(&mut slice[..]),
-            Ok((4, &REMOTE_END.into()))
-        );
-        assert_eq!(&slice, b"abcd");
-        assert_eq!(
-            socket.recv_slice(&mut slice[..]),
-            Ok((4, REMOTE_END.into()))
-        );
-        assert_eq!(&slice, b"abcd");
+        assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Truncated));
+        assert_eq!(socket.recv_slice(&mut slice[..]), Err(RecvError::Truncated));
         assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Exhausted));
     }