فهرست منبع

socket: make dispatch infallible, except for emit errors.

socket dispatch can't be allowed to fail, it's too late to handle
errors there. For example, you can push a malformed packet to a
raw::Socket. Its dispatch() would detect that and error out, but *not*
pop the packet from the queue, resulting in the socket getting stuck forever!

Instead, all errors are now logged (and the offending packet is dropped if applicable).

Also, Err(Exhausted) was returned as "no error, but I have no packets to emit", but
that's not really useful to Interface, because it can already see if the
socket emitted packets or not.
Dario Nieuwenhuis 2 سال پیش
والد
کامیت
629f0bce79
6فایلهای تغییر یافته به همراه250 افزوده شده و 192 حذف شده
  1. 8 9
      src/socket/dhcpv4.rs
  2. 4 8
      src/socket/dns.rs
  3. 97 64
      src/socket/icmp.rs
  4. 82 62
      src/socket/raw.rs
  5. 8 8
      src/socket/tcp.rs
  6. 51 41
      src/socket/udp.rs

+ 8 - 9
src/socket/dhcpv4.rs

@@ -6,7 +6,6 @@ use crate::wire::{
     DhcpMessageType, DhcpPacket, DhcpRepr, IpAddress, IpProtocol, Ipv4Address, Ipv4Cidr, Ipv4Repr,
     DhcpMessageType, DhcpPacket, DhcpRepr, IpAddress, IpProtocol, Ipv4Address, Ipv4Cidr, Ipv4Repr,
     UdpRepr, DHCP_CLIENT_PORT, DHCP_MAX_DNS_SERVER_COUNT, DHCP_SERVER_PORT, UDP_HEADER_LEN,
     UdpRepr, DHCP_CLIENT_PORT, DHCP_MAX_DNS_SERVER_COUNT, DHCP_SERVER_PORT, UDP_HEADER_LEN,
 };
 };
-use crate::{Error, Result};
 
 
 use super::PollAt;
 use super::PollAt;
 
 
@@ -376,16 +375,16 @@ impl Socket {
         0x12345678
         0x12345678
     }
     }
 
 
-    pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<()>
+    pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
     where
     where
-        F: FnOnce(&mut Context, (Ipv4Repr, UdpRepr, DhcpRepr)) -> Result<()>,
+        F: FnOnce(&mut Context, (Ipv4Repr, UdpRepr, DhcpRepr)) -> Result<(), E>,
     {
     {
         // note: Dhcpv4Socket is only usable in ethernet mediums, so the
         // note: Dhcpv4Socket is only usable in ethernet mediums, so the
         // unwrap can never fail.
         // unwrap can never fail.
         let ethernet_addr = if let Some(HardwareAddress::Ethernet(addr)) = cx.hardware_addr() {
         let ethernet_addr = if let Some(HardwareAddress::Ethernet(addr)) = cx.hardware_addr() {
             addr
             addr
         } else {
         } else {
-            return Err(Error::Malformed);
+            panic!("using DHCPv4 socket with a non-ethernet hardware address.");
         };
         };
 
 
         // Worst case biggest IPv4 header length.
         // Worst case biggest IPv4 header length.
@@ -432,7 +431,7 @@ impl Socket {
         match &mut self.state {
         match &mut self.state {
             ClientState::Discovering(state) => {
             ClientState::Discovering(state) => {
                 if cx.now() < state.retry_at {
                 if cx.now() < state.retry_at {
-                    return Err(Error::Exhausted);
+                    return Ok(());
                 }
                 }
 
 
                 // send packet
                 // send packet
@@ -451,13 +450,12 @@ impl Socket {
             }
             }
             ClientState::Requesting(state) => {
             ClientState::Requesting(state) => {
                 if cx.now() < state.retry_at {
                 if cx.now() < state.retry_at {
-                    return Err(Error::Exhausted);
+                    return Ok(());
                 }
                 }
 
 
                 if state.retry >= REQUEST_RETRIES {
                 if state.retry >= REQUEST_RETRIES {
                     net_debug!("DHCP request retries exceeded, restarting discovery");
                     net_debug!("DHCP request retries exceeded, restarting discovery");
                     self.reset();
                     self.reset();
-                    // return Ok so we get polled again
                     return Ok(());
                     return Ok(());
                 }
                 }
 
 
@@ -489,7 +487,7 @@ impl Socket {
                 }
                 }
 
 
                 if cx.now() < state.renew_at {
                 if cx.now() < state.renew_at {
-                    return Err(Error::Exhausted);
+                    return Ok(());
                 }
                 }
 
 
                 ipv4_repr.src_addr = state.config.address.address();
                 ipv4_repr.src_addr = state.config.address.address();
@@ -553,6 +551,7 @@ mod test {
 
 
     use super::*;
     use super::*;
     use crate::wire::EthernetAddress;
     use crate::wire::EthernetAddress;
+    use crate::Error;
 
 
     // =========================================================================================//
     // =========================================================================================//
     // Helper functions
     // Helper functions
@@ -622,7 +621,7 @@ mod test {
                         None => panic!("Too many reprs emitted"),
                         None => panic!("Too many reprs emitted"),
                     }
                     }
                     i += 1;
                     i += 1;
-                    Ok(())
+                    Ok::<_, Error>(())
                 });
                 });
         }
         }
 
 

+ 4 - 8
src/socket/dns.rs

@@ -8,7 +8,6 @@ use crate::socket::{Context, PollAt};
 use crate::time::{Duration, Instant};
 use crate::time::{Duration, Instant};
 use crate::wire::dns::{Flags, Opcode, Packet, Question, Rcode, Record, RecordData, Repr, Type};
 use crate::wire::dns::{Flags, Opcode, Packet, Question, Rcode, Record, RecordData, Repr, Type};
 use crate::wire::{self, IpAddress, IpProtocol, IpRepr, UdpRepr};
 use crate::wire::{self, IpAddress, IpProtocol, IpRepr, UdpRepr};
-use crate::Error;
 
 
 #[cfg(feature = "async")]
 #[cfg(feature = "async")]
 use super::WakerRegistration;
 use super::WakerRegistration;
@@ -474,9 +473,9 @@ impl<'a> Socket<'a> {
         net_trace!("no query matched");
         net_trace!("no query matched");
     }
     }
 
 
