Browse Source

Merge #609

609: Finish ICMPv4 TimeExceeded support r=Dirbaio a=ngc0202

Will now Display and PrettyPrint properly. Its structure is exactly the same as that of the Destination Unreachable message's, so this mostly mirrors that implementation.

Co-authored-by: ngc0202 <ngc0202@gmail.com>
bors[bot] 2 years ago
parent
commit
7f0eb580af
2 changed files with 77 additions and 10 deletions
  1. 15 8
      src/socket/icmp.rs
  2. 62 2
      src/wire/icmpv4.rs

+ 15 - 8
src/socket/icmp.rs

@@ -350,18 +350,22 @@ impl<'a> Socket<'a> {
     pub(crate) fn accepts(&self, cx: &mut Context, ip_repr: &IpRepr, icmp_repr: &IcmpRepr) -> bool {
         match (&self.endpoint, icmp_repr) {
             // If we are bound to ICMP errors associated to a UDP port, only
-            // accept Destination Unreachable messages with the data containing
-            // a UDP packet send from the local port we are bound to.
+            // accept Destination Unreachable or Time Exceeded messages with
+            // the data containing a UDP packet send from the local port we
+            // are bound to.
             #[cfg(feature = "proto-ipv4")]
             (
                 &Endpoint::Udp(endpoint),
-                &IcmpRepr::Ipv4(Icmpv4Repr::DstUnreachable { data, .. }),
+                &IcmpRepr::Ipv4(
+                    Icmpv4Repr::DstUnreachable { data, header, .. }
+                    | Icmpv4Repr::TimeExceeded { data, header, .. },
+                ),
             ) if endpoint.addr.is_none() || endpoint.addr == Some(ip_repr.dst_addr()) => {
                 let packet = UdpPacket::new_unchecked(data);
                 match UdpRepr::parse(
                     &packet,
-                    &ip_repr.src_addr(),
-                    &ip_repr.dst_addr(),
+                    &header.src_addr.into(),
+                    &header.dst_addr.into(),
                     &cx.checksum_caps(),
                 ) {
                     Ok(repr) => endpoint.port == repr.src_port,
@@ -371,13 +375,16 @@ impl<'a> Socket<'a> {
             #[cfg(feature = "proto-ipv6")]
             (
                 &Endpoint::Udp(endpoint),
-                &IcmpRepr::Ipv6(Icmpv6Repr::DstUnreachable { data, .. }),
+                &IcmpRepr::Ipv6(
+                    Icmpv6Repr::DstUnreachable { data, header, .. }
+                    | Icmpv6Repr::TimeExceeded { data, header, .. },
+                ),
             ) if endpoint.addr.is_none() || endpoint.addr == Some(ip_repr.dst_addr()) => {
                 let packet = UdpPacket::new_unchecked(data);
                 match UdpRepr::parse(
                     &packet,
-                    &ip_repr.src_addr(),
-                    &ip_repr.dst_addr(),
+                    &header.src_addr.into(),
+                    &header.dst_addr.into(),
                     &cx.checksum_caps(),
                 ) {
                     Ok(repr) => endpoint.port == repr.src_port,

+ 62 - 2
src/wire/icmpv4.rs

@@ -138,6 +138,16 @@ enum_with_unknown! {
     }
 }
 
+impl fmt::Display for TimeExceeded {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        match *self {
+            TimeExceeded::TtlExpired => write!(f, "time-to-live exceeded in transit"),
+            TimeExceeded::FragExpired => write!(f, "fragment reassembly time exceeded"),
+            TimeExceeded::Unknown(id) => write!(f, "{}", id),
+        }
+    }
+}
+
 enum_with_unknown! {
     /// Internet protocol control message subtype for type "Parameter Problem".
     pub enum ParamProblem(u8) {
@@ -372,6 +382,11 @@ pub enum Repr<'a> {
         header: Ipv4Repr,
         data: &'a [u8],
     },
+    TimeExceeded {
+        reason: TimeExceeded,
+        header: Ipv4Repr,
+        data: &'a [u8],
+    },
 }
 
 impl<'a> Repr<'a> {
@@ -424,6 +439,30 @@ impl<'a> Repr<'a> {
                     data: payload,
                 })
             }
+
+            (Message::TimeExceeded, code) => {
+                let ip_packet = Ipv4Packet::new_checked(packet.data())?;
+
+                let payload = &packet.data()[ip_packet.header_len() as usize..];
+                // RFC 792 requires exactly eight bytes to be returned.
+                // We allow more, since there isn't a reason not to, but require at least eight.
+                if payload.len() < 8 {
+                    return Err(Error);
+                }
+
+                Ok(Repr::TimeExceeded {
+                    reason: TimeExceeded::from(code),
+                    header: Ipv4Repr {
+                        src_addr: ip_packet.src_addr(),
+                        dst_addr: ip_packet.dst_addr(),
+                        next_header: ip_packet.next_header(),
+                        payload_len: payload.len(),
+                        hop_limit: ip_packet.hop_limit(),
+                    },
+                    data: payload,
+                })
+            }
+
             _ => Err(Error),
         }
     }
@@ -434,7 +473,8 @@ impl<'a> Repr<'a> {
             &Repr::EchoRequest { data, .. } | &Repr::EchoReply { data, .. } => {
                 field::ECHO_SEQNO.end + data.len()
             }
-            &Repr::DstUnreachable { header, data, .. } => {
+            &Repr::DstUnreachable { header, data, .. }
+            | &Repr::TimeExceeded { header, data, .. } => {
                 field::UNUSED.end + header.buffer_len() + data.len()
             }
         }
@@ -487,6 +527,20 @@ impl<'a> Repr<'a> {
                 let payload = &mut ip_packet.into_inner()[header.buffer_len()..];
                 payload.copy_from_slice(data)
             }
+
+            Repr::TimeExceeded {
+                reason,
+                header,
+                data,
+            } => {
+                packet.set_msg_type(Message::TimeExceeded);
+                packet.set_msg_code(reason.into());
+
+                let mut ip_packet = Ipv4Packet::new_unchecked(packet.data_mut());
+                header.emit(&mut ip_packet, checksum_caps);
+                let payload = &mut ip_packet.into_inner()[header.buffer_len()..];
+                payload.copy_from_slice(data)
+            }
         }
 
         if checksum_caps.icmpv4.tx() {
@@ -510,6 +564,9 @@ impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {
                     Message::DstUnreachable => {
                         write!(f, " code={:?}", DstUnreachable::from(self.msg_code()))
                     }
+                    Message::TimeExceeded => {
+                        write!(f, " code={:?}", TimeExceeded::from(self.msg_code()))
+                    }
                     _ => write!(f, " code={}", self.msg_code()),
                 }
             }
@@ -545,6 +602,9 @@ impl<'a> fmt::Display for Repr<'a> {
             Repr::DstUnreachable { reason, .. } => {
                 write!(f, "ICMPv4 destination unreachable ({})", reason)
             }
+            Repr::TimeExceeded { reason, .. } => {
+                write!(f, "ICMPv4 time exceeded ({})", reason)
+            }
         }
     }
 }
@@ -564,7 +624,7 @@ impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> {
         write!(f, "{}{}", indent, packet)?;
 
         match packet.msg_type() {
-            Message::DstUnreachable => {
+            Message::DstUnreachable | Message::TimeExceeded => {
                 indent.increase(f)?;
                 super::Ipv4Packet::<&[u8]>::pretty_print(&packet.data(), f, indent)
             }