瀏覽代碼

socket: make dispatch infallible, except for emit errors.

socket dispatch can't be allowed to fail, it's too late to handle
errors there. For example, you can push a malformed packet to a
raw::Socket. Its dispatch() would detect that and error out, but *not*
pop the packet from the queue, resulting in the socket getting stuck forever!

Instead, all errors are now logged (and the offending packet is dropped if applicable).

Also, Err(Exhausted) was returned as "no error, but I have no packets to emit", but
that's not really useful to Interface, because it can already see if the
socket emitted packets or not.
Dario Nieuwenhuis 2 年之前
父節點
當前提交
629f0bce79
共有 6 個文件被更改,包括 250 次插入192 次删除
  1. 8 9
      src/socket/dhcpv4.rs
  2. 4 8
      src/socket/dns.rs
  3. 97 64
      src/socket/icmp.rs
  4. 82 62
      src/socket/raw.rs
  5. 8 8
      src/socket/tcp.rs
  6. 51 41
      src/socket/udp.rs

+ 8 - 9
src/socket/dhcpv4.rs

@@ -6,7 +6,6 @@ use crate::wire::{
     DhcpMessageType, DhcpPacket, DhcpRepr, IpAddress, IpProtocol, Ipv4Address, Ipv4Cidr, Ipv4Repr,
     UdpRepr, DHCP_CLIENT_PORT, DHCP_MAX_DNS_SERVER_COUNT, DHCP_SERVER_PORT, UDP_HEADER_LEN,
 };
-use crate::{Error, Result};
 
 use super::PollAt;
 
@@ -376,16 +375,16 @@ impl Socket {
         0x12345678
     }
 
-    pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<()>
+    pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
     where