-    pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<(), Error>
+    pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
     where
     where
-        F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), Error>,
+        F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), E>,
     {
     {
         let hop_limit = self.hop_limit.unwrap_or(64);
         let hop_limit = self.hop_limit.unwrap_or(64);
 
 
@@ -556,10 +555,7 @@ impl<'a> Socket<'a> {
                     udp_repr.src_port
                     udp_repr.src_port
                 );
                 );
 
 
-                if let Err(e) = emit(cx, (ip_repr, udp_repr, payload)) {
-                    net_trace!("DNS emit error {:?}", e);
-                    return Ok(());
-                }
+                emit(cx, (ip_repr, udp_repr, payload))?;
 
 
                 pq.retransmit_at = cx.now() + pq.delay;
                 pq.retransmit_at = cx.now() + pq.delay;
                 pq.delay = MAX_RETRANSMIT_DELAY.min(pq.delay * 2);
                 pq.delay = MAX_RETRANSMIT_DELAY.min(pq.delay * 2);
@@ -569,7 +565,7 @@ impl<'a> Socket<'a> {
         }
         }
 
 
         // Nothing to dispatch
         // Nothing to dispatch
-        Err(Error::Exhausted)
+        Ok(())
     }
     }
 
 
     pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt {
     pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt {

+ 97 - 64
src/socket/icmp.rs

@@ -6,8 +6,8 @@ use crate::phy::ChecksumCapabilities;
 #[cfg(feature = "async")]
 #[cfg(feature = "async")]
 use crate::socket::WakerRegistration;
 use crate::socket::WakerRegistration;
 use crate::socket::{Context, PollAt};
 use crate::socket::{Context, PollAt};
-use crate::Error;
 
 
+use crate::storage::Empty;
 use crate::wire::IcmpRepr;
 use crate::wire::IcmpRepr;
 #[cfg(feature = "proto-ipv4")]
 #[cfg(feature = "proto-ipv4")]
 use crate::wire::{Icmpv4Packet, Icmpv4Repr, Ipv4Repr};
 use crate::wire::{Icmpv4Packet, Icmpv4Repr, Ipv4Repr};
@@ -451,66 +451,98 @@ impl<'a> Socket<'a> {
         self.rx_waker.wake();
         self.rx_waker.wake();
     }
     }
 
 
-    pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<(), Error>
+    pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
     where
     where
-        F: FnOnce(&mut Context, (IpRepr, IcmpRepr)) -> Result<(), Error>,
+        F: FnOnce(&mut Context, (IpRepr, IcmpRepr)) -> Result<(), E>,
     {
     {
         let hop_limit = self.hop_limit.unwrap_or(64);
         let hop_limit = self.hop_limit.unwrap_or(64);
-        self.tx_buffer
-            .dequeue_with(|remote_endpoint, packet_buf| {
-                net_trace!(
-                    "icmp:{}: sending {} octets",
-                    remote_endpoint,
-                    packet_buf.len()
-                );
-                match *remote_endpoint {
-                    #[cfg(feature = "proto-ipv4")]
-                    IpAddress::Ipv4(dst_addr) => {
-                        let src_addr = match cx.get_source_address_ipv4(dst_addr) {
-                            Some(addr) => addr,
-                            None => return Err(Error::Unaddressable),
-                        };
-                        let packet = Icmpv4Packet::new_unchecked(&*packet_buf);
-                        let repr = Icmpv4Repr::parse(&packet, &ChecksumCapabilities::ignored())?;
-                        let ip_repr = IpRepr::Ipv4(Ipv4Repr {
-                            src_addr,
-                            dst_addr,
-                            next_header: IpProtocol::Icmp,
-                            payload_len: repr.buffer_len(),
-                            hop_limit: hop_limit,
-                        });
-                        emit(cx, (ip_repr, IcmpRepr::Ipv4(repr)))
-                    }
-                    #[cfg(feature = "proto-ipv6")]
-                    IpAddress::Ipv6(dst_addr) => {
-                        let src_addr = match cx.get_source_address_ipv6(dst_addr) {
-                            Some(addr) => addr,
-                            None => return Err(Error::Unaddressable),
-                        };
-                        let packet = Icmpv6Packet::new_unchecked(&*packet_buf);
-                        let repr = Icmpv6Repr::parse(
-                            &src_addr.into(),
-                            &dst_addr.into(),
-                            &packet,
-                            &ChecksumCapabilities::ignored(),
-                        )?;
-                        let ip_repr = IpRepr::Ipv6(Ipv6Repr {
-                            src_addr,
-                            dst_addr,
-                            next_header: IpProtocol::Icmpv6,
-                            payload_len: repr.buffer_len(),
-                            hop_limit: hop_limit,
-                        });
-                        emit(cx, (ip_repr, IcmpRepr::Ipv6(repr)))
-                    }
+        let res = self.tx_buffer.dequeue_with(|remote_endpoint, packet_buf| {
+            net_trace!(
+                "icmp:{}: sending {} octets",
+                remote_endpoint,
+                packet_buf.len()
+            );
+            match *remote_endpoint {
+                #[cfg(feature = "proto-ipv4")]
+                IpAddress::Ipv4(dst_addr) => {
+                    let src_addr = match cx.get_source_address_ipv4(dst_addr) {
+                        Some(addr) => addr,
+                        None => {
+                            net_trace!(
+                                "icmp:{}: not find suitable source address, dropping",
+                                remote_endpoint
+                            );
+                            return Ok(());
+                        }
+                    };
+                    let packet = Icmpv4Packet::new_unchecked(&*packet_buf);
+                    let repr = match Icmpv4Repr::parse(&packet, &ChecksumCapabilities::ignored()) {
+                        Ok(x) => x,
+                        Err(_) => {
+                            net_trace!(
+                                "icmp:{}: malformed packet in queue, dropping",
+                                remote_endpoint
+                            );
+                            return Ok(());
+                        }
+                    };
+                    let ip_repr = IpRepr::Ipv4(Ipv4Repr {
+                        src_addr,
+                        dst_addr,
+                        next_header: IpProtocol::Icmp,
+                        payload_len: repr.buffer_len(),
+                        hop_limit: hop_limit,
+                    });
+                    emit(cx, (ip_repr, IcmpRepr::Ipv4(repr)))
                 }
                 }
-            })
-            .map_err(|_| Error::Exhausted)??;
-
-        #[cfg(feature = "async")]
-        self.tx_waker.wake();
-
-        Ok(())
+                #[cfg(feature = "proto-ipv6")]
+                IpAddress::Ipv6(dst_addr) => {
+                    let src_addr = match cx.get_source_address_ipv6(dst_addr) {
+                        Some(addr) => addr,
+                        None => {
+                            net_trace!(
+                                "icmp:{}: not find suitable source address, dropping",
+                                remote_endpoint
+                            );
+                            return Ok(());
+                        }
+                    };
+                    let packet = Icmpv6Packet::new_unchecked(&*packet_buf);
+                    let repr = match Icmpv6Repr::parse(
+                        &src_addr.into(),
+                        &dst_addr.into(),
+                        &packet,
+                        &ChecksumCapabilities::ignored(),
+                    ) {
+                        Ok(x) => x,
+                        Err(_) => {
+                            net_trace!(
+                                "icmp:{}: malformed packet in queue, dropping",
+                                remote_endpoint
+                            );
+                            return Ok(());
+                        }
+                    };
+                    let ip_repr = IpRepr::Ipv6(Ipv6Repr {
+                        src_addr,
+                        dst_addr,
+                        next_header: IpProtocol::Icmpv6,
+                        payload_len: repr.buffer_len(),
+                        hop_limit: hop_limit,
+                    });
+                    emit(cx, (ip_repr, IcmpRepr::Ipv6(repr)))
+                }
+            }
+        });
+        match res {
+            Err(Empty) => Ok(()),
+            Ok(Err(e)) => Err(e),
+            Ok(Ok(())) => {
+                #[cfg(feature = "async")]
+                self.tx_waker.wake();
+                Ok(())
+            }
+        }
     }
     }
 
 
     pub(crate) fn poll_at(&self, _cx: &mut Context) -> PollAt {
     pub(crate) fn poll_at(&self, _cx: &mut Context) -> PollAt {
@@ -552,8 +584,8 @@ mod tests_common {
 #[cfg(all(test, feature = "proto-ipv4"))]
 #[cfg(all(test, feature = "proto-ipv4"))]
 mod test_ipv4 {
 mod test_ipv4 {
     use super::tests_common::*;
     use super::tests_common::*;
-
     use crate::wire::{Icmpv4DstUnreachable, IpEndpoint, Ipv4Address};
     use crate::wire::{Icmpv4DstUnreachable, IpEndpoint, Ipv4Address};
+    use crate::Error;
 
 
     const REMOTE_IPV4: Ipv4Address = Ipv4Address([192, 168, 1, 2]);
     const REMOTE_IPV4: Ipv4Address = Ipv4Address([192, 168, 1, 2]);
     const LOCAL_IPV4: Ipv4Address = Ipv4Address([192, 168, 1, 1]);
     const LOCAL_IPV4: Ipv4Address = Ipv4Address([192, 168, 1, 1]);
@@ -602,7 +634,7 @@ mod test_ipv4 {
 
 
         assert_eq!(
         assert_eq!(
             socket.dispatch(&mut cx, |_, _| unreachable!()),
             socket.dispatch(&mut cx, |_, _| unreachable!()),
-            Err(Error::Exhausted)
+            Ok::<_, ()>(())
         );
         );
 
 
         // This buffer is too long
         // This buffer is too long
@@ -641,7 +673,7 @@ mod test_ipv4 {
             socket.dispatch(&mut cx, |_, (ip_repr, icmp_repr)| {
             socket.dispatch(&mut cx, |_, (ip_repr, icmp_repr)| {
                 assert_eq!(ip_repr, LOCAL_IPV4_REPR);
                 assert_eq!(ip_repr, LOCAL_IPV4_REPR);
                 assert_eq!(icmp_repr, ECHOV4_REPR.into());
                 assert_eq!(icmp_repr, ECHOV4_REPR.into());
-                Ok(())
+                Ok::<_, Error>(())
             }),
             }),
             Ok(())
             Ok(())
         );
         );
@@ -677,7 +709,7 @@ mod test_ipv4 {
                         hop_limit: 0x2a,
                         hop_limit: 0x2a,
                     })
                     })
                 );
                 );
-                Ok(())
+                Ok::<_, Error>(())
             }),
             }),
             Ok(())
             Ok(())
         );
         );
