浏览代码

Remove IpAddress::Unspecified, assign src addr in sockets.

This continues work started in #579, with the goal of "remove unspecified variants from wire".

"unspecified" variants are bad, they've been a source of bugs in the past. The issue with them is that
depending on the context it may or may not make sense for the value to be unspecified.
It's better to not have them, and then use Option only where the value can really be unspecified.

This removes lots of `Address::Unspecified => unreachable!()` and similar match arms, which shows the
unreachable variant was actively unwanted in many places.

This fixes the "unspecified src addr" panic:

- Picking src addr is now the resposibility of the sockets, not the interface. All sockets now emit IpReprs with properly assigned src addr.
- Removed the `assert!(!ip_repr.src_addr().is_unspecified());`. This assert is WRONG even if
  now sockets pick the source address, because there ARE cases where we indeed want to send a
  packet with zero src addr, for example in DHCP.
Dario Nieuwenhuis 2 年之前
父节点
当前提交
eb41d077e0
共有 11 个文件被更改,包括 378 次插入480 次删除
  1. 4 1
      CHANGELOG.md
  2. 1 5
      examples/loopback.rs
  3. 0 2
      examples/ping.rs
  4. 39 5
      src/iface/interface.rs
  5. 0 1
      src/iface/route.rs
  6. 38 32
      src/socket/icmp.rs
  7. 0 1
      src/socket/raw.rs
  8. 163 331
      src/socket/tcp.rs
  9. 23 12
      src/socket/udp.rs
  10. 107 88
      src/wire/ip.rs
  11. 3 2
      src/wire/mod.rs

+ 4 - 1
CHANGELOG.md

@@ -6,7 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 
 ## [Unreleased]
 
-- No unreleased changes.
+- Remove IpRepr::Unspecified (#579)
+- Remove IpVersion::Unspecified
+- Remove IpAddress::Unspecified
+- When sending packets with a raw socket, the source IP address is sent unmodified (it was previously replaced with the interface's address if it was unspecified).
 
 ## [0.8.1] - 2022-05-12
 

+ 1 - 5
examples/loopback.rs

@@ -150,11 +150,7 @@ fn main() {
             if !did_connect {
                 debug!("connecting");
                 socket
-                    .connect(
-                        cx,
-                        (IpAddress::v4(127, 0, 0, 1), 1234),
-                        (IpAddress::Unspecified, 65000),
-                    )
+                    .connect(cx, (IpAddress::v4(127, 0, 0, 1), 1234), 65000)
                     .unwrap();
                 did_connect = true;
             }

+ 0 - 2
examples/ping.rs

@@ -195,7 +195,6 @@ fn main() {
                         &device_caps.checksum,
                     );
                 }
-                _ => unimplemented!(),
             }
 
             waiting_queue.insert(seq_no, timestamp);
@@ -239,7 +238,6 @@ fn main() {
                         received
                     );
                 }
-                _ => unimplemented!(),
             }
         }
 

+ 39 - 5
src/iface/interface.rs

@@ -641,6 +641,7 @@ where
                 }
             }
             // Multicast is not yet implemented for other address families
+            #[allow(unreachable_patterns)]
             _ => Err(Error::Unaddressable),
         }
     }
@@ -672,6 +673,7 @@ where
                 }
             }
             // Multicast is not yet implemented for other address families
+            #[allow(unreachable_patterns)]
             _ => Err(Error::Unaddressable),
         }
     }
@@ -693,6 +695,7 @@ where
             .iter()
             .filter_map(|cidr| match cidr.address() {
                 IpAddress::Ipv4(addr) => Some(addr),
+                #[allow(unreachable_patterns)]
                 _ => None,
             })
             .next()