-        F: FnOnce(&mut Context, (Ipv4Repr, UdpRepr, DhcpRepr)) -> Result<()>,
+        F: FnOnce(&mut Context, (Ipv4Repr, UdpRepr, DhcpRepr)) -> Result<(), E>,
     {
         // note: Dhcpv4Socket is only usable in ethernet mediums, so the
         // unwrap can never fail.
         let ethernet_addr = if let Some(HardwareAddress::Ethernet(addr)) = cx.hardware_addr() {
             addr
         } else {
-            return Err(Error::Malformed);
+            panic!("using DHCPv4 socket with a non-ethernet hardware address.");
         };
 
         // Worst case biggest IPv4 header length.
@@ -432,7 +431,7 @@ impl Socket {
         match &mut self.state {
             ClientState::Discovering(state) => {
                 if cx.now() < state.retry_at {
-                    return Err(Error::Exhausted);
+                    return Ok(());
                 }
 
                 // send packet
@@ -451,13 +450,12 @@ impl Socket {
             }
             ClientState::Requesting(state) => {
                 if cx.now() < state.retry_at {
-                    return Err(Error::Exhausted);
+                    return Ok(());
                 }
 
                 if state.retry >= REQUEST_RETRIES {
                     net_debug!("DHCP request retries exceeded, restarting discovery");
                     self.reset();
-                    // return Ok so we get polled again
                     return Ok(());
                 }
 
@@ -489,7 +487,7 @@ impl Socket {
                 }
 
                 if cx.now() < state.renew_at {
-                    return Err(Error::Exhausted);
+                    return Ok(());
                 }
 
                 ipv4_repr.src_addr = state.config.address.address();
@@ -553,6 +551,7 @@ mod test {
 
     use super::*;
     use crate::wire::EthernetAddress;
+    use crate::Error;
 
     // =========================================================================================//
     // Helper functions
@@ -622,7 +621,7 @@ mod test {
                         None => panic!("Too many reprs emitted"),
                     }
                     i += 1;
-                    Ok(())
+                    Ok::<_, Error>(())
                 });
         }
 

+ 4 - 8
src/socket/dns.rs

@@ -8,7 +8,6 @@ use crate::socket::{Context, PollAt};
 use crate::time::{Duration, Instant};
 use crate::wire::dns::{Flags, Opcode, Packet, Question, Rcode, Record, RecordData, Repr, Type};
 use crate::wire::{self, IpAddress, IpProtocol, IpRepr, UdpRepr};
-use crate::Error;
 
 #[cfg(feature = "async")]
 use super::WakerRegistration;
@@ -474,9 +473,9 @@ impl<'a> Socket<'a> {
         net_trace!("no query matched");
     }
 
-    pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<(), Error>
+    pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
     where
-        F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), Error>,
+        F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), E>,
     {
         let hop_limit = self.hop_limit.unwrap_or(64);
 
@@ -556,10 +555,7 @@ impl<'a> Socket<'a> {
                     udp_repr.src_port
                 );
 
-                if let Err(e) = emit(cx, (ip_repr, udp_repr, payload)) {
-                    net_trace!("DNS emit error {:?}", e);
-                    return Ok(());
-                }
+                emit(cx, (ip_repr, udp_repr, payload))?;
 
                 pq.retransmit_at = cx.now() + pq.delay;
                 pq.delay = MAX_RETRANSMIT_DELAY.min(pq.delay * 2);
@@ -569,7 +565,7 @@ impl<'a> Socket<'a> {
         }
 
         // Nothing to dispatch
-        Err(Error::Exhausted)
+        Ok(())
     }
 
     pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt {

+ 97 - 64
src/socket/icmp.rs

@@ -6,8 +6,8 @@ use crate::phy::ChecksumCapabilities;
 #[cfg(feature = "async")]
 use crate::socket::WakerRegistration;
 use crate::socket::{Context, PollAt};
-use crate::Error;
 
+use crate::storage::Empty;
 use crate::wire::IcmpRepr;
 #[cfg(feature = "proto-ipv4")]
 use crate::wire::{Icmpv4Packet, Icmpv4Repr, Ipv4Repr};
@@ -451,66 +451,98 @@ impl<'a> Socket<'a> {
         self.rx_waker.wake();
     }
 
-    pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<(), Error>
+    pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
     where
-        F: FnOnce(&mut Context, (IpRepr, IcmpRepr)) -> Result<(), Error>,
+        F: FnOnce(&mut Context, (IpRepr, IcmpRepr)) -> Result<(), E>,
     {
         let hop_limit = self.hop_limit.unwrap_or(64);
-        self.tx_buffer
-            .dequeue_with(|remote_endpoint, packet_buf| {
-                net_trace!(
-                    "icmp:{}: sending {} octets",
-                    remote_endpoint,
-                    packet_buf.len()
-                );
-                match *remote_endpoint {
-                    #[cfg(feature = "proto-ipv4")]
-                    IpAddress::Ipv4(dst_addr) => {
-                        let src_addr = match cx.get_source_address_ipv4(dst_addr) {
-                            Some(addr) => addr,
-                            None => return Err(Error::Unaddressable),
-                        };
-                        let packet = Icmpv4Packet::new_unchecked(&*packet_buf);
-                        let repr = Icmpv4Repr::parse(&packet, &ChecksumCapabilities::ignored())?;
-                        let ip_repr = IpRepr::Ipv4(Ipv4Repr {
-                            src_addr,
-                            dst_addr,
-                            next_header: IpProtocol::Icmp,
-                            payload_len: repr.buffer_len(),
-                            hop_limit: hop_limit,
-                        });
-                        emit(cx, (ip_repr, IcmpRepr::Ipv4(repr)))
-                    }
-                    #[cfg(feature = "proto-ipv6")]
-                    IpAddress::Ipv6(dst_addr) => {
-                        let src_addr = match cx.get_source_address_ipv6(dst_addr) {
-                            Some(addr) => addr,
-                            None => return Err(Error::Unaddressable),
-                        };
-                        let packet = Icmpv6Packet::new_unchecked(&*packet_buf);
-                        let repr = Icmpv6Repr::parse(
-                            &src_addr.into(),
-                            &dst_addr.into(),
-                            &packet,
-                            &ChecksumCapabilities::ignored(),
-                        )?;
-                        let ip_repr = IpRepr::Ipv6(Ipv6Repr {
-                            src_addr,
-                            dst_addr,
-                            next_header: IpProtocol::Icmpv6,
-                            payload_len: repr.buffer_len(),
-                            hop_limit: hop_limit,
-                        });
-                        emit(cx, (ip_repr, IcmpRepr::Ipv6(repr)))
-                    }
+        let res = self.tx_buffer.dequeue_with(|remote_endpoint, packet_buf| {
+            net_trace!(
+                "icmp:{}: sending {} octets",
+                remote_endpoint,
+                packet_buf.len()
+            );
+            match *remote_endpoint {
+                #[cfg(feature = "proto-ipv4")]
+                IpAddress::Ipv4(dst_addr) => {
+                    let src_addr = match cx.get_source_address_ipv4(dst_addr) {
+                        Some(addr) => addr,
+                        None => {
+                            net_trace!(
+                                "icmp:{}: not find suitable source address, dropping",
+                                remote_endpoint
+                            );
+                            return Ok(());
+                        }
+                    };
+                    let packet = Icmpv4Packet::new_unchecked(&*packet_buf);
+                    let repr = match Icmpv4Repr::parse(&packet, &ChecksumCapabilities::ignored()) {
+                        Ok(x) => x,
+                        Err(_) => {
+                            net_trace!(
+                                "icmp:{}: malformed packet in queue, dropping",
+                                remote_endpoint
+                            );
+                            return Ok(());
+                        }
+                    };
+                    let ip_repr = IpRepr::Ipv4(Ipv4Repr {
+                        src_addr,
+                        dst_addr,
+                        next_header: IpProtocol::Icmp,
+                        payload_len: repr.buffer_len(),
+                        hop_limit: hop_limit,
+                    });
+                    emit(cx, (ip_repr, IcmpRepr::Ipv4(repr)))
                 }
-            })
-            .map_err(|_| Error::Exhausted)??;
-
-        #[cfg(feature = "async")]
-        self.tx_waker.wake();
-
-        Ok(())
+                #[cfg(feature = "proto-ipv6")]
+                IpAddress::Ipv6(dst_addr) => {
+                    let src_addr = match cx.get_source_address_ipv6(dst_addr) {
+                        Some(addr) => addr,
+                        None => {
+                            net_trace!(
+                                "icmp:{}: not find suitable source address, dropping",
+                                remote_endpoint
+                            );
+                            return Ok(());
+                        }
+                    };
+                    let packet = Icmpv6Packet::new_unchecked(&*packet_buf);
+                    let repr = match Icmpv6Repr::parse(
+                        &src_addr.into(),
+                        &dst_addr.into(),
+                        &packet,
+                        &ChecksumCapabilities::ignored(),
+                    ) {
+                        Ok(x) => x,
+                        Err(_) => {
+                            net_trace!(
+                                "icmp:{}: malformed packet in queue, dropping",
+                                remote_endpoint
+                            );
+                            return Ok(());
+                        }
+                    };
+                    let ip_repr = IpRepr::Ipv6(Ipv6Repr {
+                        src_addr,
+                        dst_addr,
+                        next_header: IpProtocol::Icmpv6,
+                        payload_len: repr.buffer_len(),
+                        hop_limit: hop_limit,
+                    });
+                    emit(cx, (ip_repr, IcmpRepr::Ipv6(repr)))
+                }
+            }
+        });
+        match res {
+            Err(Empty) => Ok(()),
+            Ok(Err(e)) => Err(e),
+            Ok(Ok(())) => {
+                #[cfg(feature = "async")]
+                self.tx_waker.wake();
+                Ok(())
+            }
+        }
     }
 
     pub(crate) fn poll_at(&self, _cx: &mut Context) -> PollAt {
@@ -552,8 +584,8 @@ mod tests_common {
 #[cfg(all(test, feature = "proto-ipv4"))]
 mod test_ipv4 {
     use super::tests_common::*;
-
     use crate::wire::{Icmpv4DstUnreachable, IpEndpoint, Ipv4Address};
+    use crate::Error;
 
     const REMOTE_IPV4: Ipv4Address = Ipv4Address([192, 168, 1, 2]);
     const LOCAL_IPV4: Ipv4Address = Ipv4Address([192, 168, 1, 1]);
@@ -602,7 +634,7 @@ mod test_ipv4 {
 
         assert_eq!(
             socket.dispatch(&mut cx, |_, _| unreachable!()),
-            Err(Error::Exhausted)
+            Ok::<_, ()>(())
         );
 
         // This buffer is too long
@@ -641,7 +673,7 @@ mod test_ipv4 {
             socket.dispatch(&mut cx, |_, (ip_repr, icmp_repr)| {
                 assert_eq!(ip_repr, LOCAL_IPV4_REPR);
                 assert_eq!(icmp_repr, ECHOV4_REPR.into());
-                Ok(())
+                Ok::<_, Error>(())
             }),
             Ok(())
         );
@@ -677,7 +709,7 @@ mod test_ipv4 {
                         hop_limit: 0x2a,
                     })
                 );
-                Ok(())
+                Ok::<_, Error>(())
             }),
             Ok(())
         );
@@ -795,6 +827,7 @@ mod test_ipv6 {
     use super::tests_common::*;
 
     use crate::wire::{Icmpv6DstUnreachable, IpEndpoint, Ipv6Address};
+    use crate::Error;
 
     const REMOTE_IPV6: Ipv6Address =
         Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]);
@@ -844,7 +877,7 @@ mod test_ipv6 {
 
         assert_eq!(
             socket.dispatch(&mut cx, |_, _| unreachable!()),
-            Err(Error::Exhausted)
+            Ok::<_, Error>(())
         );
 
         // This buffer is too long
@@ -888,7 +921,7 @@ mod test_ipv6 {
             socket.dispatch(&mut cx, |_, (ip_repr, icmp_repr)| {
                 assert_eq!(ip_repr, LOCAL_IPV6_REPR);
                 assert_eq!(icmp_repr, ECHOV6_REPR.into());
-                Ok(())
+                Ok::<_, Error>(())
             }),
             Ok(())
         );
@@ -929,7 +962,7 @@ mod test_ipv6 {
                         hop_limit: 0x2a,
                     })
                 );