@@ -795,6 +827,7 @@ mod test_ipv6 {
     use super::tests_common::*;
     use super::tests_common::*;
 
 
     use crate::wire::{Icmpv6DstUnreachable, IpEndpoint, Ipv6Address};
     use crate::wire::{Icmpv6DstUnreachable, IpEndpoint, Ipv6Address};
+    use crate::Error;
 
 
     const REMOTE_IPV6: Ipv6Address =
     const REMOTE_IPV6: Ipv6Address =
         Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]);
         Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]);
@@ -844,7 +877,7 @@ mod test_ipv6 {
 
 
         assert_eq!(
         assert_eq!(
             socket.dispatch(&mut cx, |_, _| unreachable!()),
             socket.dispatch(&mut cx, |_, _| unreachable!()),
-            Err(Error::Exhausted)
+            Ok::<_, Error>(())
         );
         );
 
 
         // This buffer is too long
         // This buffer is too long
@@ -888,7 +921,7 @@ mod test_ipv6 {
             socket.dispatch(&mut cx, |_, (ip_repr, icmp_repr)| {
             socket.dispatch(&mut cx, |_, (ip_repr, icmp_repr)| {
                 assert_eq!(ip_repr, LOCAL_IPV6_REPR);
                 assert_eq!(ip_repr, LOCAL_IPV6_REPR);
                 assert_eq!(icmp_repr, ECHOV6_REPR.into());
                 assert_eq!(icmp_repr, ECHOV6_REPR.into());
-                Ok(())
+                Ok::<_, Error>(())
             }),
             }),
             Ok(())
             Ok(())
         );
         );