@@ -1061,16 +1064,46 @@ impl<'a> InterfaceInner<'a> {
 
     #[allow(unused)] // unused depending on which sockets are enabled
     pub(crate) fn get_source_address(&mut self, dst_addr: IpAddress) -> Option<IpAddress> {
-        let v = dst_addr.version().unwrap();
+        let v = dst_addr.version();
         for cidr in self.ip_addrs.iter() {
             let addr = cidr.address();
-            if addr.version() == Some(v) {
+            if addr.version() == v {
                 return Some(addr);
             }
         }
         None
     }
 
+    #[cfg(feature = "proto-ipv4")]
+    #[allow(unused)]
+    pub(crate) fn get_source_address_ipv4(
+        &mut self,
+        _dst_addr: Ipv4Address,
+    ) -> Option<Ipv4Address> {
+        for cidr in self.ip_addrs.iter() {
+            #[allow(irrefutable_let_patterns)] // if only ipv4 is enabled
+            if let IpCidr::Ipv4(cidr) = cidr {
+                return Some(cidr.address());
+            }
+        }
+        None
+    }
+
+    #[cfg(feature = "proto-ipv6")]
+    #[allow(unused)]
+    pub(crate) fn get_source_address_ipv6(
+        &mut self,
+        _dst_addr: Ipv6Address,
+    ) -> Option<Ipv6Address> {
+        for cidr in self.ip_addrs.iter() {
+            #[allow(irrefutable_let_patterns)] // if only ipv6 is enabled
+            if let IpCidr::Ipv6(cidr) = cidr {
+                return Some(cidr.address());
+            }
+        }
+        None
+    }
+
     #[cfg(test)]
     pub(crate) fn mock() -> Self {
         Self {
@@ -1207,6 +1240,7 @@ impl<'a> InterfaceInner<'a> {
                 key == Ipv4Address::MULTICAST_ALL_SYSTEMS
                     || self.ipv4_multicast_groups.get(&key).is_some()
             }
+            #[allow(unreachable_patterns)]
             _ => false,
         }
     }
@@ -2301,7 +2335,6 @@ impl<'a> InterfaceInner<'a> {
         if dst_addr.is_multicast() {
             let b = dst_addr.as_bytes();
             let hardware_addr = match *dst_addr {
-                IpAddress::Unspecified => unreachable!(),
                 #[cfg(feature = "proto-ipv4")]
                 IpAddress::Ipv4(_addr) => {
                     HardwareAddress::Ethernet(EthernetAddress::from_bytes(&[
@@ -2401,6 +2434,7 @@ impl<'a> InterfaceInner<'a> {
                 self.dispatch_ip(tx_token, packet)?;
             }
 
+            #[allow(unreachable_patterns)]
             _ => (),
         }
         // The request got dispatched, limit the rate on the cache.
@@ -2417,7 +2451,6 @@ impl<'a> InterfaceInner<'a> {
 
     fn dispatch_ip<Tx: TxToken>(&mut self, tx_token: Tx, packet: IpPacket) -> Result<()> {
         let ip_repr = packet.ip_repr();
-        assert!(!ip_repr.src_addr().is_unspecified());
         assert!(!ip_repr.dst_addr().is_unspecified());
 
         match self.caps.medium {
@@ -2471,7 +2504,6 @@ impl<'a> InterfaceInner<'a> {
     #[cfg(feature = "medium-ieee802154")]
     fn dispatch_ieee802154<Tx: TxToken>(&mut self, tx_token: Tx, packet: IpPacket) -> Result<()> {
         let ip_repr = packet.ip_repr();
-        assert!(!ip_repr.src_addr().is_unspecified());
         assert!(!ip_repr.dst_addr().is_unspecified());
 
         match self.caps.medium {
@@ -2518,6 +2550,7 @@ impl<'a> InterfaceInner<'a> {
 
                 let (src_addr, dst_addr) = match (ip_repr.src_addr(), ip_repr.dst_addr()) {
                     (IpAddress::Ipv6(src_addr), IpAddress::Ipv6(dst_addr)) => (src_addr, dst_addr),
+                    #[allow(unreachable_patterns)]
                     _ => return Err(Error::Unaddressable),
                 };
 
@@ -2608,6 +2641,7 @@ impl<'a> InterfaceInner<'a> {
                     Ok(())
                 })
             }
+            #[allow(unreachable_patterns)]
             _ => Err(Error::NotSupported),
         }
     }

+ 0 - 1
src/iface/route.rs

@@ -134,7 +134,6 @@ impl<'a> Routes<'a> {
             IpAddress::Ipv4(addr) => IpCidr::Ipv4(Ipv4Cidr::new(*addr, 32)),
             #[cfg(feature = "proto-ipv6")]
             IpAddress::Ipv6(addr) => IpCidr::Ipv6(Ipv6Cidr::new(*addr, 128)),
-            _ => unimplemented!(),
         };
 
         for (prefix, route) in self

+ 38 - 32
src/socket/icmp.rs

@@ -11,10 +11,10 @@ use crate::{Error, Result};
 
 use crate::wire::IcmpRepr;
 #[cfg(feature = "proto-ipv4")]
-use crate::wire::{Icmpv4Packet, Icmpv4Repr, Ipv4Address, Ipv4Repr};
+use crate::wire::{Icmpv4Packet, Icmpv4Repr, Ipv4Repr};
 #[cfg(feature = "proto-ipv6")]
-use crate::wire::{Icmpv6Packet, Icmpv6Repr, Ipv6Address, Ipv6Repr};
-use crate::wire::{IpAddress, IpEndpoint, IpProtocol, IpRepr};
+use crate::wire::{Icmpv6Packet, Icmpv6Repr, Ipv6Repr};
+use crate::wire::{IpAddress, IpListenEndpoint, IpProtocol, IpRepr};
 use crate::wire::{UdpPacket, UdpRepr};
 
 /// Type of endpoint to bind the ICMP socket to. See [IcmpSocket::bind] for
@@ -26,7 +26,7 @@ use crate::wire::{UdpPacket, UdpRepr};
 pub enum Endpoint {
     Unspecified,
     Ident(u16),
-    Udp(IpEndpoint),
+    Udp(IpListenEndpoint),
 }
 
 impl Endpoint {
@@ -80,7 +80,7 @@ impl<'a> IcmpSocket<'a> {
         IcmpSocket {
             rx_buffer: rx_buffer,
             tx_buffer: tx_buffer,
-            endpoint: Endpoint::default(),
+            endpoint: Default::default(),
             hop_limit: None,
             #[cfg(feature = "async")]
             rx_waker: WakerRegistration::new(),
@@ -170,14 +170,14 @@ impl<'a> IcmpSocket<'a> {
     /// # use smoltcp::socket::{Socket, IcmpSocket, IcmpSocketBuffer, IcmpPacketMetadata};
     /// # let rx_buffer = IcmpSocketBuffer::new(vec![IcmpPacketMetadata::EMPTY], vec![0; 20]);
     /// # let tx_buffer = IcmpSocketBuffer::new(vec![IcmpPacketMetadata::EMPTY], vec![0; 20]);
-    /// use smoltcp::wire::IpEndpoint;
+    /// use smoltcp::wire::IpListenEndpoint;
     /// use smoltcp::socket::IcmpEndpoint;
     ///
     /// let mut icmp_socket = // ...
     /// # IcmpSocket::new(rx_buffer, tx_buffer);
     ///
     /// // Bind to ICMP error responses for UDP packets sent from port 53.
-    /// let endpoint = IpEndpoint::from(53);
+    /// let endpoint = IpListenEndpoint::from(53);
     /// icmp_socket.bind(IcmpEndpoint::Udp(endpoint)).unwrap();
     /// ```
     ///
@@ -332,7 +332,7 @@ impl<'a> IcmpSocket<'a> {
             (
                 &Endpoint::Udp(endpoint),
                 &IcmpRepr::Ipv4(Icmpv4Repr::DstUnreachable { data, .. }),
-            ) if endpoint.addr.is_unspecified() || endpoint.addr == ip_repr.dst_addr() => {
+            ) if endpoint.addr.is_none() || endpoint.addr == Some(ip_repr.dst_addr()) => {
                 let packet = UdpPacket::new_unchecked(data);
                 match UdpRepr::parse(
                     &packet,
@@ -348,7 +348,7 @@ impl<'a> IcmpSocket<'a> {
             (
                 &Endpoint::Udp(endpoint),
                 &IcmpRepr::Ipv6(Icmpv6Repr::DstUnreachable { data, .. }),
-            ) if endpoint.addr.is_unspecified() || endpoint.addr == ip_repr.dst_addr() => {
+            ) if endpoint.addr.is_none() || endpoint.addr == Some(ip_repr.dst_addr()) => {
                 let packet = UdpPacket::new_unchecked(data);
                 match UdpRepr::parse(
                     &packet,
@@ -447,12 +447,16 @@ impl<'a> IcmpSocket<'a> {
             );
             match *remote_endpoint {
                 #[cfg(feature = "proto-ipv4")]
-                IpAddress::Ipv4(ipv4_addr) => {
+                IpAddress::Ipv4(dst_addr) => {
+                    let src_addr = match cx.get_source_address_ipv4(dst_addr) {
+                        Some(addr) => addr,
+                        None => return Err(Error::Unaddressable),
+                    };
                     let packet = Icmpv4Packet::new_unchecked(&*packet_buf);
                     let repr = Icmpv4Repr::parse(&packet, &ChecksumCapabilities::ignored())?;
                     let ip_repr = IpRepr::Ipv4(Ipv4Repr {
-                        src_addr: Ipv4Address::default(),
-                        dst_addr: ipv4_addr,
+                        src_addr,
+                        dst_addr,
                         next_header: IpProtocol::Icmp,
                         payload_len: repr.buffer_len(),
                         hop_limit: hop_limit,
@@ -460,25 +464,27 @@ impl<'a> IcmpSocket<'a> {
                     emit(cx, (ip_repr, IcmpRepr::Ipv4(repr)))
                 }
                 #[cfg(feature = "proto-ipv6")]
-                IpAddress::Ipv6(ipv6_addr) => {
+                IpAddress::Ipv6(dst_addr) => {
+                    let src_addr = match cx.get_source_address_ipv6(dst_addr) {
+                        Some(addr) => addr,
+                        None => return Err(Error::Unaddressable),
+                    };
                     let packet = Icmpv6Packet::new_unchecked(&*packet_buf);
-                    let src_addr = Ipv6Address::default();
                     let repr = Icmpv6Repr::parse(
                         &src_addr.into(),
-                        &ipv6_addr.into(),
+                        &dst_addr.into(),
                         &packet,
                         &ChecksumCapabilities::ignored(),
                     )?;
                     let ip_repr = IpRepr::Ipv6(Ipv6Repr {
-                        src_addr: src_addr,
-                        dst_addr: ipv6_addr,
+                        src_addr,
+                        dst_addr,
                         next_header: IpProtocol::Icmpv6,
                         payload_len: repr.buffer_len(),
                         hop_limit: hop_limit,
                     });
                     emit(cx, (ip_repr, IcmpRepr::Ipv6(repr)))
                 }
-                _ => Err(Error::Unaddressable),
             }
         })?;
 
@@ -531,10 +537,10 @@ mod tests_common {
 mod test_ipv4 {
     use super::tests_common::*;
 
-    use crate::wire::Icmpv4DstUnreachable;
+    use crate::wire::{Icmpv4DstUnreachable, IpEndpoint, Ipv4Address};
 
-    const REMOTE_IPV4: Ipv4Address = Ipv4Address([0x7f, 0x00, 0x00, 0x02]);
-    const LOCAL_IPV4: Ipv4Address = Ipv4Address([0x7f, 0x00, 0x00, 0x01]);
+    const REMOTE_IPV4: Ipv4Address = Ipv4Address([192, 168, 1, 2]);
+    const LOCAL_IPV4: Ipv4Address = Ipv4Address([192, 168, 1, 1]);
     const LOCAL_END_V4: IpEndpoint = IpEndpoint {
         addr: IpAddress::Ipv4(LOCAL_IPV4),
         port: LOCAL_PORT,
@@ -547,7 +553,7 @@ mod test_ipv4 {
     };
 
     static LOCAL_IPV4_REPR: IpRepr = IpRepr::Ipv4(Ipv4Repr {
-        src_addr: Ipv4Address::UNSPECIFIED,
+        src_addr: LOCAL_IPV4,
         dst_addr: REMOTE_IPV4,
         next_header: IpProtocol::Icmp,
         payload_len: 24,
@@ -566,7 +572,7 @@ mod test_ipv4 {
     fn test_send_unaddressable() {
         let mut socket = socket(buffer(0), buffer(1));
         assert_eq!(
-            socket.send_slice(b"abcdef", IpAddress::default()),
+            socket.send_slice(b"abcdef", IpAddress::Ipv4(Ipv4Address::default())),
             Err(Error::Unaddressable)
         );
         assert_eq!(socket.send_slice(b"abcdef", REMOTE_IPV4.into()), Ok(()));
@@ -648,7 +654,7 @@ mod test_ipv4 {
                 assert_eq!(
                     ip_repr,
                     IpRepr::Ipv4(Ipv4Repr {
-                        src_addr: Ipv4Address::UNSPECIFIED,
+                        src_addr: LOCAL_IPV4,
                         dst_addr: REMOTE_IPV4,
                         next_header: IpProtocol::Icmp,
                         payload_len: ECHOV4_REPR.buffer_len(),
@@ -719,7 +725,7 @@ mod test_ipv4 {
     fn test_accepts_udp() {
         let mut socket = socket(buffer(1), buffer(1));
         let mut cx = Context::mock();
-        assert_eq!(socket.bind(Endpoint::Udp(LOCAL_END_V4)), Ok(()));
+        assert_eq!(socket.bind(Endpoint::Udp(LOCAL_END_V4.into())), Ok(()));
 
         let checksum = ChecksumCapabilities::default();
 
@@ -778,12 +784,12 @@ mod test_ipv4 {
 mod test_ipv6 {
     use super::tests_common::*;
 
-    use crate::wire::Icmpv6DstUnreachable;
+    use crate::wire::{Icmpv6DstUnreachable, IpEndpoint, Ipv6Address};
 
     const REMOTE_IPV6: Ipv6Address =
-        Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]);
-    const LOCAL_IPV6: Ipv6Address =
         Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]);
+    const LOCAL_IPV6: Ipv6Address =
+        Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]);
     const LOCAL_END_V6: IpEndpoint = IpEndpoint {
         addr: IpAddress::Ipv6(LOCAL_IPV6),
         port: LOCAL_PORT,
@@ -795,7 +801,7 @@ mod test_ipv6 {
     };
 
     static LOCAL_IPV6_REPR: IpRepr = IpRepr::Ipv6(Ipv6Repr {
-        src_addr: Ipv6Address::UNSPECIFIED,
+        src_addr: LOCAL_IPV6,
         dst_addr: REMOTE_IPV6,
         next_header: IpProtocol::Icmpv6,
         payload_len: 24,
@@ -814,7 +820,7 @@ mod test_ipv6 {
     fn test_send_unaddressable() {
         let mut socket = socket(buffer(0), buffer(1));
         assert_eq!(
-            socket.send_slice(b"abcdef", IpAddress::default()),
+            socket.send_slice(b"abcdef", IpAddress::Ipv6(Ipv6Address::default())),
             Err(Error::Unaddressable)
         );
         assert_eq!(socket.send_slice(b"abcdef", REMOTE_IPV6.into()), Ok(()));
@@ -906,7 +912,7 @@ mod test_ipv6 {
                 assert_eq!(
                     ip_repr,
                     IpRepr::Ipv6(Ipv6Repr {
-                        src_addr: Ipv6Address::UNSPECIFIED,
+                        src_addr: LOCAL_IPV6,
                         dst_addr: REMOTE_IPV6,
                         next_header: IpProtocol::Icmpv6,
                         payload_len: ECHOV6_REPR.buffer_len(),
@@ -987,7 +993,7 @@ mod test_ipv6 {
     fn test_accepts_udp() {
         let mut socket = socket(buffer(1), buffer(1));
         let mut cx = Context::mock();
-        assert_eq!(socket.bind(Endpoint::Udp(LOCAL_END_V6)), Ok(()));
+        assert_eq!(socket.bind(Endpoint::Udp(LOCAL_END_V6.into())), Ok(()));
 
         let checksum = ChecksumCapabilities::default();
 

+ 0 - 1
src/socket/raw.rs

@@ -277,7 +277,6 @@ impl<'a> RawSocket<'a> {
                     let ipv6_repr = Ipv6Repr::parse(&packet)?;
                     Ok((IpRepr::Ipv6(ipv6_repr), packet.payload()))
                 }
-                IpVersion::Unspecified => unreachable!(),
             }
         }
 

文件差异内容过多而无法显示
+ 163 - 331
src/socket/tcp.rs


+ 23 - 12
src/socket/udp.rs

@@ -7,7 +7,7 @@ use crate::socket::PollAt;
 #[cfg(feature = "async")]
 use crate::socket::WakerRegistration;
 use crate::storage::{PacketBuffer, PacketMetadata};
-use crate::wire::{IpEndpoint, IpProtocol, IpRepr, UdpRepr};
+use crate::wire::{IpEndpoint, IpListenEndpoint, IpProtocol, IpRepr, UdpRepr};
 use crate::{Error, Result};
 
 /// A UDP packet metadata.
@@ -22,7 +22,7 @@ pub type UdpSocketBuffer<'a> = PacketBuffer<'a, IpEndpoint>;
 /// packet buffers.
 #[derive(Debug)]
 pub struct UdpSocket<'a> {
-    endpoint: IpEndpoint,
+    endpoint: IpListenEndpoint,
     rx_buffer: UdpSocketBuffer<'a>,
     tx_buffer: UdpSocketBuffer<'a>,
     /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
@@ -37,7 +37,7 @@ impl<'a> UdpSocket<'a> {
     /// Create an UDP socket with the given buffers.
     pub fn new(rx_buffer: UdpSocketBuffer<'a>, tx_buffer: UdpSocketBuffer<'a>) -> UdpSocket<'a> {
         UdpSocket {
-            endpoint: IpEndpoint::default(),
+            endpoint: IpListenEndpoint::default(),
             rx_buffer,
             tx_buffer,
             hop_limit: None,
@@ -85,7 +85,7 @@ impl<'a> UdpSocket<'a> {
 
     /// Return the bound endpoint.
     #[inline]
-    pub fn endpoint(&self) -> IpEndpoint {
+    pub fn endpoint(&self) -> IpListenEndpoint {
         self.endpoint
     }
 
@@ -121,7 +121,7 @@ impl<'a> UdpSocket<'a> {
     /// This function returns `Err(Error::Illegal)` if the socket was open
     /// (see [is_open](#method.is_open)), and `Err(Error::Unaddressable)`
     /// if the port in the given endpoint is zero.
-    pub fn bind<T: Into<IpEndpoint>>(&mut self, endpoint: T) -> Result<()> {
+    pub fn bind<T: Into<IpListenEndpoint>>(&mut self, endpoint: T) -> Result<()> {
         let endpoint = endpoint.into();
         if endpoint.port == 0 {
             return Err(Error::Unaddressable);
@@ -145,7 +145,7 @@ impl<'a> UdpSocket<'a> {
     /// Close the socket.
     pub fn close(&mut self) {
         // Clear the bound endpoint of the socket.
-        self.endpoint = IpEndpoint::default();
+        self.endpoint = IpListenEndpoint::default();
 
         // Reset the RX and TX buffers of the socket.
         self.tx_buffer.reset();
@@ -211,7 +211,10 @@ impl<'a> UdpSocket<'a> {
         if self.endpoint.port == 0 {
             return Err(Error::Unaddressable);
         }
-        if !remote_endpoint.is_specified() {
+        if remote_endpoint.addr.is_unspecified() {
+            return Err(Error::Unaddressable);
+        }
+        if remote_endpoint.port == 0 {
             return Err(Error::Unaddressable);
         }
 
@@ -297,8 +300,8 @@ impl<'a> UdpSocket<'a> {
         if self.endpoint.port != repr.dst_port {
             return false;
         }
-        if !self.endpoint.addr.is_unspecified()
-            && self.endpoint.addr != ip_repr.dst_addr()
+        if !self.endpoint.addr.is_none()
+            && self.endpoint.addr != Some(ip_repr.dst_addr())
             && !ip_repr.dst_addr().is_broadcast()
             && !ip_repr.dst_addr().is_multicast()
         {
@@ -349,6 +352,14 @@ impl<'a> UdpSocket<'a> {
 
         self.tx_buffer
             .dequeue_with(|remote_endpoint, payload_buf| {
+                let src_addr = match endpoint.addr {
+                    Some(addr) => addr,
+                    None => match cx.get_source_address(remote_endpoint.addr) {
+                        Some(addr) => addr,
+                        None => return Err(Error::Unaddressable),
+                    },
+                };
+
                 net_trace!(
                     "udp:{}:{}: sending {} octets",
                     endpoint,
@@ -361,7 +372,7 @@ impl<'a> UdpSocket<'a> {
                     dst_port: remote_endpoint.port,
                 };
                 let ip_repr = IpRepr::new(
-                    endpoint.addr,
+                    src_addr,
                     remote_endpoint.addr,
                     IpProtocol::Udp,
                     repr.header_len() + payload_buf.len(),
@@ -388,7 +399,7 @@ impl<'a> UdpSocket<'a> {
 #[cfg(test)]
 mod test {
     use super::*;
-    use crate::wire::{IpAddress, IpRepr, UdpRepr};
+    use crate::wire::{IpRepr, UdpRepr};
 
     fn buffer(packets: usize) -> UdpSocketBuffer<'static> {
         UdpSocketBuffer::new(
@@ -511,7 +522,7 @@ mod test {
             socket.send_slice(
                 b"abcdef",
                 IpEndpoint {
-                    addr: IpAddress::Unspecified,
+                    addr: IpvXAddress::UNSPECIFIED.into(),
                     ..REMOTE_END
                 }
             ),

+ 107 - 88
src/wire/ip.rs

@@ -11,9 +11,7 @@ use crate::{Error, Result};
 /// Internet protocol version.
 #[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
-#[non_exhaustive]
 pub enum Version {
-    Unspecified,
     #[cfg(feature = "proto-ipv4")]
     Ipv4,
     #[cfg(feature = "proto-ipv6")]
@@ -39,7 +37,6 @@ impl Version {
 impl fmt::Display for Version {
     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
         match *self {
-            Version::Unspecified => write!(f, "IPv?"),
             #[cfg(feature = "proto-ipv4")]
             Version::Ipv4 => write!(f, "IPv4"),
             #[cfg(feature = "proto-ipv6")]
@@ -84,11 +81,7 @@ impl fmt::Display for Protocol {
 
 /// An internetworking address.
 #[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
-#[non_exhaustive]
 pub enum Address {
-    /// An unspecified address.
-    /// May be used as a placeholder for storage where the address is not assigned yet.
-    Unspecified,
     /// An IPv4 address.
     #[cfg(feature = "proto-ipv4")]
     Ipv4(Ipv4Address),
@@ -112,20 +105,18 @@ impl Address {
     }
 
     /// Return the protocol version.
-    pub fn version(&self) -> Option<Version> {
+    pub fn version(&self) -> Version {
         match self {
-            Address::Unspecified => None,
             #[cfg(feature = "proto-ipv4")]
-            Address::Ipv4(_) => Some(Version::Ipv4),
+            Address::Ipv4(_) => Version::Ipv4,
             #[cfg(feature = "proto-ipv6")]
-            Address::Ipv6(_) => Some(Version::Ipv6),
+            Address::Ipv6(_) => Version::Ipv6,
         }
     }
 
     /// Return an address as a sequence of octets, in big-endian.
     pub fn as_bytes(&self) -> &[u8] {
         match *self {
-            Address::Unspecified => &[],
             #[cfg(feature = "proto-ipv4")]
             Address::Ipv4(ref addr) => addr.as_bytes(),
             #[cfg(feature = "proto-ipv6")]
@@ -136,7 +127,6 @@ impl Address {
     /// Query whether the address is a valid unicast address.
     pub fn is_unicast(&self) -> bool {
         match *self {
-            Address::Unspecified => false,
             #[cfg(feature = "proto-ipv4")]
             Address::Ipv4(addr) => addr.is_unicast(),
             #[cfg(feature = "proto-ipv6")]
@@ -147,7 +137,6 @@ impl Address {
     /// Query whether the address is a valid multicast address.
     pub fn is_multicast(&self) -> bool {
         match *self {
-            Address::Unspecified => false,
             #[cfg(feature = "proto-ipv4")]
             Address::Ipv4(addr) => addr.is_multicast(),
             #[cfg(feature = "proto-ipv6")]
@@ -158,7 +147,6 @@ impl Address {
     /// Query whether the address is the broadcast address.
     pub fn is_broadcast(&self) -> bool {
         match *self {
-            Address::Unspecified => false,
             #[cfg(feature = "proto-ipv4")]
             Address::Ipv4(addr) => addr.is_broadcast(),
             #[cfg(feature = "proto-ipv6")]
@@ -169,7 +157,6 @@ impl Address {
     /// Query whether the address falls into the "unspecified" range.
     pub fn is_unspecified(&self) -> bool {
         match *self {
-            Address::Unspecified => true,
             #[cfg(feature = "proto-ipv4")]
             Address::Ipv4(addr) => addr.is_unspecified(),
             #[cfg(feature = "proto-ipv6")]
@@ -177,17 +164,6 @@ impl Address {
         }
     }
 
-    /// Return an unspecified address that has the same IP version as `self`.
-    pub fn as_unspecified(&self) -> Address {
-        match *self {
-            Address::Unspecified => Address::Unspecified,
-            #[cfg(feature = "proto-ipv4")]
-            Address::Ipv4(_) => Address::Ipv4(Ipv4Address::UNSPECIFIED),
-            #[cfg(feature = "proto-ipv6")]
-            Address::Ipv6(_) => Address::Ipv6(Ipv6Address::UNSPECIFIED),
-        }
-    }
-
     /// If `self` is a CIDR-compatible subnet mask, return `Some(prefix_len)`,
     /// where `prefix_len` is the number of leading zeroes. Return `None` otherwise.
     pub fn prefix_len(&self) -> Option<u8> {
@@ -233,7 +209,6 @@ impl From<Address> for ::std::net::IpAddr {
             Address::Ipv4(ipv4) => ::std::net::IpAddr::V4(ipv4.into()),
             #[cfg(feature = "proto-ipv6")]
             Address::Ipv6(ipv6) => ::std::net::IpAddr::V6(ipv6.into()),
-            _ => unreachable!(),
         }
     }
 }
@@ -252,12 +227,6 @@ impl From<::std::net::Ipv6Addr> for Address {
     }
 }
 
-impl Default for Address {
-    fn default() -> Address {
-        Address::Unspecified
-    }
-}
-
 #[cfg(feature = "proto-ipv4")]
 impl From<Ipv4Address> for Address {
     fn from(addr: Ipv4Address) -> Self {
@@ -275,7 +244,6 @@ impl From<Ipv6Address> for Address {
 impl fmt::Display for Address {
     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
         match *self {
-            Address::Unspecified => write!(f, "*"),
             #[cfg(feature = "proto-ipv4")]
             Address::Ipv4(addr) => write!(f, "{}", addr),
             #[cfg(feature = "proto-ipv6")]
@@ -288,7 +256,6 @@ impl fmt::Display for Address {
 impl defmt::Format for Address {
     fn format(&self, f: defmt::Formatter) {
         match self {
-            &Address::Unspecified => defmt::write!(f, "{:?}", "*"),
             #[cfg(feature = "proto-ipv4")]
             &Address::Ipv4(addr) => defmt::write!(f, "{:?}", addr),
             #[cfg(feature = "proto-ipv6")]
@@ -300,7 +267,6 @@ impl defmt::Format for Address {
 /// A specification of a CIDR block, containing an address and a variable-length
 /// subnet masking prefix length.
 #[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
-#[non_exhaustive]
 pub enum Cidr {
     #[cfg(feature = "proto-ipv4")]
     Ipv4(Ipv4Cidr),
@@ -312,17 +278,13 @@ impl Cidr {
     /// Create a CIDR block from the given address and prefix length.
     ///
     /// # Panics
-    /// This function panics if the given address is unspecified, or
-    /// the given prefix length is invalid for the given address.
+    /// This function panics if the given prefix length is invalid for the given address.
     pub fn new(addr: Address, prefix_len: u8) -> Cidr {
         match addr {
             #[cfg(feature = "proto-ipv4")]
             Address::Ipv4(addr) => Cidr::Ipv4(Ipv4Cidr::new(addr, prefix_len)),
             #[cfg(feature = "proto-ipv6")]
             Address::Ipv6(addr) => Cidr::Ipv6(Ipv6Cidr::new(addr, prefix_len)),
-            Address::Unspecified => {
-                panic!("a CIDR block cannot be based on an unspecified address")
-            }
         }
     }
 
@@ -354,14 +316,8 @@ impl Cidr {
             (&Cidr::Ipv4(ref cidr), &Address::Ipv4(ref addr)) => cidr.contains_addr(addr),
             #[cfg(feature = "proto-ipv6")]
             (&Cidr::Ipv6(ref cidr), &Address::Ipv6(ref addr)) => cidr.contains_addr(addr),
-            #[cfg(all(feature = "proto-ipv6", feature = "proto-ipv4"))]
-            (&Cidr::Ipv4(_), &Address::Ipv6(_)) | (&Cidr::Ipv6(_), &Address::Ipv4(_)) => false,
-            (_, &Address::Unspecified) =>
-            // a fully unspecified address covers both IPv4 and IPv6,
-            // and no CIDR block can do that.
-            {
-                false
-            }
+            #[allow(unreachable_patterns)]
+            _ => false,
         }
     }
 
@@ -373,8 +329,8 @@ impl Cidr {
             (&Cidr::Ipv4(ref cidr), &Cidr::Ipv4(ref other)) => cidr.contains_subnet(other),
             #[cfg(feature = "proto-ipv6")]
             (&Cidr::Ipv6(ref cidr), &Cidr::Ipv6(ref other)) => cidr.contains_subnet(other),
-            #[cfg(all(feature = "proto-ipv6", feature = "proto-ipv4"))]
-            (&Cidr::Ipv4(_), &Cidr::Ipv6(_)) | (&Cidr::Ipv6(_), &Cidr::Ipv4(_)) => false,
+            #[allow(unreachable_patterns)]
+            _ => false,
         }
     }
 }
@@ -418,28 +374,20 @@ impl defmt::Format for Cidr {
 
 /// An internet endpoint address.
 ///
-/// An endpoint can be constructed from a port, in which case the address is unspecified.
-#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)]
+/// `Endpoint` always fully specifies both the address and the port.
+///
+/// See also ['ListenEndpoint'], which allows not specifying the address
+/// in order to listen on a given port on any address.
+#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
 pub struct Endpoint {
     pub addr: Address,
     pub port: u16,
 }
 
 impl Endpoint {
-    /// An endpoint with unspecified address and port.
-    pub const UNSPECIFIED: Endpoint = Endpoint {
-        addr: Address::Unspecified,
-        port: 0,
-    };
-
     /// Create an endpoint address from given address and port.
     pub fn new(addr: Address, port: u16) -> Endpoint {
-        Endpoint { addr, port }
-    }
-
-    /// Query whether the endpoint has a specified address and port.
-    pub fn is_specified(&self) -> bool {
-        !self.addr.is_unspecified() && self.port != 0
+        Endpoint { addr: addr, port }
     }
 }
 
@@ -486,19 +434,100 @@ impl defmt::Format for Endpoint {
     }
 }
 
-impl From<u16> for Endpoint {
-    fn from(port: u16) -> Endpoint {
+impl<T: Into<Address>> From<(T, u16)> for Endpoint {
+    fn from((addr, port): (T, u16)) -> Endpoint {
         Endpoint {
-            addr: Address::Unspecified,
+            addr: addr.into(),
             port,
         }
     }
 }
 
-impl<T: Into<Address>> From<(T, u16)> for Endpoint {
-    fn from((addr, port): (T, u16)) -> Endpoint {
-        Endpoint {
-            addr: addr.into(),
+/// An internet endpoint address for listening.
+///
+/// In contrast with [`Endpoint`], `ListenEndpoint` allows not specifying the address,
+/// in order to listen on a given port at all our addresses.
+///
+/// An endpoint can be constructed from a port, in which case the address is unspecified.
+#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)]
+pub struct ListenEndpoint {
+    pub addr: Option<Address>,
+    pub port: u16,
+}
+
+impl ListenEndpoint {
+    /// Query whether the endpoint has a specified address and port.
+    pub fn is_specified(&self) -> bool {
+        self.addr.is_some() && self.port != 0
+    }
+}
+
+#[cfg(all(feature = "std", feature = "proto-ipv4", feature = "proto-ipv6"))]
+impl From<::std::net::SocketAddr> for ListenEndpoint {
+    fn from(x: ::std::net::SocketAddr) -> ListenEndpoint {
+        ListenEndpoint {
+            addr: Some(x.ip().into()),
+            port: x.port(),
+        }
+    }
+}
+
+#[cfg(all(feature = "std", feature = "proto-ipv4"))]
+impl From<::std::net::SocketAddrV4> for ListenEndpoint {
+    fn from(x: ::std::net::SocketAddrV4) -> ListenEndpoint {
+        ListenEndpoint {
+            addr: Some((*x.ip()).into()),
+            port: x.port(),
+        }
+    }
+}
+
+#[cfg(all(feature = "std", feature = "proto-ipv6"))]
+impl From<::std::net::SocketAddrV6> for ListenEndpoint {
+    fn from(x: ::std::net::SocketAddrV6) -> ListenEndpoint {
+        ListenEndpoint {
+            addr: Some((*x.ip()).into()),
+            port: x.port(),
+        }
+    }
+}
+
+impl fmt::Display for ListenEndpoint {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        if let Some(addr) = self.addr {
+            write!(f, "{}:{}", addr, self.port)
+        } else {
+            write!(f, "*:{}", self.port)
+        }
+    }
+}
+
+#[cfg(feature = "defmt")]
+impl defmt::Format for ListenEndpoint {
+    fn format(&self, f: defmt::Formatter) {
+        defmt::write!(f, "{:?}:{=u16}", self.addr, self.port);
+    }
+}
+
+impl From<u16> for ListenEndpoint {
+    fn from(port: u16) -> ListenEndpoint {
+        ListenEndpoint { addr: None, port }
+    }
+}
+
+impl From<Endpoint> for ListenEndpoint {
+    fn from(endpoint: Endpoint) -> ListenEndpoint {
+        ListenEndpoint {
+            addr: Some(endpoint.addr),
+            port: endpoint.port,
+        }
+    }
+}
+
+impl<T: Into<Address>> From<(T, u16)> for ListenEndpoint {
+    fn from((addr, port): (T, u16)) -> ListenEndpoint {
+        ListenEndpoint {
+            addr: Some(addr.into()),
             port,
         }
     }
@@ -510,7 +539,6 @@ impl<T: Into<Address>> From<(T, u16)> for Endpoint {
 /// or IPv6 concrete high-level representation.
 #[derive(Debug, Clone, PartialEq, Eq)]
 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
-#[non_exhaustive]
 pub enum Repr {
     #[cfg(feature = "proto-ipv4")]
     Ipv4(Ipv4Repr),
@@ -562,6 +590,7 @@ impl Repr {
                 payload_len,
                 hop_limit,
             }),
+            #[allow(unreachable_patterns)]
             _ => panic!("IP version mismatch: src={:?} dst={:?}", src_addr, dst_addr),
         }
     }
@@ -618,17 +647,11 @@ impl Repr {
 
     /// Set the payload length.
     pub fn set_payload_len(&mut self, length: usize) {
-        match *self {
+        match self {
             #[cfg(feature = "proto-ipv4")]
-            Repr::Ipv4(Ipv4Repr {
-                ref mut payload_len,
-                ..
-            }) => *payload_len = length,
+            Repr::Ipv4(Ipv4Repr { payload_len, .. }) => *payload_len = length,
             #[cfg(feature = "proto-ipv6")]
-            Repr::Ipv6(Ipv6Repr {
-                ref mut payload_len,
-                ..
-            }) => *payload_len = length,
+            Repr::Ipv6(Ipv6Repr { payload_len, .. }) => *payload_len = length,
         }
     }
 
@@ -759,6 +782,7 @@ pub mod checksum {
                 ])
             }
 
+            #[allow(unreachable_patterns)]
             _ => panic!(
                 "Unexpected pseudo header addresses: {}, {}",
                 src_addr, dst_addr
@@ -892,11 +916,6 @@ pub(crate) mod test {
     #[cfg(feature = "proto-ipv4")]
     use crate::wire::{Ipv4Address, Ipv4Repr};
 
-    #[test]
-    fn endpoint_unspecified() {
-        assert!(!Endpoint::UNSPECIFIED.is_specified());
-    }
-
     #[test]
     #[cfg(feature = "proto-ipv4")]
     fn to_prefix_len_ipv4() {

+ 3 - 2
src/wire/mod.rs

@@ -157,8 +157,9 @@ pub use self::ieee802154::{
 };
 
 pub use self::ip::{
-    Address as IpAddress, Cidr as IpCidr, Endpoint as IpEndpoint, Protocol as IpProtocol,
-    Repr as IpRepr, Version as IpVersion,
+    Address as IpAddress, Cidr as IpCidr, Endpoint as IpEndpoint,
+    ListenEndpoint as IpListenEndpoint, Protocol as IpProtocol, Repr as IpRepr,
+    Version as IpVersion,
 };
 
 #[cfg(feature = "proto-ipv4")]

部分文件因为文件数量过多而无法显示