-                Ok(())
+                Ok::<_, Error>(())
             }),
             Ok(())
         );

+ 82 - 62
src/socket/raw.rs

@@ -3,12 +3,11 @@ use core::cmp::min;
 use core::task::Waker;
 
 use crate::iface::Context;
-use crate::phy::ChecksumCapabilities;
 use crate::socket::PollAt;
 #[cfg(feature = "async")]
 use crate::socket::WakerRegistration;
-use crate::Error;
 
+use crate::storage::Empty;
 use crate::wire::{IpProtocol, IpRepr, IpVersion};
 #[cfg(feature = "proto-ipv4")]
 use crate::wire::{Ipv4Packet, Ipv4Repr};
@@ -265,21 +264,27 @@ impl<'a> Socket<'a> {
         self.rx_waker.wake();
     }
 
-    pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<(), Error>
+    pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
     where
-        F: FnOnce(&mut Context, (IpRepr, &[u8])) -> Result<(), Error>,
+        F: FnOnce(&mut Context, (IpRepr, &[u8])) -> Result<(), E>,
     {
-        fn prepare<'a>(
-            next_header: IpProtocol,
-            buffer: &'a mut [u8],
-            _checksum_caps: &ChecksumCapabilities,
-        ) -> Result<(IpRepr, &'a [u8]), Error> {
-            match IpVersion::of_packet(buffer)? {
+        let ip_protocol = self.ip_protocol;
+        let ip_version = self.ip_version;
+        let _checksum_caps = &cx.checksum_caps();
+        let res = self.tx_buffer.dequeue_with(|&mut (), buffer| {
+            match IpVersion::of_packet(buffer) {
                 #[cfg(feature = "proto-ipv4")]
-                IpVersion::Ipv4 => {
-                    let mut packet = Ipv4Packet::new_checked(buffer)?;
-                    if packet.next_header() != next_header {
-                        return Err(Error::Unaddressable);
+                Ok(IpVersion::Ipv4) => {
+                    let mut packet = match Ipv4Packet::new_checked(buffer) {
+                        Ok(x) => x,
+                        Err(_) => {
+                            net_trace!("raw: malformed ipv6 packet in queue, dropping.");
+                            return Ok(());
+                        }
+                    };
+                    if packet.next_header() != ip_protocol {
+                        net_trace!("raw: sent packet with wrong ip protocol, dropping.");
+                        return Ok(());
                     }
                     if _checksum_caps.ipv4.tx() {
                         packet.fill_checksum();
@@ -289,55 +294,57 @@ impl<'a> Socket<'a> {
                         packet.set_checksum(0);
                     }
 
-                    let packet = Ipv4Packet::new_checked(&*packet.into_inner())?;
-                    let ipv4_repr = Ipv4Repr::parse(&packet, _checksum_caps)?;
-                    Ok((IpRepr::Ipv4(ipv4_repr), packet.payload()))
+                    let packet = Ipv4Packet::new_unchecked(&*packet.into_inner());
+                    let ipv4_repr = match Ipv4Repr::parse(&packet, _checksum_caps) {
+                        Ok(x) => x,
+                        Err(_) => {
+                            net_trace!("raw: malformed ipv4 packet in queue, dropping.");
+                            return Ok(());
+                        }
+                    };
+                    net_trace!("raw:{}:{}: sending", ip_version, ip_protocol);
+                    emit(cx, (IpRepr::Ipv4(ipv4_repr), packet.payload()))
                 }
                 #[cfg(feature = "proto-ipv6")]
-                IpVersion::Ipv6 => {
-                    let packet = Ipv6Packet::new_checked(buffer)?;
-                    if packet.next_header() != next_header {
-                        return Err(Error::Unaddressable);
+                Ok(IpVersion::Ipv6) => {
+                    let packet = match Ipv6Packet::new_checked(buffer) {
+                        Ok(x) => x,
+                        Err(_) => {
+                            net_trace!("raw: malformed ipv6 packet in queue, dropping.");
+                            return Ok(());
+                        }
+                    };
+                    if packet.next_header() != ip_protocol {
+                        net_trace!("raw: sent ipv6 packet with wrong ip protocol, dropping.");
+                        return Ok(());
                     }
                     let packet = Ipv6Packet::new_unchecked(&*packet.into_inner());
-                    let ipv6_repr = Ipv6Repr::parse(&packet)?;
-                    Ok((IpRepr::Ipv6(ipv6_repr), packet.payload()))
+                    let ipv6_repr = match Ipv6Repr::parse(&packet) {
+                        Ok(x) => x,
+                        Err(_) => {
+                            net_trace!("raw: malformed ipv6 packet in queue, dropping.");
+                            return Ok(());
+                        }
+                    };
+
+                    net_trace!("raw:{}:{}: sending", ip_version, ip_protocol);
+                    emit(cx, (IpRepr::Ipv6(ipv6_repr), packet.payload()))
+                }
+                Err(_) => {
+                    net_trace!("raw: sent packet with invalid IP version, dropping.");
+                    Ok(())
                 }
             }
+        });
+        match res {
+            Err(Empty) => Ok(()),
+            Ok(Err(e)) => Err(e),
+            Ok(Ok(())) => {
+                #[cfg(feature = "async")]
+                self.tx_waker.wake();
+                Ok(())
+            }
         }
-
-        let ip_protocol = self.ip_protocol;
-        let ip_version = self.ip_version;
-        self.tx_buffer
-            .dequeue_with(|&mut (), packet_buf| {
-                match prepare(ip_protocol, packet_buf, &cx.checksum_caps()) {
-                    Ok((ip_repr, raw_packet)) => {
-                        net_trace!(
-                            "raw:{}:{}: sending {} octets",
-                            ip_version,
-                            ip_protocol,
-                            ip_repr.buffer_len() + raw_packet.len()
-                        );
-                        emit(cx, (ip_repr, raw_packet))
-                    }
-                    Err(error) => {
-                        net_debug!(
-                            "raw:{}:{}: dropping outgoing packet ({})",
-                            ip_version,
-                            ip_protocol,
-                            error
-                        );
-                        // Return Ok(()) so the packet is dequeued.
-                        Ok(())
-                    }
-                }
-            })
-            .map_err(|_| Error::Exhausted)??;
-
-        #[cfg(feature = "async")]
-        self.tx_waker.wake();
-
-        Ok(())
     }
 
     pub(crate) fn poll_at(&self, _cx: &mut Context) -> PollAt {
@@ -357,6 +364,7 @@ mod test {
     use crate::wire::{Ipv4Address, Ipv4Repr};
     #[cfg(feature = "proto-ipv6")]
     use crate::wire::{Ipv6Address, Ipv6Repr};
+    use crate::Error;
 
     fn buffer(packets: usize) -> PacketBuffer<'static> {
         PacketBuffer::new(vec![PacketMetadata::EMPTY; packets], vec![0; 48 * packets])
@@ -453,7 +461,7 @@ mod test {
                     assert!(socket.can_send());
                     assert_eq!(
                         socket.dispatch(&mut cx, |_, _| unreachable!()),
-                        Err(Error::Exhausted)
+                        Ok::<_, Error>(())
                     );
 
                     assert_eq!(socket.send_slice(&$packet[..]), Ok(()));
@@ -474,7 +482,7 @@ mod test {
                         socket.dispatch(&mut cx, |_, (ip_repr, ip_payload)| {
                             assert_eq!(ip_repr, $hdr);
                             assert_eq!(ip_payload, &$payload);
-                            Ok(())
+                            Ok::<_, Error>(())
                         }),
                         Ok(())
                     );
@@ -539,13 +547,19 @@ mod test {
             Ipv4Packet::new_unchecked(&mut wrong_version).set_version(6);
 
             assert_eq!(socket.send_slice(&wrong_version[..]), Ok(()));
-            assert_eq!(socket.dispatch(&mut cx, |_, _| unreachable!()), Ok(()));
+            assert_eq!(
+                socket.dispatch(&mut cx, |_, _| unreachable!()),
+                Ok::<_, Error>(())
+            );
 
             let mut wrong_protocol = ipv4_locals::PACKET_BYTES;
             Ipv4Packet::new_unchecked(&mut wrong_protocol).set_next_header(IpProtocol::Tcp);
 
             assert_eq!(socket.send_slice(&wrong_protocol[..]), Ok(()));
-            assert_eq!(socket.dispatch(&mut cx, |_, _| unreachable!()), Ok(()));
+            assert_eq!(
+                socket.dispatch(&mut cx, |_, _| unreachable!()),
+                Ok::<_, Error>(())
+            );
         }
         #[cfg(feature = "proto-ipv6")]
         {
@@ -556,13 +570,19 @@ mod test {
             Ipv6Packet::new_unchecked(&mut wrong_version[..]).set_version(4);
 
             assert_eq!(socket.send_slice(&wrong_version[..]), Ok(()));
-            assert_eq!(socket.dispatch(&mut cx, |_, _| unreachable!()), Ok(()));
+            assert_eq!(
+                socket.dispatch(&mut cx, |_, _| unreachable!()),
+                Ok::<_, Error>(())
+            );
 
             let mut wrong_protocol = ipv6_locals::PACKET_BYTES;
             Ipv6Packet::new_unchecked(&mut wrong_protocol[..]).set_next_header(IpProtocol::Tcp);
 
             assert_eq!(socket.send_slice(&wrong_protocol[..]), Ok(()));
-            assert_eq!(socket.dispatch(&mut cx, |_, _| unreachable!()), Ok(()));
+            assert_eq!(
+                socket.dispatch(&mut cx, |_, _| unreachable!()),
+                Ok::<_, Error>(())
+            );
         }
     }
 

+ 8 - 8
src/socket/tcp.rs

@@ -16,7 +16,6 @@ use crate::wire::{
     IpAddress, IpEndpoint, IpListenEndpoint, IpProtocol, IpRepr, TcpControl, TcpRepr, TcpSeqNumber,
     TCP_HEADER_LEN,
 };
-use crate::Error;
 
 macro_rules! tcp_trace {
     ($($arg:expr),*) => (net_log!(trace, $($arg),*));
@@ -1934,12 +1933,12 @@ impl<'a> Socket<'a> {
         }
     }
 
-    pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<(), Error>
+    pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
     where
-        F: FnOnce(&mut Context, (IpRepr, TcpRepr)) -> Result<(), Error>,
+        F: FnOnce(&mut Context, (IpRepr, TcpRepr)) -> Result<(), E>,
     {
         if self.tuple.is_none() {
-            return Err(Error::Exhausted);
+            return Ok(());
         }
 
         if self.remote_last_ts.is_none() {
@@ -1999,9 +1998,9 @@ impl<'a> Socket<'a> {
             // If we have spent enough time in the TIME-WAIT state, close the socket.
             tcp_trace!("TIME-WAIT timer expired");
             self.reset();
-            return Err(Error::Exhausted);
+            return Ok(());
         } else {
-            return Err(Error::Exhausted);
+            return Ok(());
         }
 
         // NOTE(unwrap): we check tuple is not None the first thing in this function.
@@ -2041,7 +2040,7 @@ impl<'a> Socket<'a> {
             }
 
             // We never transmit anything in the LISTEN state.
-            State::Listen => return Err(Error::Exhausted),
+            State::Listen => return Ok(()),
 
             // We transmit a SYN in the SYN-SENT state.
             // We transmit a SYN|ACK in the SYN-RECEIVED state.
@@ -2270,6 +2269,7 @@ impl<'a> fmt::Write for Socket<'a> {
 mod test {
     use super::*;
     use crate::wire::IpRepr;
+    use crate::Error;
     use core::i32;
     use std::ops::{Deref, DerefMut};
     use std::vec::Vec;
@@ -6168,7 +6168,7 @@ mod test {
         assert_eq!(
             s.socket.dispatch(&mut s.cx, |_, (ip_repr, _)| {
                 assert_eq!(ip_repr.hop_limit(), 0x2a);
-                Ok(())
+                Ok::<_, Error>(())
             }),
             Ok(())
         );

+ 51 - 41
src/socket/udp.rs

@@ -6,8 +6,8 @@ use crate::iface::Context;
 use crate::socket::PollAt;
 #[cfg(feature = "async")]
 use crate::socket::WakerRegistration;
+use crate::storage::Empty;
 use crate::wire::{IpEndpoint, IpListenEndpoint, IpProtocol, IpRepr, UdpRepr};
-use crate::Error;
 
 /// A UDP packet metadata.
 pub type PacketMetadata = crate::storage::PacketMetadata<IpEndpoint>;
@@ -383,49 +383,58 @@ impl<'a> Socket<'a> {
         self.rx_waker.wake();
     }
 
-    pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<(), Error>
+    pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
     where
-        F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), Error>,
+        F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), E>,
     {
         let endpoint = self.endpoint;
         let hop_limit = self.hop_limit.unwrap_or(64);
 
-        self.tx_buffer
-            .dequeue_with(|remote_endpoint, payload_buf| {
-                let src_addr = match endpoint.addr {
+        let res = self.tx_buffer.dequeue_with(|remote_endpoint, payload_buf| {
+            let src_addr = match endpoint.addr {
+                Some(addr) => addr,
+                None => match cx.get_source_address(remote_endpoint.addr) {
                     Some(addr) => addr,
-                    None => match cx.get_source_address(remote_endpoint.addr) {
-                        Some(addr) => addr,
-                        None => return Err(Error::Unaddressable),
-                    },
-                };
-
-                net_trace!(
-                    "udp:{}:{}: sending {} octets",
-                    endpoint,
-                    remote_endpoint,
-                    payload_buf.len()
-                );
-
-                let repr = UdpRepr {
-                    src_port: endpoint.port,
-                    dst_port: remote_endpoint.port,
-                };
-                let ip_repr = IpRepr::new(
-                    src_addr,
-                    remote_endpoint.addr,
-                    IpProtocol::Udp,
-                    repr.header_len() + payload_buf.len(),
-                    hop_limit,
-                );
-                emit(cx, (ip_repr, repr, payload_buf))
-            })
-            .map_err(|_| Error::Exhausted)??;
-
-        #[cfg(feature = "async")]
-        self.tx_waker.wake();
-
-        Ok(())
+                    None => {
+                        net_trace!(
+                            "udp:{}:{}: cannot find suitable source address, dropping.",
+                            endpoint,
+                            remote_endpoint
+                        );
+                        return Ok(());
+                    }
+                },
+            };
+
+            net_trace!(
+                "udp:{}:{}: sending {} octets",
+                endpoint,
+                remote_endpoint,
+                payload_buf.len()
+            );
+
+            let repr = UdpRepr {
+                src_port: endpoint.port,
+                dst_port: remote_endpoint.port,
+            };
+            let ip_repr = IpRepr::new(
+                src_addr,
+                remote_endpoint.addr,
+                IpProtocol::Udp,
+                repr.header_len() + payload_buf.len(),
+                hop_limit,
+            );
+            emit(cx, (ip_repr, repr, payload_buf))
+        });
+        match res {
+            Err(Empty) => Ok(()),
+            Ok(Err(e)) => Err(e),
+            Ok(Ok(())) => {
+                #[cfg(feature = "async")]
+                self.tx_waker.wake();
+                Ok(())
+            }
+        }
     }
 
     pub(crate) fn poll_at(&self, _cx: &mut Context) -> PollAt {
@@ -441,6 +450,7 @@ impl<'a> Socket<'a> {
 mod test {
     use super::*;
     use crate::wire::{IpRepr, UdpRepr};
+    use crate::Error;
 
     fn buffer(packets: usize) -> PacketBuffer<'static> {
         PacketBuffer::new(vec![PacketMetadata::EMPTY; packets], vec![0; 16 * packets])
@@ -589,7 +599,7 @@ mod test {
         assert!(socket.can_send());
         assert_eq!(
             socket.dispatch(&mut cx, |_, _| unreachable!()),
-            Err(Error::Exhausted)
+            Ok::<_, Error>(())
         );
 
         assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Ok(()));
@@ -615,7 +625,7 @@ mod test {
                 assert_eq!(ip_repr, LOCAL_IP_REPR);
                 assert_eq!(udp_repr, LOCAL_UDP_REPR);
                 assert_eq!(payload, PAYLOAD);
-                Ok(())
+                Ok::<_, Error>(())
             }),
             Ok(())
         );
@@ -711,7 +721,7 @@ mod test {
                         hop_limit: 0x2a,
                     })
                 );
-                Ok(())
+                Ok::<_, Error>(())
             }),
             Ok(())
         );