@@ -929,7 +962,7 @@ mod test_ipv6 {
                         hop_limit: 0x2a,
                         hop_limit: 0x2a,
                     })
                     })
                 );
                 );
-                Ok(())
+                Ok::<_, Error>(())
             }),
             }),
             Ok(())
             Ok(())
         );
         );

+ 82 - 62
src/socket/raw.rs

@@ -3,12 +3,11 @@ use core::cmp::min;
 use core::task::Waker;
 use core::task::Waker;
 
 
 use crate::iface::Context;
 use crate::iface::Context;
-use crate::phy::ChecksumCapabilities;
 use crate::socket::PollAt;
 use crate::socket::PollAt;
 #[cfg(feature = "async")]
 #[cfg(feature = "async")]
 use crate::socket::WakerRegistration;
 use crate::socket::WakerRegistration;
-use crate::Error;
 
 
+use crate::storage::Empty;
 use crate::wire::{IpProtocol, IpRepr, IpVersion};
 use crate::wire::{IpProtocol, IpRepr, IpVersion};
 #[cfg(feature = "proto-ipv4")]
 #[cfg(feature = "proto-ipv4")]
 use crate::wire::{Ipv4Packet, Ipv4Repr};
 use crate::wire::{Ipv4Packet, Ipv4Repr};
@@ -265,21 +264,27 @@ impl<'a> Socket<'a> {
         self.rx_waker.wake();
         self.rx_waker.wake();
     }
     }
 
 
-    pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<(), Error>
+    pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
     where
     where
-        F: FnOnce(&mut Context, (IpRepr, &[u8])) -> Result<(), Error>,
+        F: FnOnce(&mut Context, (IpRepr, &[u8])) -> Result<(), E>,
     {
     {
-        fn prepare<'a>(
-            next_header: IpProtocol,
-            buffer: &'a mut [u8],
-            _checksum_caps: &ChecksumCapabilities,
-        ) -> Result<(IpRepr, &'a [u8]), Error> {
-            match IpVersion::of_packet(buffer)? {
+        let ip_protocol = self.ip_protocol;
+        let ip_version = self.ip_version;
+        let _checksum_caps = &cx.checksum_caps();
+        let res = self.tx_buffer.dequeue_with(|&mut (), buffer| {
+            match IpVersion::of_packet(buffer) {
                 #[cfg(feature = "proto-ipv4")]
                 #[cfg(feature = "proto-ipv4")]
-                IpVersion::Ipv4 => {
-                    let mut packet = Ipv4Packet::new_checked(buffer)?;
-                    if packet.next_header() != next_header {
-                        return Err(Error::Unaddressable);
+                Ok(IpVersion::Ipv4) => {
+                    let mut packet = match Ipv4Packet::new_checked(buffer) {
+                        Ok(x) => x,
+                        Err(_) => {
+                            net_trace!("raw: malformed ipv6 packet in queue, dropping.");
+                            return Ok(());
+                        }
+                    };
+                    if packet.next_header() != ip_protocol {
+                        net_trace!("raw: sent packet with wrong ip protocol, dropping.");
+                        return Ok(());
                     }
                     }
                     if _checksum_caps.ipv4.tx() {
                     if _checksum_caps.ipv4.tx() {
                         packet.fill_checksum();
                         packet.fill_checksum();
@@ -289,55 +294,57 @@ impl<'a> Socket<'a> {
                         packet.set_checksum(0);
                         packet.set_checksum(0);
                     }
                     }
 
 
-                    let packet = Ipv4Packet::new_checked(&*packet.into_inner())?;
-                    let ipv4_repr = Ipv4Repr::parse(&packet, _checksum_caps)?;
-                    Ok((IpRepr::Ipv4(ipv4_repr), packet.payload()))
+                    let packet = Ipv4Packet::new_unchecked(&*packet.into_inner());
+                    let ipv4_repr = match Ipv4Repr::parse(&packet, _checksum_caps) {
+                        Ok(x) => x,
+                        Err(_) => {
+                            net_trace!("raw: malformed ipv4 packet in queue, dropping.");
+                            return Ok(());
+                        }
+                    };
+                    net_trace!("raw:{}:{}: sending", ip_version, ip_protocol);
+                    emit(cx, (IpRepr::Ipv4(ipv4_repr), packet.payload()))
                 }
                 }
                 #[cfg(feature = "proto-ipv6")]
                 #[cfg(feature = "proto-ipv6")]
-                IpVersion::Ipv6 => {
-                    let packet = Ipv6Packet::new_checked(buffer)?;
-                    if packet.next_header() != next_header {
-                        return Err(Error::Unaddressable);
+                Ok(IpVersion::Ipv6) => {
+                    let packet = match Ipv6Packet::new_checked(buffer) {
+                        Ok(x) => x,
+                        Err(_) => {
+                            net_trace!("raw: malformed ipv6 packet in queue, dropping.");
+                            return Ok(());
+                        }
+                    };
+                    if packet.next_header() != ip_protocol {
+                        net_trace!("raw: sent ipv6 packet with wrong ip protocol, dropping.");
+                        return Ok(());
                     }
                     }
                     let packet = Ipv6Packet::new_unchecked(&*packet.into_inner());
                     let packet = Ipv6Packet::new_unchecked(&*packet.into_inner());
-                    let ipv6_repr = Ipv6Repr::parse(&packet)?;
-                    Ok((IpRepr::Ipv6(ipv6_repr), packet.payload()))
+                    let ipv6_repr = match Ipv6Repr::parse(&packet) {
+                        Ok(x) => x,
+                        Err(_) => {
+                            net_trace!("raw: malformed ipv6 packet in queue, dropping.");
+                            return Ok(());
+                        }
+                    };
+
+                    net_trace!("raw:{}:{}: sending", ip_version, ip_protocol);
+                    emit(cx, (IpRepr::Ipv6(ipv6_repr), packet.payload()))
+                }
+                Err(_) => {
+                    net_trace!("raw: sent packet with invalid IP version, dropping.");
+                    Ok(())
                 }
                 }
             }
             }
+        });
+        match res {
+            Err(Empty) => Ok(()),
+            Ok(Err(e)) => Err(e),
+            Ok(Ok(())) => {
+                #[cfg(feature = "async")]
+                self.tx_waker.wake();
+                Ok(())
+            }
         }
         }
-
-        let ip_protocol = self.ip_protocol;
-        let ip_version = self.ip_version;
-        self.tx_buffer
-            .dequeue_with(|&mut (), packet_buf| {
-                match prepare(ip_protocol, packet_buf, &cx.checksum_caps()) {
-                    Ok((ip_repr, raw_packet)) => {
-                        net_trace!(
-                            "raw:{}:{}: sending {} octets",
-                            ip_version,
-                            ip_protocol,
-                            ip_repr.buffer_len() + raw_packet.len()
-                        );
-                        emit(cx, (ip_repr, raw_packet))
-                    }
-                    Err(error) => {
-                        net_debug!(
-                            "raw:{}:{}: dropping outgoing packet ({})",
-                            ip_version,
-                            ip_protocol,
-                            error
-                        );
-                        // Return Ok(()) so the packet is dequeued.
-                        Ok(())
-                    }
-                }
-            })
-            .map_err(|_| Error::Exhausted)??;
-
-        #[cfg(feature = "async")]
-        self.tx_waker.wake();
-
-        Ok(())
     }
     }
 
 
     pub(crate) fn poll_at(&self, _cx: &mut Context) -> PollAt {
     pub(crate) fn poll_at(&self, _cx: &mut Context) -> PollAt {
@@ -357,6 +364,7 @@ mod test {
     use crate::wire::{Ipv4Address, Ipv4Repr};
     use crate::wire::{Ipv4Address, Ipv4Repr};
     #[cfg(feature = "proto-ipv6")]
     #[cfg(feature = "proto-ipv6")]
     use crate::wire::{Ipv6Address, Ipv6Repr};
     use crate::wire::{Ipv6Address, Ipv6Repr};
+    use crate::Error;
 
 
     fn buffer(packets: usize) -> PacketBuffer<'static> {
     fn buffer(packets: usize) -> PacketBuffer<'static> {
         PacketBuffer::new(vec![PacketMetadata::EMPTY; packets], vec![0; 48 * packets])
         PacketBuffer::new(vec![PacketMetadata::EMPTY; packets], vec![0; 48 * packets])
@@ -453,7 +461,7 @@ mod test {
                     assert!(socket.can_send());
                     assert!(socket.can_send());
                     assert_eq!(
                     assert_eq!(
                         socket.dispatch(&mut cx, |_, _| unreachable!()),
                         socket.dispatch(&mut cx, |_, _| unreachable!()),
-                        Err(Error::Exhausted)
+                        Ok::<_, Error>(())
                     );
                     );
 
 
                     assert_eq!(socket.send_slice(&$packet[..]), Ok(()));
                     assert_eq!(socket.send_slice(&$packet[..]), Ok(()));
@@ -474,7 +482,7 @@ mod test {
                         socket.dispatch(&mut cx, |_, (ip_repr, ip_payload)| {
                         socket.dispatch(&mut cx, |_, (ip_repr, ip_payload)| {
                             assert_eq!(ip_repr, $hdr);
                             assert_eq!(ip_repr, $hdr);
                             assert_eq!(ip_payload, &$payload);
                             assert_eq!(ip_payload, &$payload);
-                            Ok(())
+                            Ok::<_, Error>(())
                         }),
                         }),
                         Ok(())
                         Ok(())
                     );
                     );
@@ -539,13 +547,19 @@ mod test {
             Ipv4Packet::new_unchecked(&mut wrong_version).set_version(6);
             Ipv4Packet::new_unchecked(&mut wrong_version).set_version(6);
 
 
             assert_eq!(socket.send_slice(&wrong_version[..]), Ok(()));
             assert_eq!(socket.send_slice(&wrong_version[..]), Ok(()));
-            assert_eq!(socket.dispatch(&mut cx, |_, _| unreachable!()), Ok(()));
+            assert_eq!(
+                socket.dispatch(&mut cx, |_, _| unreachable!()),
+                Ok::<_, Error>(())
+            );
 
 
             let mut wrong_protocol = ipv4_locals::PACKET_BYTES;
             let mut wrong_protocol = ipv4_locals::PACKET_BYTES;
             Ipv4Packet::new_unchecked(&mut wrong_protocol).set_next_header(IpProtocol::Tcp);
             Ipv4Packet::new_unchecked(&mut wrong_protocol).set_next_header(IpProtocol::Tcp);
 
 
             assert_eq!(socket.send_slice(&wrong_protocol[..]), Ok(()));
             assert_eq!(socket.send_slice(&wrong_protocol[..]), Ok(()));
-            assert_eq!(socket.dispatch(&mut cx, |_, _| unreachable!()), Ok(()));
+            assert_eq!(
+                socket.dispatch(&mut cx, |_, _| unreachable!()),
+                Ok::<_, Error>(())
+            );
         }
         }
         #[cfg(feature = "proto-ipv6")]
         #[cfg(feature = "proto-ipv6")]
         {
         {
@@ -556,13 +570,19 @@ mod test {
             Ipv6Packet::new_unchecked(&mut wrong_version[..]).set_version(4);
             Ipv6Packet::new_unchecked(&mut wrong_version[..]).set_version(4);
 
 
             assert_eq!(socket.send_slice(&wrong_version[..]), Ok(()));
             assert_eq!(socket.send_slice(&wrong_version[..]), Ok(()));
-            assert_eq!(socket.dispatch(&mut cx, |_, _| unreachable!()), Ok(()));
+            assert_eq!(
+                socket.dispatch(&mut cx, |_, _| unreachable!()),
+                Ok::<_, Error>(())
+            );
 
 
             let mut wrong_protocol = ipv6_locals::PACKET_BYTES;
             let mut wrong_protocol = ipv6_locals::PACKET_BYTES;
             Ipv6Packet::new_unchecked(&mut wrong_protocol[..]).set_next_header(IpProtocol::Tcp);
             Ipv6Packet::new_unchecked(&mut wrong_protocol[..]).set_next_header(IpProtocol::Tcp);
 
 
             assert_eq!(socket.send_slice(&wrong_protocol[..]), Ok(()));
             assert_eq!(socket.send_slice(&wrong_protocol[..]), Ok(()));
-            assert_eq!(socket.dispatch(&mut cx, |_, _| unreachable!()), Ok(()));
+            assert_eq!(
+                socket.dispatch(&mut cx, |_, _| unreachable!()),
+                Ok::<_, Error>(())
+            );
         }
         }
     }
     }
 
 

+ 8 - 8
src/socket/tcp.rs

@@ -16,7 +16,6 @@ use crate::wire::{
     IpAddress, IpEndpoint, IpListenEndpoint, IpProtocol, IpRepr, TcpControl, TcpRepr, TcpSeqNumber,
     IpAddress, IpEndpoint, IpListenEndpoint, IpProtocol, IpRepr, TcpControl, TcpRepr, TcpSeqNumber,
     TCP_HEADER_LEN,
     TCP_HEADER_LEN,
 };
 };
-use crate::Error;
 
 
 macro_rules! tcp_trace {
 macro_rules! tcp_trace {
     ($($arg:expr),*) => (net_log!(trace, $($arg),*));
     ($($arg:expr),*) => (net_log!(trace, $($arg),*));
@@ -1934,12 +1933,12 @@ impl<'a> Socket<'a> {
         }
         }
     }
     }
 
 
-    pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<(), Error>
+    pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
     where
     where
-        F: FnOnce(&mut Context, (IpRepr, TcpRepr)) -> Result<(), Error>,
+        F: FnOnce(&mut Context, (IpRepr, TcpRepr)) -> Result<(), E>,
     {
     {
         if self.tuple.is_none() {
         if self.tuple.is_none() {
-            return Err(Error::Exhausted);
+            return Ok(());
         }
         }
 
 
         if self.remote_last_ts.is_none() {
         if self.remote_last_ts.is_none() {
@@ -1999,9 +1998,9 @@ impl<'a> Socket<'a> {
             // If we have spent enough time in the TIME-WAIT state, close the socket.
             // If we have spent enough time in the TIME-WAIT state, close the socket.
             tcp_trace!("TIME-WAIT timer expired");
             tcp_trace!("TIME-WAIT timer expired");
             self.reset();
             self.reset();
-            return Err(Error::Exhausted);
+            return Ok(());
         } else {
         } else {
-            return Err(Error::Exhausted);
+            return Ok(());
         }
         }
 
 
         // NOTE(unwrap): we check tuple is not None the first thing in this function.
         // NOTE(unwrap): we check tuple is not None the first thing in this function.
@@ -2041,7 +2040,7 @@ impl<'a> Socket<'a> {
             }
             }
 
 
             // We never transmit anything in the LISTEN state.
             // We never transmit anything in the LISTEN state.
-            State::Listen => return Err(Error::Exhausted),
+            State::Listen => return Ok(()),
 
 
             // We transmit a SYN in the SYN-SENT state.
             // We transmit a SYN in the SYN-SENT state.
             // We transmit a SYN|ACK in the SYN-RECEIVED state.
             // We transmit a SYN|ACK in the SYN-RECEIVED state.
@@ -2270,6 +2269,7 @@ impl<'a> fmt::Write for Socket<'a> {
 mod test {
 mod test {
     use super::*;
     use super::*;
     use crate::wire::IpRepr;
     use crate::wire::IpRepr;
+    use crate::Error;
     use core::i32;
     use core::i32;
     use std::ops::{Deref, DerefMut};
     use std::ops::{Deref, DerefMut};
     use std::vec::Vec;
     use std::vec::Vec;
@@ -6168,7 +6168,7 @@ mod test {
         assert_eq!(
         assert_eq!(
             s.socket.dispatch(&mut s.cx, |_, (ip_repr, _)| {
             s.socket.dispatch(&mut s.cx, |_, (ip_repr, _)| {
                 assert_eq!(ip_repr.hop_limit(), 0x2a);
                 assert_eq!(ip_repr.hop_limit(), 0x2a);
-                Ok(())
+                Ok::<_, Error>(())
             }),
             }),
             Ok(())
             Ok(())
         );
         );

+ 51 - 41
src/socket/udp.rs

@@ -6,8 +6,8 @@ use crate::iface::Context;
 use crate::socket::PollAt;
 use crate::socket::PollAt;
 #[cfg(feature = "async")]
 #[cfg(feature = "async")]
 use crate::socket::WakerRegistration;
 use crate::socket::WakerRegistration;
+use crate::storage::Empty;
 use crate::wire::{IpEndpoint, IpListenEndpoint, IpProtocol, IpRepr, UdpRepr};
 use crate::wire::{IpEndpoint, IpListenEndpoint, IpProtocol, IpRepr, UdpRepr};
-use crate::Error;
 
 
 /// A UDP packet metadata.
 /// A UDP packet metadata.
 pub type PacketMetadata = crate::storage::PacketMetadata<IpEndpoint>;
 pub type PacketMetadata = crate::storage::PacketMetadata<IpEndpoint>;
@@ -383,49 +383,58 @@ impl<'a> Socket<'a> {
         self.rx_waker.wake();
         self.rx_waker.wake();
     }
     }
 
 
-    pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<(), Error>
+    pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
     where
     where
-        F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), Error>,
+        F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), E>,
     {
     {
         let endpoint = self.endpoint;
         let endpoint = self.endpoint;
         let hop_limit = self.hop_limit.unwrap_or(64);
         let hop_limit = self.hop_limit.unwrap_or(64);
 
 
-        self.tx_buffer
-            .dequeue_with(|remote_endpoint, payload_buf| {
-                let src_addr = match endpoint.addr {
+        let res = self.tx_buffer.dequeue_with(|remote_endpoint, payload_buf| {
+            let src_addr = match endpoint.addr {
+                Some(addr) => addr,
+                None => match cx.get_source_address(remote_endpoint.addr) {
                     Some(addr) => addr,
                     Some(addr) => addr,
-                    None => match cx.get_source_address(remote_endpoint.addr) {
-                        Some(addr) => addr,
-                        None => return Err(Error::Unaddressable),
-                    },
-                };
-
-                net_trace!(
-                    "udp:{}:{}: sending {} octets",
-                    endpoint,
-                    remote_endpoint,
-                    payload_buf.len()
-                );
-
-                let repr = UdpRepr {
-                    src_port: endpoint.port,
-                    dst_port: remote_endpoint.port,
-                };
-                let ip_repr = IpRepr::new(
-                    src_addr,
-                    remote_endpoint.addr,
-                    IpProtocol::Udp,
-                    repr.header_len() + payload_buf.len(),
-                    hop_limit,
-                );
-                emit(cx, (ip_repr, repr, payload_buf))
-            })
-            .map_err(|_| Error::Exhausted)??;
-
-        #[cfg(feature = "async")]
-        self.tx_waker.wake();
-
-        Ok(())
+                    None => {
+                        net_trace!(
+                            "udp:{}:{}: cannot find suitable source address, dropping.",
+                            endpoint,
+                            remote_endpoint
+                        );
+                        return Ok(());
+                    }
+                },
+            };
+
+            net_trace!(
+                "udp:{}:{}: sending {} octets",
+                endpoint,
+                remote_endpoint,
+                payload_buf.len()
+            );
+
+            let repr = UdpRepr {
+                src_port: endpoint.port,
+                dst_port: remote_endpoint.port,
+            };
+            let ip_repr = IpRepr::new(
+                src_addr,
+                remote_endpoint.addr,
+                IpProtocol::Udp,
+                repr.header_len() + payload_buf.len(),
+                hop_limit,
+            );
+            emit(cx, (ip_repr, repr, payload_buf))
+        });
+        match res {
+            Err(Empty) => Ok(()),
+            Ok(Err(e)) => Err(e),
+            Ok(Ok(())) => {
+                #[cfg(feature = "async")]
+                self.tx_waker.wake();
+                Ok(())
+            }
+        }
     }
     }
 
 
     pub(crate) fn poll_at(&self, _cx: &mut Context) -> PollAt {
     pub(crate) fn poll_at(&self, _cx: &mut Context) -> PollAt {
@@ -441,6 +450,7 @@ impl<'a> Socket<'a> {
 mod test {
 mod test {
     use super::*;
     use super::*;
     use crate::wire::{IpRepr, UdpRepr};
     use crate::wire::{IpRepr, UdpRepr};
+    use crate::Error;
 
 
     fn buffer(packets: usize) -> PacketBuffer<'static> {
     fn buffer(packets: usize) -> PacketBuffer<'static> {
         PacketBuffer::new(vec![PacketMetadata::EMPTY; packets], vec![0; 16 * packets])
         PacketBuffer::new(vec![PacketMetadata::EMPTY; packets], vec![0; 16 * packets])
@@ -589,7 +599,7 @@ mod test {
         assert!(socket.can_send());
         assert!(socket.can_send());
         assert_eq!(
         assert_eq!(
             socket.dispatch(&mut cx, |_, _| unreachable!()),
             socket.dispatch(&mut cx, |_, _| unreachable!()),
-            Err(Error::Exhausted)
+            Ok::<_, Error>(())
         );
         );
 
 
         assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Ok(()));
         assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Ok(()));
@@ -615,7 +625,7 @@ mod test {
                 assert_eq!(ip_repr, LOCAL_IP_REPR);
                 assert_eq!(ip_repr, LOCAL_IP_REPR);
                 assert_eq!(udp_repr, LOCAL_UDP_REPR);
                 assert_eq!(udp_repr, LOCAL_UDP_REPR);
                 assert_eq!(payload, PAYLOAD);
                 assert_eq!(payload, PAYLOAD);
-                Ok(())
+                Ok::<_, Error>(())
             }),
             }),
             Ok(())
             Ok(())
         );
         );
@@ -711,7 +721,7 @@ mod test {
                         hop_limit: 0x2a,
                         hop_limit: 0x2a,
                     })
                     })
                 );
                 );
-                Ok(())
+                Ok::<_, Error>(())
             }),
             }),
             Ok(())
             Ok(())
         );
         );