Browse Source

Merge #557

557: Simplify socket handling r=Dirbaio a=Dirbaio

See individual commit messages

Co-authored-by: Dario Nieuwenhuis <[email protected]>
bors[bot] 3 years ago
parent
commit
a17c167bd9

+ 31 - 37
examples/benchmark.rs

@@ -13,7 +13,6 @@ use std::thread;
 
 use smoltcp::iface::{InterfaceBuilder, NeighborCache};
 use smoltcp::phy::{wait as phy_wait, Device, Medium};
-use smoltcp::socket::SocketSet;
 use smoltcp::socket::{TcpSocket, TcpSocketBuffer};
 use smoltcp::time::{Duration, Instant};
 use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr};
@@ -97,7 +96,7 @@ fn main() {
     let ethernet_addr = EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]);
     let ip_addrs = [IpCidr::new(IpAddress::v4(192, 168, 69, 1), 24)];
     let medium = device.capabilities().medium;
-    let mut builder = InterfaceBuilder::new(device).ip_addrs(ip_addrs);
+    let mut builder = InterfaceBuilder::new(device, vec![]).ip_addrs(ip_addrs);
     if medium == Medium::Ethernet {
         builder = builder
             .hardware_addr(ethernet_addr.into())
@@ -105,15 +104,14 @@ fn main() {
     }
     let mut iface = builder.finalize();
 
-    let mut sockets = SocketSet::new(vec![]);
-    let tcp1_handle = sockets.add(tcp1_socket);
-    let tcp2_handle = sockets.add(tcp2_socket);
+    let tcp1_handle = iface.add_socket(tcp1_socket);
+    let tcp2_handle = iface.add_socket(tcp2_socket);
     let default_timeout = Some(Duration::from_millis(1000));
 
     let mut processed = 0;
     while !CLIENT_DONE.load(Ordering::SeqCst) {
         let timestamp = Instant::now();
-        match iface.poll(&mut sockets, timestamp) {
+        match iface.poll(timestamp) {
             Ok(_) => {}
             Err(e) => {
                 debug!("poll error: {}", e);
@@ -121,46 +119,42 @@ fn main() {
         }
 
         // tcp:1234: emit data
-        {
-            let mut socket = sockets.get::<TcpSocket>(tcp1_handle);
-            if !socket.is_open() {
-                socket.listen(1234).unwrap();
-            }
+        let socket = iface.get_socket::<TcpSocket>(tcp1_handle);
+        if !socket.is_open() {
+            socket.listen(1234).unwrap();
+        }
 
-            if socket.can_send() {
-                if processed < AMOUNT {
-                    let length = socket
-                        .send(|buffer| {
-                            let length = cmp::min(buffer.len(), AMOUNT - processed);
-                            (length, length)
-                        })
-                        .unwrap();
-                    processed += length;
-                }
+        if socket.can_send() {
+            if processed < AMOUNT {
+                let length = socket
+                    .send(|buffer| {
+                        let length = cmp::min(buffer.len(), AMOUNT - processed);
+                        (length, length)
+                    })
+                    .unwrap();
+                processed += length;
             }
         }
 
         // tcp:1235: sink data
-        {
-            let mut socket = sockets.get::<TcpSocket>(tcp2_handle);
-            if !socket.is_open() {
-                socket.listen(1235).unwrap();
-            }
+        let socket = iface.get_socket::<TcpSocket>(tcp2_handle);
+        if !socket.is_open() {
+            socket.listen(1235).unwrap();
+        }
 
-            if socket.can_recv() {
-                if processed < AMOUNT {
-                    let length = socket
-                        .recv(|buffer| {
-                            let length = cmp::min(buffer.len(), AMOUNT - processed);
-                            (length, length)
-                        })
-                        .unwrap();
-                    processed += length;
-                }
+        if socket.can_recv() {
+            if processed < AMOUNT {
+                let length = socket
+                    .recv(|buffer| {
+                        let length = cmp::min(buffer.len(), AMOUNT - processed);
+                        (length, length)
+                    })
+                    .unwrap();
+                processed += length;
             }
         }
 
-        match iface.poll_at(&sockets, timestamp) {
+        match iface.poll_at(timestamp) {
             Some(poll_at) if timestamp < poll_at => {
                 phy_wait(fd, Some(poll_at - timestamp)).expect("wait error");
             }

+ 41 - 46
examples/client.rs

@@ -7,7 +7,7 @@ use std::str::{self, FromStr};
 
 use smoltcp::iface::{InterfaceBuilder, NeighborCache, Routes};
 use smoltcp::phy::{wait as phy_wait, Device, Medium};
-use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer};
+use smoltcp::socket::{TcpSocket, TcpSocketBuffer};
 use smoltcp::time::Instant;
 use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr, Ipv4Address};
 
@@ -42,7 +42,7 @@ fn main() {
     routes.add_default_ipv4_route(default_v4_gw).unwrap();
 
     let medium = device.capabilities().medium;
-    let mut builder = InterfaceBuilder::new(device)
+    let mut builder = InterfaceBuilder::new(device, vec![])
         .ip_addrs(ip_addrs)
         .routes(routes);
     if medium == Medium::Ethernet {
@@ -52,63 +52,58 @@ fn main() {
     }
     let mut iface = builder.finalize();
 
-    let mut sockets = SocketSet::new(vec![]);
-    let tcp_handle = sockets.add(tcp_socket);
+    let tcp_handle = iface.add_socket(tcp_socket);
 
-    {
-        let mut socket = sockets.get::<TcpSocket>(tcp_handle);
-        socket.connect((address, port), 49500).unwrap();
-    }
+    let socket = iface.get_socket::<TcpSocket>(tcp_handle);
+    socket.connect((address, port), 49500).unwrap();
 
     let mut tcp_active = false;
     loop {
         let timestamp = Instant::now();
-        match iface.poll(&mut sockets, timestamp) {
+        match iface.poll(timestamp) {
             Ok(_) => {}
             Err(e) => {
                 debug!("poll error: {}", e);
             }
         }
 
-        {
-            let mut socket = sockets.get::<TcpSocket>(tcp_handle);
-            if socket.is_active() && !tcp_active {
-                debug!("connected");
-            } else if !socket.is_active() && tcp_active {
-                debug!("disconnected");
-                break;
-            }
-            tcp_active = socket.is_active();
-
-            if socket.may_recv() {
-                let data = socket
-                    .recv(|data| {
-                        let mut data = data.to_owned();
-                        if !data.is_empty() {
-                            debug!(
-                                "recv data: {:?}",
-                                str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)")
-                            );
-                            data = data.split(|&b| b == b'\n').collect::<Vec<_>>().concat();
-                            data.reverse();
-                            data.extend(b"\n");
-                        }
-                        (data.len(), data)
-                    })
-                    .unwrap();
-                if socket.can_send() && !data.is_empty() {
-                    debug!(
-                        "send data: {:?}",
-                        str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)")
-                    );
-                    socket.send_slice(&data[..]).unwrap();
-                }
-            } else if socket.may_send() {
-                debug!("close");
-                socket.close();
+        let socket = iface.get_socket::<TcpSocket>(tcp_handle);
+        if socket.is_active() && !tcp_active {
+            debug!("connected");
+        } else if !socket.is_active() && tcp_active {
+            debug!("disconnected");
+            break;
+        }
+        tcp_active = socket.is_active();
+
+        if socket.may_recv() {
+            let data = socket
+                .recv(|data| {
+                    let mut data = data.to_owned();
+                    if !data.is_empty() {
+                        debug!(
+                            "recv data: {:?}",
+                            str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)")
+                        );
+                        data = data.split(|&b| b == b'\n').collect::<Vec<_>>().concat();
+                        data.reverse();
+                        data.extend(b"\n");
+                    }
+                    (data.len(), data)
+                })
+                .unwrap();
+            if socket.can_send() && !data.is_empty() {
+                debug!(
+                    "send data: {:?}",
+                    str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)")
+                );
+                socket.send_slice(&data[..]).unwrap();
             }
+        } else if socket.may_send() {
+            debug!("close");
+            socket.close();
         }
 
-        phy_wait(fd, iface.poll_delay(&sockets, timestamp)).expect("wait error");
+        phy_wait(fd, iface.poll_delay(timestamp)).expect("wait error");
     }
 }

+ 7 - 7
examples/dhcp_client.rs

@@ -6,7 +6,7 @@ use std::collections::BTreeMap;
 use std::os::unix::io::AsRawFd;
 
 use smoltcp::iface::{Interface, InterfaceBuilder, NeighborCache, Routes};
-use smoltcp::socket::{Dhcpv4Event, Dhcpv4Socket, SocketSet};
+use smoltcp::socket::{Dhcpv4Event, Dhcpv4Socket};
 use smoltcp::time::Instant;
 use smoltcp::wire::{EthernetAddress, IpCidr, Ipv4Address, Ipv4Cidr};
 use smoltcp::{
@@ -34,7 +34,7 @@ fn main() {
     let routes = Routes::new(&mut routes_storage[..]);
 
     let medium = device.capabilities().medium;
-    let mut builder = InterfaceBuilder::new(device)
+    let mut builder = InterfaceBuilder::new(device, vec![])
         .ip_addrs(ip_addrs)
         .routes(routes);
     if medium == Medium::Ethernet {
@@ -44,7 +44,6 @@ fn main() {
     }
     let mut iface = builder.finalize();
 
-    let mut sockets = SocketSet::new(vec![]);
     let mut dhcp_socket = Dhcpv4Socket::new();
 
     // Set a ridiculously short max lease time to show DHCP renews work properly.
@@ -53,15 +52,16 @@ fn main() {
     // IMPORTANT: This should be removed in production.
     dhcp_socket.set_max_lease_duration(Some(Duration::from_secs(10)));
 
-    let dhcp_handle = sockets.add(dhcp_socket);
+    let dhcp_handle = iface.add_socket(dhcp_socket);
 
     loop {
         let timestamp = Instant::now();
-        if let Err(e) = iface.poll(&mut sockets, timestamp) {
+        if let Err(e) = iface.poll(timestamp) {
             debug!("poll error: {}", e);
         }
 
-        match sockets.get::<Dhcpv4Socket>(dhcp_handle).poll() {
+        let event = iface.get_socket::<Dhcpv4Socket>(dhcp_handle).poll();
+        match event {
             None => {}
             Some(Dhcpv4Event::Configured(config)) => {
                 debug!("DHCP config acquired!");
@@ -90,7 +90,7 @@ fn main() {
             }
         }
 
-        phy_wait(fd, iface.poll_delay(&sockets, timestamp)).expect("wait error");
+        phy_wait(fd, iface.poll_delay(timestamp)).expect("wait error");
     }
 }
 

+ 41 - 44
examples/httpclient.rs

@@ -8,7 +8,7 @@ use url::Url;
 
 use smoltcp::iface::{InterfaceBuilder, NeighborCache, Routes};
 use smoltcp::phy::{wait as phy_wait, Device, Medium};
-use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer};
+use smoltcp::socket::{TcpSocket, TcpSocketBuffer};
 use smoltcp::time::Instant;
 use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr, Ipv4Address, Ipv6Address};
 
@@ -48,7 +48,7 @@ fn main() {
     routes.add_default_ipv6_route(default_v6_gw).unwrap();
 
     let medium = device.capabilities().medium;
-    let mut builder = InterfaceBuilder::new(device)
+    let mut builder = InterfaceBuilder::new(device, vec![])
         .ip_addrs(ip_addrs)
         .routes(routes);
     if medium == Medium::Ethernet {
@@ -58,8 +58,7 @@ fn main() {
     }
     let mut iface = builder.finalize();
 
-    let mut sockets = SocketSet::new(vec![]);
-    let tcp_handle = sockets.add(tcp_socket);
+    let tcp_handle = iface.add_socket(tcp_socket);
 
     enum State {
         Connect,
@@ -70,54 +69,52 @@ fn main() {
 
     loop {
         let timestamp = Instant::now();
-        match iface.poll(&mut sockets, timestamp) {
+        match iface.poll(timestamp) {
             Ok(_) => {}
             Err(e) => {
                 debug!("poll error: {}", e);
             }
         }
 
-        {
-            let mut socket = sockets.get::<TcpSocket>(tcp_handle);
+        let socket = iface.get_socket::<TcpSocket>(tcp_handle);
 
-            state = match state {
-                State::Connect if !socket.is_active() => {
-                    debug!("connecting");
-                    let local_port = 49152 + rand::random::<u16>() % 16384;
-                    socket
-                        .connect((address, url.port().unwrap_or(80)), local_port)
-                        .unwrap();
-                    State::Request
-                }
-                State::Request if socket.may_send() => {
-                    debug!("sending request");
-                    let http_get = "GET ".to_owned() + url.path() + " HTTP/1.1\r\n";
-                    socket.send_slice(http_get.as_ref()).expect("cannot send");
-                    let http_host = "Host: ".to_owned() + url.host_str().unwrap() + "\r\n";
-                    socket.send_slice(http_host.as_ref()).expect("cannot send");
-                    socket
-                        .send_slice(b"Connection: close\r\n")
-                        .expect("cannot send");
-                    socket.send_slice(b"\r\n").expect("cannot send");
-                    State::Response
-                }
-                State::Response if socket.can_recv() => {
-                    socket
-                        .recv(|data| {
-                            println!("{}", str::from_utf8(data).unwrap_or("(invalid utf8)"));
-                            (data.len(), ())
-                        })
-                        .unwrap();
-                    State::Response
-                }
-                State::Response if !socket.may_recv() => {
-                    debug!("received complete response");
-                    break;
-                }
-                _ => state,
+        state = match state {
+            State::Connect if !socket.is_active() => {
+                debug!("connecting");
+                let local_port = 49152 + rand::random::<u16>() % 16384;
+                socket
+                    .connect((address, url.port().unwrap_or(80)), local_port)
+                    .unwrap();
+                State::Request
             }
-        }
+            State::Request if socket.may_send() => {
+                debug!("sending request");
+                let http_get = "GET ".to_owned() + url.path() + " HTTP/1.1\r\n";
+                socket.send_slice(http_get.as_ref()).expect("cannot send");
+                let http_host = "Host: ".to_owned() + url.host_str().unwrap() + "\r\n";
+                socket.send_slice(http_host.as_ref()).expect("cannot send");
+                socket
+                    .send_slice(b"Connection: close\r\n")
+                    .expect("cannot send");
+                socket.send_slice(b"\r\n").expect("cannot send");
+                State::Response
+            }
+            State::Response if socket.can_recv() => {
+                socket
+                    .recv(|data| {
+                        println!("{}", str::from_utf8(data).unwrap_or("(invalid utf8)"));
+                        (data.len(), ())
+                    })
+                    .unwrap();
+                State::Response
+            }
+            State::Response if !socket.may_recv() => {
+                debug!("received complete response");
+                break;
+            }
+            _ => state,
+        };
 
-        phy_wait(fd, iface.poll_delay(&sockets, timestamp)).expect("wait error");
+        phy_wait(fd, iface.poll_delay(timestamp)).expect("wait error");
     }
 }

+ 37 - 42
examples/loopback.rs

@@ -11,7 +11,7 @@ use log::{debug, error, info};
 
 use smoltcp::iface::{InterfaceBuilder, NeighborCache};
 use smoltcp::phy::{Loopback, Medium};
-use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer};
+use smoltcp::socket::{TcpSocket, TcpSocketBuffer};
 use smoltcp::time::{Duration, Instant};
 use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr};
 
@@ -94,7 +94,8 @@ fn main() {
     let mut neighbor_cache = NeighborCache::new(&mut neighbor_cache_entries[..]);
 
     let mut ip_addrs = [IpCidr::new(IpAddress::v4(127, 0, 0, 1), 8)];
-    let mut iface = InterfaceBuilder::new(device)
+    let mut sockets: [_; 2] = Default::default();
+    let mut iface = InterfaceBuilder::new(device, &mut sockets[..])
         .hardware_addr(EthernetAddress::default().into())
         .neighbor_cache(neighbor_cache)
         .ip_addrs(ip_addrs)
@@ -120,65 +121,59 @@ fn main() {
         TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer)
     };
 
-    let mut socket_set_entries: [_; 2] = Default::default();
-    let mut socket_set = SocketSet::new(&mut socket_set_entries[..]);
-    let server_handle = socket_set.add(server_socket);
-    let client_handle = socket_set.add(client_socket);
+    let server_handle = iface.add_socket(server_socket);
+    let client_handle = iface.add_socket(client_socket);
 
     let mut did_listen = false;
     let mut did_connect = false;
     let mut done = false;
     while !done && clock.elapsed() < Instant::from_millis(10_000) {
-        match iface.poll(&mut socket_set, clock.elapsed()) {
+        match iface.poll(clock.elapsed()) {
             Ok(_) => {}
             Err(e) => {
                 debug!("poll error: {}", e);
             }
         }
 
-        {
-            let mut socket = socket_set.get::<TcpSocket>(server_handle);
-            if !socket.is_active() && !socket.is_listening() {
-                if !did_listen {
-                    debug!("listening");
-                    socket.listen(1234).unwrap();
-                    did_listen = true;
-                }
+        let mut socket = iface.get_socket::<TcpSocket>(server_handle);
+        if !socket.is_active() && !socket.is_listening() {
+            if !did_listen {
+                debug!("listening");
+                socket.listen(1234).unwrap();
+                did_listen = true;
             }
+        }
 
-            if socket.can_recv() {
-                debug!(
-                    "got {:?}",
-                    socket.recv(|buffer| { (buffer.len(), str::from_utf8(buffer).unwrap()) })
-                );
-                socket.close();
-                done = true;
-            }
+        if socket.can_recv() {
+            debug!(
+                "got {:?}",
+                socket.recv(|buffer| { (buffer.len(), str::from_utf8(buffer).unwrap()) })
+            );
+            socket.close();
+            done = true;
         }
 
-        {
-            let mut socket = socket_set.get::<TcpSocket>(client_handle);
-            if !socket.is_open() {
-                if !did_connect {
-                    debug!("connecting");
-                    socket
-                        .connect(
-                            (IpAddress::v4(127, 0, 0, 1), 1234),
-                            (IpAddress::Unspecified, 65000),
-                        )
-                        .unwrap();
-                    did_connect = true;
-                }
+        let mut socket = iface.get_socket::<TcpSocket>(client_handle);
+        if !socket.is_open() {
+            if !did_connect {
+                debug!("connecting");
+                socket
+                    .connect(
+                        (IpAddress::v4(127, 0, 0, 1), 1234),
+                        (IpAddress::Unspecified, 65000),
+                    )
+                    .unwrap();
+                did_connect = true;
             }
+        }
 
-            if socket.can_send() {
-                debug!("sending");
-                socket.send_slice(b"0123456789abcdef").unwrap();
-                socket.close();
-            }
+        if socket.can_send() {
+            debug!("sending");
+            socket.send_slice(b"0123456789abcdef").unwrap();
+            socket.close();
         }
 
-        match iface.poll_delay(&socket_set, clock.elapsed()) {
+        match iface.poll_delay(clock.elapsed()) {
             Some(Duration::ZERO) => debug!("resuming"),
             Some(delay) => {
                 debug!("sleeping for {} ms", delay);

+ 30 - 36
examples/multicast.rs

@@ -7,8 +7,7 @@ use std::os::unix::io::AsRawFd;
 use smoltcp::iface::{InterfaceBuilder, NeighborCache};
 use smoltcp::phy::wait as phy_wait;
 use smoltcp::socket::{
-    RawPacketMetadata, RawSocket, RawSocketBuffer, SocketSet, UdpPacketMetadata, UdpSocket,
-    UdpSocketBuffer,
+    RawPacketMetadata, RawSocket, RawSocketBuffer, UdpPacketMetadata, UdpSocket, UdpSocketBuffer,
 };
 use smoltcp::time::Instant;
 use smoltcp::wire::{
@@ -37,7 +36,7 @@ fn main() {
     let ethernet_addr = EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x02]);
     let ip_addr = IpCidr::new(IpAddress::from(local_addr), 24);
     let mut ipv4_multicast_storage = [None; 1];
-    let mut iface = InterfaceBuilder::new(device)
+    let mut iface = InterfaceBuilder::new(device, vec![])
         .hardware_addr(ethernet_addr.into())
         .neighbor_cache(neighbor_cache)
         .ip_addrs([ip_addr])
@@ -50,8 +49,6 @@ fn main() {
         .join_multicast_group(Ipv4Address::from_bytes(&MDNS_GROUP), now)
         .unwrap();
 
-    let mut sockets = SocketSet::new(vec![]);
-
     // Must fit at least one IGMP packet
     let raw_rx_buffer = RawSocketBuffer::new(vec![RawPacketMetadata::EMPTY; 2], vec![0; 512]);
     // Will not send IGMP
@@ -62,55 +59,52 @@ fn main() {
         raw_rx_buffer,
         raw_tx_buffer,
     );
-    let raw_handle = sockets.add(raw_socket);
+    let raw_handle = iface.add_socket(raw_socket);
 
     // Must fit mDNS payload of at least one packet
     let udp_rx_buffer = UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY; 4], vec![0; 1024]);
     // Will not send mDNS
     let udp_tx_buffer = UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY], vec![0; 0]);
     let udp_socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer);
-    let udp_handle = sockets.add(udp_socket);
+    let udp_handle = iface.add_socket(udp_socket);
 
     loop {
         let timestamp = Instant::now();
-        match iface.poll(&mut sockets, timestamp) {
+        match iface.poll(timestamp) {
             Ok(_) => {}
             Err(e) => {
                 debug!("poll error: {}", e);
             }
         }
 
-        {
-            let mut socket = sockets.get::<RawSocket>(raw_handle);
-
-            if socket.can_recv() {
-                // For display purposes only - normally we wouldn't process incoming IGMP packets
-                // in the application layer
-                socket
-                    .recv()
-                    .and_then(Ipv4Packet::new_checked)
-                    .and_then(|ipv4_packet| IgmpPacket::new_checked(ipv4_packet.payload()))
-                    .and_then(|igmp_packet| IgmpRepr::parse(&igmp_packet))
-                    .map(|igmp_repr| println!("IGMP packet: {:?}", igmp_repr))
-                    .unwrap_or_else(|e| println!("Recv IGMP error: {:?}", e));
-            }
+        let socket = iface.get_socket::<RawSocket>(raw_handle);
+
+        if socket.can_recv() {
+            // For display purposes only - normally we wouldn't process incoming IGMP packets
+            // in the application layer
+            socket
+                .recv()
+                .and_then(Ipv4Packet::new_checked)
+                .and_then(|ipv4_packet| IgmpPacket::new_checked(ipv4_packet.payload()))
+                .and_then(|igmp_packet| IgmpRepr::parse(&igmp_packet))
+                .map(|igmp_repr| println!("IGMP packet: {:?}", igmp_repr))
+                .unwrap_or_else(|e| println!("Recv IGMP error: {:?}", e));
         }
-        {
-            let mut socket = sockets.get::<UdpSocket>(udp_handle);
-            if !socket.is_open() {
-                socket.bind(MDNS_PORT).unwrap()
-            }
 
-            if socket.can_recv() {
-                socket
-                    .recv()
-                    .map(|(data, sender)| {
-                        println!("mDNS traffic: {} UDP bytes from {}", data.len(), sender)
-                    })
-                    .unwrap_or_else(|e| println!("Recv UDP error: {:?}", e));
-            }
+        let socket = iface.get_socket::<UdpSocket>(udp_handle);
+        if !socket.is_open() {
+            socket.bind(MDNS_PORT).unwrap()
+        }
+
+        if socket.can_recv() {
+            socket
+                .recv()
+                .map(|(data, sender)| {
+                    println!("mDNS traffic: {} UDP bytes from {}", data.len(), sender)
+                })
+                .unwrap_or_else(|e| println!("Recv UDP error: {:?}", e));
         }
 
-        phy_wait(fd, iface.poll_delay(&sockets, timestamp)).expect("wait error");
+        phy_wait(fd, iface.poll_delay(timestamp)).expect("wait error");
     }
 }

+ 94 - 98
examples/ping.rs

@@ -11,7 +11,7 @@ use std::str::FromStr;
 use smoltcp::iface::{InterfaceBuilder, NeighborCache, Routes};
 use smoltcp::phy::wait as phy_wait;
 use smoltcp::phy::Device;
-use smoltcp::socket::{IcmpEndpoint, IcmpPacketMetadata, IcmpSocket, IcmpSocketBuffer, SocketSet};
+use smoltcp::socket::{IcmpEndpoint, IcmpPacketMetadata, IcmpSocket, IcmpSocketBuffer};
 use smoltcp::wire::{
     EthernetAddress, Icmpv4Packet, Icmpv4Repr, Icmpv6Packet, Icmpv6Repr, IpAddress, IpCidr,
     Ipv4Address, Ipv6Address,
@@ -127,7 +127,7 @@ fn main() {
     routes.add_default_ipv6_route(default_v6_gw).unwrap();
 
     let medium = device.capabilities().medium;
-    let mut builder = InterfaceBuilder::new(device)
+    let mut builder = InterfaceBuilder::new(device, vec![])
         .ip_addrs(ip_addrs)
         .routes(routes);
     if medium == Medium::Ethernet {
@@ -137,8 +137,7 @@ fn main() {
     }
     let mut iface = builder.finalize();
 
-    let mut sockets = SocketSet::new(vec![]);
-    let icmp_handle = sockets.add(icmp_socket);
+    let icmp_handle = iface.add_socket(icmp_socket);
 
     let mut send_at = Instant::from_millis(0);
     let mut seq_no = 0;
@@ -149,119 +148,116 @@ fn main() {
 
     loop {
         let timestamp = Instant::now();
-        match iface.poll(&mut sockets, timestamp) {
+        match iface.poll(timestamp) {
             Ok(_) => {}
             Err(e) => {
                 debug!("poll error: {}", e);
             }
         }
 
-        {
-            let timestamp = Instant::now();
-            let mut socket = sockets.get::<IcmpSocket>(icmp_handle);
-            if !socket.is_open() {
-                socket.bind(IcmpEndpoint::Ident(ident)).unwrap();
-                send_at = timestamp;
-            }
+        let timestamp = Instant::now();
+        let socket = iface.get_socket::<IcmpSocket>(icmp_handle);
+        if !socket.is_open() {
+            socket.bind(IcmpEndpoint::Ident(ident)).unwrap();
+            send_at = timestamp;
+        }
 
-            if socket.can_send() && seq_no < count as u16 && send_at <= timestamp {
-                NetworkEndian::write_i64(&mut echo_payload, timestamp.total_millis());
+        if socket.can_send() && seq_no < count as u16 && send_at <= timestamp {
+            NetworkEndian::write_i64(&mut echo_payload, timestamp.total_millis());
 
-                match remote_addr {
-                    IpAddress::Ipv4(_) => {
-                        let (icmp_repr, mut icmp_packet) = send_icmp_ping!(
-                            Icmpv4Repr,
-                            Icmpv4Packet,
-                            ident,
-                            seq_no,
-                            echo_payload,
-                            socket,
-                            remote_addr
-                        );
-                        icmp_repr.emit(&mut icmp_packet, &device_caps.checksum);
-                    }
-                    IpAddress::Ipv6(_) => {
-                        let (icmp_repr, mut icmp_packet) = send_icmp_ping!(
-                            Icmpv6Repr,
-                            Icmpv6Packet,
-                            ident,
-                            seq_no,
-                            echo_payload,
-                            socket,
-                            remote_addr
-                        );
-                        icmp_repr.emit(
-                            &src_ipv6,
-                            &remote_addr,
-                            &mut icmp_packet,
-                            &device_caps.checksum,
-                        );
-                    }
-                    _ => unimplemented!(),
+            match remote_addr {
+                IpAddress::Ipv4(_) => {
+                    let (icmp_repr, mut icmp_packet) = send_icmp_ping!(
+                        Icmpv4Repr,
+                        Icmpv4Packet,
+                        ident,
+                        seq_no,
+                        echo_payload,
+                        socket,
+                        remote_addr
+                    );
+                    icmp_repr.emit(&mut icmp_packet, &device_caps.checksum);
                 }
-
-                waiting_queue.insert(seq_no, timestamp);
-                seq_no += 1;
-                send_at += interval;
+                IpAddress::Ipv6(_) => {
+                    let (icmp_repr, mut icmp_packet) = send_icmp_ping!(
+                        Icmpv6Repr,
+                        Icmpv6Packet,
+                        ident,
+                        seq_no,
+                        echo_payload,
+                        socket,
+                        remote_addr
+                    );
+                    icmp_repr.emit(
+                        &src_ipv6,
+                        &remote_addr,
+                        &mut icmp_packet,
+                        &device_caps.checksum,
+                    );
+                }
+                _ => unimplemented!(),
             }
 
-            if socket.can_recv() {
-                let (payload, _) = socket.recv().unwrap();
+            waiting_queue.insert(seq_no, timestamp);
+            seq_no += 1;
+            send_at += interval;
+        }
+
+        if socket.can_recv() {
+            let (payload, _) = socket.recv().unwrap();
 
-                match remote_addr {
-                    IpAddress::Ipv4(_) => {
-                        let icmp_packet = Icmpv4Packet::new_checked(&payload).unwrap();
-                        let icmp_repr =
-                            Icmpv4Repr::parse(&icmp_packet, &device_caps.checksum).unwrap();
-                        get_icmp_pong!(
-                            Icmpv4Repr,
-                            icmp_repr,
-                            payload,
-                            waiting_queue,
-                            remote_addr,
-                            timestamp,
-                            received
-                        );
-                    }
-                    IpAddress::Ipv6(_) => {
-                        let icmp_packet = Icmpv6Packet::new_checked(&payload).unwrap();
-                        let icmp_repr = Icmpv6Repr::parse(
-                            &remote_addr,
-                            &src_ipv6,
-                            &icmp_packet,
-                            &device_caps.checksum,
-                        )
-                        .unwrap();
-                        get_icmp_pong!(
-                            Icmpv6Repr,
-                            icmp_repr,
-                            payload,
-                            waiting_queue,
-                            remote_addr,
-                            timestamp,
-                            received
-                        );
-                    }
-                    _ => unimplemented!(),
+            match remote_addr {
+                IpAddress::Ipv4(_) => {
+                    let icmp_packet = Icmpv4Packet::new_checked(&payload).unwrap();
+                    let icmp_repr = Icmpv4Repr::parse(&icmp_packet, &device_caps.checksum).unwrap();
+                    get_icmp_pong!(
+                        Icmpv4Repr,
+                        icmp_repr,
+                        payload,
+                        waiting_queue,
+                        remote_addr,
+                        timestamp,
+                        received
+                    );
                 }
-            }
-
-            waiting_queue.retain(|seq, from| {
-                if timestamp - *from < timeout {
-                    true
-                } else {
-                    println!("From {} icmp_seq={} timeout", remote_addr, seq);
-                    false
+                IpAddress::Ipv6(_) => {
+                    let icmp_packet = Icmpv6Packet::new_checked(&payload).unwrap();
+                    let icmp_repr = Icmpv6Repr::parse(
+                        &remote_addr,
+                        &src_ipv6,
+                        &icmp_packet,
+                        &device_caps.checksum,
+                    )
+                    .unwrap();
+                    get_icmp_pong!(
+                        Icmpv6Repr,
+                        icmp_repr,
+                        payload,
+                        waiting_queue,
+                        remote_addr,
+                        timestamp,
+                        received
+                    );
                 }
-            });
+                _ => unimplemented!(),
+            }
+        }
 
-            if seq_no == count as u16 && waiting_queue.is_empty() {
-                break;
+        waiting_queue.retain(|seq, from| {
+            if timestamp - *from < timeout {
+                true
+            } else {
+                println!("From {} icmp_seq={} timeout", remote_addr, seq);
+                false
             }
+        });
+
+        if seq_no == count as u16 && waiting_queue.is_empty() {
+            break;
         }
 
         let timestamp = Instant::now();
-        match iface.poll_at(&sockets, timestamp) {
+        match iface.poll_at(timestamp) {
             Some(poll_at) if timestamp < poll_at => {
                 let resume_at = cmp::min(poll_at, send_at);
                 phy_wait(fd, Some(resume_at - timestamp)).expect("wait error");

+ 105 - 117
examples/server.rs

@@ -8,7 +8,6 @@ use std::str;
 
 use smoltcp::iface::{InterfaceBuilder, NeighborCache};
 use smoltcp::phy::{wait as phy_wait, Device, Medium};
-use smoltcp::socket::SocketSet;
 use smoltcp::socket::{TcpSocket, TcpSocketBuffer};
 use smoltcp::socket::{UdpPacketMetadata, UdpSocket, UdpSocketBuffer};
 use smoltcp::time::{Duration, Instant};
@@ -56,7 +55,7 @@ fn main() {
     ];
 
     let medium = device.capabilities().medium;
-    let mut builder = InterfaceBuilder::new(device).ip_addrs(ip_addrs);
+    let mut builder = InterfaceBuilder::new(device, vec![]).ip_addrs(ip_addrs);
     if medium == Medium::Ethernet {
         builder = builder
             .hardware_addr(ethernet_addr.into())
@@ -64,17 +63,16 @@ fn main() {
     }
     let mut iface = builder.finalize();
 
-    let mut sockets = SocketSet::new(vec![]);
-    let udp_handle = sockets.add(udp_socket);
-    let tcp1_handle = sockets.add(tcp1_socket);
-    let tcp2_handle = sockets.add(tcp2_socket);
-    let tcp3_handle = sockets.add(tcp3_socket);
-    let tcp4_handle = sockets.add(tcp4_socket);
+    let udp_handle = iface.add_socket(udp_socket);
+    let tcp1_handle = iface.add_socket(tcp1_socket);
+    let tcp2_handle = iface.add_socket(tcp2_socket);
+    let tcp3_handle = iface.add_socket(tcp3_socket);
+    let tcp4_handle = iface.add_socket(tcp4_socket);
 
     let mut tcp_6970_active = false;
     loop {
         let timestamp = Instant::now();
-        match iface.poll(&mut sockets, timestamp) {
+        match iface.poll(timestamp) {
             Ok(_) => {}
             Err(e) => {
                 debug!("poll error: {}", e);
@@ -82,137 +80,127 @@ fn main() {
         }
 
         // udp:6969: respond "hello"
-        {
-            let mut socket = sockets.get::<UdpSocket>(udp_handle);
-            if !socket.is_open() {
-                socket.bind(6969).unwrap()
-            }
+        let socket = iface.get_socket::<UdpSocket>(udp_handle);
+        if !socket.is_open() {
+            socket.bind(6969).unwrap()
+        }
 
-            let client = match socket.recv() {
-                Ok((data, endpoint)) => {
-                    debug!(
-                        "udp:6969 recv data: {:?} from {}",
-                        str::from_utf8(data).unwrap(),
-                        endpoint
-                    );
-                    Some(endpoint)
-                }
-                Err(_) => None,
-            };
-            if let Some(endpoint) = client {
-                let data = b"hello\n";
+        let client = match socket.recv() {
+            Ok((data, endpoint)) => {
                 debug!(
-                    "udp:6969 send data: {:?}",
-                    str::from_utf8(data.as_ref()).unwrap()
+                    "udp:6969 recv data: {:?} from {}",
+                    str::from_utf8(data).unwrap(),
+                    endpoint
                 );
-                socket.send_slice(data, endpoint).unwrap();
+                Some(endpoint)
             }
+            Err(_) => None,
+        };
+        if let Some(endpoint) = client {
+            let data = b"hello\n";
+            debug!(
+                "udp:6969 send data: {:?}",
+                str::from_utf8(data.as_ref()).unwrap()
+            );
+            socket.send_slice(data, endpoint).unwrap();
         }
 
         // tcp:6969: respond "hello"
-        {
-            let mut socket = sockets.get::<TcpSocket>(tcp1_handle);
-            if !socket.is_open() {
-                socket.listen(6969).unwrap();
-            }
+        let socket = iface.get_socket::<TcpSocket>(tcp1_handle);
+        if !socket.is_open() {
+            socket.listen(6969).unwrap();
+        }
 
-            if socket.can_send() {
-                debug!("tcp:6969 send greeting");
-                writeln!(socket, "hello").unwrap();
-                debug!("tcp:6969 close");
-                socket.close();
-            }
+        if socket.can_send() {
+            debug!("tcp:6969 send greeting");
+            writeln!(socket, "hello").unwrap();
+            debug!("tcp:6969 close");
+            socket.close();
         }
 
         // tcp:6970: echo with reverse
-        {
-            let mut socket = sockets.get::<TcpSocket>(tcp2_handle);
-            if !socket.is_open() {
-                socket.listen(6970).unwrap()
-            }
+        let socket = iface.get_socket::<TcpSocket>(tcp2_handle);
+        if !socket.is_open() {
+            socket.listen(6970).unwrap()
+        }
 
-            if socket.is_active() && !tcp_6970_active {
-                debug!("tcp:6970 connected");
-            } else if !socket.is_active() && tcp_6970_active {
-                debug!("tcp:6970 disconnected");
-            }
-            tcp_6970_active = socket.is_active();
-
-            if socket.may_recv() {
-                let data = socket
-                    .recv(|buffer| {
-                        let recvd_len = buffer.len();
-                        let mut data = buffer.to_owned();
-                        if !data.is_empty() {
-                            debug!(
-                                "tcp:6970 recv data: {:?}",
-                                str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)")
-                            );
-                            data = data.split(|&b| b == b'\n').collect::<Vec<_>>().concat();
-                            data.reverse();
-                            data.extend(b"\n");
-                        }
-                        (recvd_len, data)
-                    })
-                    .unwrap();
-                if socket.can_send() && !data.is_empty() {
-                    debug!(
-                        "tcp:6970 send data: {:?}",
-                        str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)")
-                    );
-                    socket.send_slice(&data[..]).unwrap();
-                }
-            } else if socket.may_send() {
-                debug!("tcp:6970 close");
-                socket.close();
+        if socket.is_active() && !tcp_6970_active {
+            debug!("tcp:6970 connected");
+        } else if !socket.is_active() && tcp_6970_active {
+            debug!("tcp:6970 disconnected");
+        }
+        tcp_6970_active = socket.is_active();
+
+        if socket.may_recv() {
+            let data = socket
+                .recv(|buffer| {
+                    let recvd_len = buffer.len();
+                    let mut data = buffer.to_owned();
+                    if !data.is_empty() {
+                        debug!(
+                            "tcp:6970 recv data: {:?}",
+                            str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)")
+                        );
+                        data = data.split(|&b| b == b'\n').collect::<Vec<_>>().concat();
+                        data.reverse();
+                        data.extend(b"\n");
+                    }
+                    (recvd_len, data)
+                })
+                .unwrap();
+            if socket.can_send() && !data.is_empty() {
+                debug!(
+                    "tcp:6970 send data: {:?}",
+                    str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)")
+                );
+                socket.send_slice(&data[..]).unwrap();
             }
+        } else if socket.may_send() {
+            debug!("tcp:6970 close");
+            socket.close();
         }
 
         // tcp:6971: sinkhole
-        {
-            let mut socket = sockets.get::<TcpSocket>(tcp3_handle);
-            if !socket.is_open() {
-                socket.listen(6971).unwrap();
-                socket.set_keep_alive(Some(Duration::from_millis(1000)));
-                socket.set_timeout(Some(Duration::from_millis(2000)));
-            }
+        let socket = iface.get_socket::<TcpSocket>(tcp3_handle);
+        if !socket.is_open() {
+            socket.listen(6971).unwrap();
+            socket.set_keep_alive(Some(Duration::from_millis(1000)));
+            socket.set_timeout(Some(Duration::from_millis(2000)));
+        }
 
-            if socket.may_recv() {
-                socket
-                    .recv(|buffer| {
-                        if !buffer.is_empty() {
-                            debug!("tcp:6971 recv {:?} octets", buffer.len());
-                        }
-                        (buffer.len(), ())
-                    })
-                    .unwrap();
-            } else if socket.may_send() {
-                socket.close();
-            }
+        if socket.may_recv() {
+            socket
+                .recv(|buffer| {
+                    if !buffer.is_empty() {
+                        debug!("tcp:6971 recv {:?} octets", buffer.len());
+                    }
+                    (buffer.len(), ())
+                })
+                .unwrap();
+        } else if socket.may_send() {
+            socket.close();
         }
 
         // tcp:6972: fountain
-        {
-            let mut socket = sockets.get::<TcpSocket>(tcp4_handle);
-            if !socket.is_open() {
-                socket.listen(6972).unwrap()
-            }
+        let socket = iface.get_socket::<TcpSocket>(tcp4_handle);
+        if !socket.is_open() {
+            socket.listen(6972).unwrap()
+        }
 
-            if socket.may_send() {
-                socket
-                    .send(|data| {
-                        if !data.is_empty() {
-                            debug!("tcp:6972 send {:?} octets", data.len());
-                            for (i, b) in data.iter_mut().enumerate() {
-                                *b = (i % 256) as u8;
-                            }
+        if socket.may_send() {
+            socket
+                .send(|data| {
+                    if !data.is_empty() {
+                        debug!("tcp:6972 send {:?} octets", data.len());
+                        for (i, b) in data.iter_mut().enumerate() {
+                            *b = (i % 256) as u8;
                         }
-                        (data.len(), ())
-                    })
-                    .unwrap();
-            }
+                    }
+                    (data.len(), ())
+                })
+                .unwrap();
         }
 
-        phy_wait(fd, iface.poll_delay(&sockets, timestamp)).expect("wait error");
+        phy_wait(fd, iface.poll_delay(timestamp)).expect("wait error");
     }
 }

+ 23 - 27
examples/sixlowpan.rs

@@ -49,7 +49,6 @@ use std::str;
 
 use smoltcp::iface::{InterfaceBuilder, NeighborCache};
 use smoltcp::phy::{wait as phy_wait, Medium, RawSocket};
-use smoltcp::socket::SocketSet;
 use smoltcp::socket::{UdpPacketMetadata, UdpSocket, UdpSocketBuffer};
 use smoltcp::time::Instant;
 use smoltcp::wire::{Ieee802154Pan, IpAddress, IpCidr};
@@ -81,7 +80,7 @@ fn main() {
         64,
     )];
 
-    let mut builder = InterfaceBuilder::new(device)
+    let mut builder = InterfaceBuilder::new(device, vec![])
         .ip_addrs(ip_addrs)
         .pan_id(Ieee802154Pan(0xbeef));
     builder = builder
@@ -89,12 +88,11 @@ fn main() {
         .neighbor_cache(neighbor_cache);
     let mut iface = builder.finalize();
 
-    let mut sockets = SocketSet::new(vec![]);
-    let udp_handle = sockets.add(udp_socket);
+    let udp_handle = iface.add_socket(udp_socket);
 
     loop {
         let timestamp = Instant::now();
-        match iface.poll(&mut sockets, timestamp) {
+        match iface.poll(timestamp) {
             Ok(_) => {}
             Err(e) => {
                 debug!("poll error: {}", e);
@@ -102,33 +100,31 @@ fn main() {
         }
 
         // udp:6969: respond "hello"
-        {
-            let mut socket = sockets.get::<UdpSocket>(udp_handle);
-            if !socket.is_open() {
-                socket.bind(6969).unwrap()
-            }
+        let socket = iface.get_socket::<UdpSocket>(udp_handle);
+        if !socket.is_open() {
+            socket.bind(6969).unwrap()
+        }
 
-            let client = match socket.recv() {
-                Ok((data, endpoint)) => {
-                    debug!(
-                        "udp:6969 recv data: {:?} from {}",
-                        str::from_utf8(data).unwrap(),
-                        endpoint
-                    );
-                    Some(endpoint)
-                }
-                Err(_) => None,
-            };
-            if let Some(endpoint) = client {
-                let data = b"hello\n";
+        let client = match socket.recv() {
+            Ok((data, endpoint)) => {
                 debug!(
-                    "udp:6969 send data: {:?}",
-                    str::from_utf8(data.as_ref()).unwrap()
+                    "udp:6969 recv data: {:?} from {}",
+                    str::from_utf8(data).unwrap(),
+                    endpoint
                 );
-                socket.send_slice(data, endpoint).unwrap();
+                Some(endpoint)
             }
+            Err(_) => None,
+        };
+        if let Some(endpoint) = client {
+            let data = b"hello\n";
+            debug!(
+                "udp:6969 send data: {:?}",
+                str::from_utf8(data.as_ref()).unwrap()
+            );
+            socket.send_slice(data, endpoint).unwrap();
         }
 
-        phy_wait(fd, iface.poll_delay(&sockets, timestamp)).expect("wait error");
+        phy_wait(fd, iface.poll_delay(timestamp)).expect("wait error");
     }
 }

+ 170 - 156
src/iface/interface.rs

@@ -21,6 +21,7 @@ use crate::{Error, Result};
 /// a `&mut [T]`, or `Vec<T>` if a heap is available.
 pub struct Interface<'a, DeviceT: for<'d> Device<'d>> {
     device: DeviceT,
+    sockets: SocketSet<'a>,
     inner: InterfaceInner<'a>,
 }
 
@@ -63,6 +64,7 @@ pub struct InterfaceBuilder<'a, DeviceT: for<'d> Device<'d>> {
     #[cfg(feature = "medium-ieee802154")]
     pan_id: Option<Ieee802154Pan>,
     ip_addrs: ManagedSlice<'a, IpCidr>,
+    sockets: SocketSet<'a>,
     #[cfg(feature = "proto-ipv4")]
     any_ip: bool,
     routes: Routes<'a>,
@@ -96,7 +98,7 @@ let neighbor_cache = // ...
 # NeighborCache::new(BTreeMap::new());
 let ip_addrs = // ...
 # [];
-let iface = InterfaceBuilder::new(device)
+let iface = InterfaceBuilder::new(device, vec![])
         .hardware_addr(hw_addr.into())
         .neighbor_cache(neighbor_cache)
         .ip_addrs(ip_addrs)
@@ -104,9 +106,14 @@ let iface = InterfaceBuilder::new(device)
 ```
     "##
     )]
-    pub fn new(device: DeviceT) -> Self {
+    pub fn new<SocketsT>(device: DeviceT, sockets: SocketsT) -> Self
+    where
+        SocketsT: Into<ManagedSlice<'a, Option<SocketSetItem<'a>>>>,
+    {
         InterfaceBuilder {
             device: device,
+            sockets: SocketSet::new(sockets),
+
             #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))]
             hardware_addr: None,
             #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))]
@@ -285,6 +292,7 @@ let iface = InterfaceBuilder::new(device)
 
         Interface {
             device: self.device,
+            sockets: self.sockets,
             inner: InterfaceInner {
                 #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))]
                 hardware_addr,
@@ -461,6 +469,34 @@ impl<'a, DeviceT> Interface<'a, DeviceT>
 where
     DeviceT: for<'d> Device<'d>,
 {
+    /// Add a socket to the interface, and return its handle.
+    ///
+    /// # Panics
+    /// This function panics if the storage is fixed-size (not a `Vec`) and is full.
+    pub fn add_socket<T>(&mut self, socket: T) -> SocketHandle
+    where
+        T: Into<Socket<'a>>,
+    {
+        self.sockets.add(socket)
+    }
+
+    /// Get a socket from the interface by its handle, as mutable.
+    ///
+    /// # Panics
+    /// This function may panic if the handle does not belong to this socket set
+    /// or the socket has the wrong type.
+    pub fn get_socket<T: AnySocket<'a>>(&mut self, handle: SocketHandle) -> &mut T {
+        self.sockets.get(handle)
+    }
+
+    /// Remove a socket from the set, without changing its state.
+    ///
+    /// # Panics
+    /// This function may panic if the handle does not belong to this socket set.
+    pub fn remove_socket(&mut self, handle: SocketHandle) -> Socket<'a> {
+        self.sockets.remove(handle)
+    }
+
     /// Get the HardwareAddress address of the interface.
     ///
     /// # Panics
@@ -653,13 +689,13 @@ where
     /// packets containing any unsupported protocol, option, or form, which is
     /// a very common occurrence and on a production system it should not even
     /// be logged.
-    pub fn poll(&mut self, sockets: &mut SocketSet, timestamp: Instant) -> Result<bool> {
+    pub fn poll(&mut self, timestamp: Instant) -> Result<bool> {
         let cx = self.context(timestamp);
 
         let mut readiness_may_have_changed = false;
         loop {
-            let processed_any = self.socket_ingress(&cx, sockets);
-            let emitted_any = self.socket_egress(&cx, sockets)?;
+            let processed_any = self.socket_ingress(&cx);
+            let emitted_any = self.socket_egress(&cx)?;
 
             #[cfg(feature = "proto-igmp")]
             self.igmp_egress(&cx, timestamp)?;
@@ -681,10 +717,10 @@ where
     ///
     /// [poll]: #method.poll
     /// [Instant]: struct.Instant.html
-    pub fn poll_at(&self, sockets: &SocketSet, timestamp: Instant) -> Option<Instant> {
+    pub fn poll_at(&self, timestamp: Instant) -> Option<Instant> {
         let cx = self.context(timestamp);
 
-        sockets
+        self.sockets
             .iter()
             .filter_map(|socket| {
                 let socket_poll_at = socket.poll_at(&cx);
@@ -707,19 +743,20 @@ where
     ///
     /// [poll]: #method.poll
     /// [Duration]: struct.Duration.html
-    pub fn poll_delay(&self, sockets: &SocketSet, timestamp: Instant) -> Option<Duration> {
-        match self.poll_at(sockets, timestamp) {
+    pub fn poll_delay(&self, timestamp: Instant) -> Option<Duration> {
+        match self.poll_at(timestamp) {
             Some(poll_at) if timestamp < poll_at => Some(poll_at - timestamp),
             Some(_) => Some(Duration::from_millis(0)),
             _ => None,
         }
     }
 
-    fn socket_ingress(&mut self, cx: &Context, sockets: &mut SocketSet) -> bool {
+    fn socket_ingress(&mut self, cx: &Context) -> bool {
         let mut processed_any = false;
-        let &mut Self {
-            ref mut device,
-            ref mut inner,
+        let Self {
+            device,
+            inner,
+            sockets,
         } = self;
         while let Some((rx_token, tx_token)) = device.receive() {
             if let Err(err) = rx_token.consume(cx.now, |frame| match cx.caps.medium {
@@ -784,24 +821,25 @@ where
         processed_any
     }
 
-    fn socket_egress(&mut self, cx: &Context, sockets: &mut SocketSet) -> Result<bool> {
-        let _caps = self.device.capabilities();
+    fn socket_egress(&mut self, cx: &Context) -> Result<bool> {
+        let Self {
+            device,
+            inner,
+            sockets,
+        } = self;
+        let _caps = device.capabilities();
 
         let mut emitted_any = false;
-        for mut socket in sockets.iter_mut() {
+        for socket in sockets.iter_mut() {
             if !socket
                 .meta_mut()
-                .egress_permitted(cx.now, |ip_addr| self.inner.has_neighbor(cx, &ip_addr))
+                .egress_permitted(cx.now, |ip_addr| inner.has_neighbor(cx, &ip_addr))
             {
                 continue;
             }
 
             let mut neighbor_addr = None;
             let mut device_result = Ok(());
-            let &mut Self {
-                ref mut device,
-                ref mut inner,
-            } = self;
 
             macro_rules! respond {
                 ($response:expr) => {{
@@ -1164,7 +1202,7 @@ impl<'a> InterfaceInner<'a> {
 
                         // Look for UDP sockets that will accept the UDP packet.
                         // If it does not accept the packet, then send an ICMP message.
-                        for mut udp_socket in sockets.iter_mut().filter_map(UdpSocket::downcast) {
+                        for udp_socket in sockets.iter_mut().filter_map(UdpSocket::downcast) {
                             if !udp_socket.accepts(&IpRepr::Ipv6(ipv6_repr), &udp_repr) {
                                 continue;
                             }
@@ -1290,7 +1328,7 @@ impl<'a> InterfaceInner<'a> {
         let mut handled_by_raw_socket = false;
 
         // Pass every IP packet to all raw sockets we have registered.
-        for mut raw_socket in sockets.iter_mut().filter_map(RawSocket::downcast) {
+        for raw_socket in sockets.iter_mut().filter_map(RawSocket::downcast) {
             if !raw_socket.accepts(ip_repr) {
                 continue;
             }
@@ -1422,7 +1460,7 @@ impl<'a> InterfaceInner<'a> {
                 if udp_packet.src_port() == DHCP_SERVER_PORT
                     && udp_packet.dst_port() == DHCP_CLIENT_PORT
                 {
-                    if let Some(mut dhcp_socket) =
+                    if let Some(dhcp_socket) =
                         sockets.iter_mut().filter_map(Dhcpv4Socket::downcast).next()
                     {
                         let (src_addr, dst_addr) = (ip_repr.src_addr(), ip_repr.dst_addr());
@@ -1601,7 +1639,7 @@ impl<'a> InterfaceInner<'a> {
         let mut handled_by_icmp_socket = false;
 
         #[cfg(all(feature = "socket-icmp", feature = "proto-ipv6"))]
-        for mut icmp_socket in _sockets.iter_mut().filter_map(IcmpSocket::downcast) {
+        for icmp_socket in _sockets.iter_mut().filter_map(IcmpSocket::downcast) {
             if !icmp_socket.accepts(cx, &ip_repr, &icmp_repr.into()) {
                 continue;
             }
@@ -1787,7 +1825,7 @@ impl<'a> InterfaceInner<'a> {
         let mut handled_by_icmp_socket = false;
 
         #[cfg(all(feature = "socket-icmp", feature = "proto-ipv4"))]
-        for mut icmp_socket in _sockets.iter_mut().filter_map(IcmpSocket::downcast) {
+        for icmp_socket in _sockets.iter_mut().filter_map(IcmpSocket::downcast) {
             if !icmp_socket.accepts(cx, &ip_repr, &icmp_repr.into()) {
                 continue;
             }
@@ -1911,7 +1949,7 @@ impl<'a> InterfaceInner<'a> {
         let udp_repr = UdpRepr::parse(&udp_packet, &src_addr, &dst_addr, &cx.caps.checksum)?;
         let udp_payload = udp_packet.payload();
 
-        for mut udp_socket in sockets.iter_mut().filter_map(UdpSocket::downcast) {
+        for udp_socket in sockets.iter_mut().filter_map(UdpSocket::downcast) {
             if !udp_socket.accepts(&ip_repr, &udp_repr) {
                 continue;
             }
@@ -1968,7 +2006,7 @@ impl<'a> InterfaceInner<'a> {
         let tcp_packet = TcpPacket::new_checked(ip_payload)?;
         let tcp_repr = TcpRepr::parse(&tcp_packet, &src_addr, &dst_addr, &cx.caps.checksum)?;
 
-        for mut tcp_socket in sockets.iter_mut().filter_map(TcpSocket::downcast) {
+        for tcp_socket in sockets.iter_mut().filter_map(TcpSocket::downcast) {
             if !tcp_socket.accepts(&ip_repr, &tcp_repr) {
                 continue;
             }
@@ -2488,7 +2526,6 @@ mod test {
     #[cfg(feature = "medium-ethernet")]
     use crate::iface::NeighborCache;
     use crate::phy::{ChecksumCapabilities, Loopback};
-    use crate::socket::SocketSet;
     #[cfg(feature = "proto-igmp")]
     use crate::time::Instant;
     use crate::{Error, Result};
@@ -2500,7 +2537,7 @@ mod test {
         }
     }
 
-    fn create_loopback<'a>() -> (Interface<'a, Loopback>, SocketSet<'a>) {
+    fn create_loopback<'a>() -> Interface<'a, Loopback> {
         #[cfg(feature = "medium-ethernet")]
         return create_loopback_ethernet();
         #[cfg(not(feature = "medium-ethernet"))]
@@ -2509,7 +2546,7 @@ mod test {
 
     #[cfg(all(feature = "medium-ip"))]
     #[allow(unused)]
-    fn create_loopback_ip<'a>() -> (Interface<'a, Loopback>, SocketSet<'a>) {
+    fn create_loopback_ip<'a>() -> Interface<'a, Loopback> {
         // Create a basic device
         let device = Loopback::new(Medium::Ip);
         let ip_addrs = [
@@ -2521,16 +2558,14 @@ mod test {
             IpCidr::new(IpAddress::v6(0xfdbe, 0, 0, 0, 0, 0, 0, 1), 64),
         ];
 
-        let iface_builder = InterfaceBuilder::new(device).ip_addrs(ip_addrs);
+        let iface_builder = InterfaceBuilder::new(device, vec![]).ip_addrs(ip_addrs);
         #[cfg(feature = "proto-igmp")]
         let iface_builder = iface_builder.ipv4_multicast_groups(BTreeMap::new());
-        let iface = iface_builder.finalize();
-
-        (iface, SocketSet::new(vec![]))
+        iface_builder.finalize()
     }
 
     #[cfg(all(feature = "medium-ethernet"))]
-    fn create_loopback_ethernet<'a>() -> (Interface<'a, Loopback>, SocketSet<'a>) {
+    fn create_loopback_ethernet<'a>() -> Interface<'a, Loopback> {
         // Create a basic device
         let device = Loopback::new(Medium::Ethernet);
         let ip_addrs = [
@@ -2542,15 +2577,13 @@ mod test {
             IpCidr::new(IpAddress::v6(0xfdbe, 0, 0, 0, 0, 0, 0, 1), 64),
         ];
 
-        let iface_builder = InterfaceBuilder::new(device)
+        let iface_builder = InterfaceBuilder::new(device, vec![])
             .hardware_addr(EthernetAddress::default().into())
             .neighbor_cache(NeighborCache::new(BTreeMap::new()))
             .ip_addrs(ip_addrs);
         #[cfg(feature = "proto-igmp")]
         let iface_builder = iface_builder.ipv4_multicast_groups(BTreeMap::new());
-        let iface = iface_builder.finalize();
-
-        (iface, SocketSet::new(vec![]))
+        iface_builder.finalize()
     }
 
     #[cfg(feature = "proto-igmp")]
@@ -2583,13 +2616,13 @@ mod test {
     #[should_panic(expected = "hardware_addr required option was not set")]
     #[cfg(all(feature = "medium-ethernet"))]
     fn test_builder_initialization_panic() {
-        InterfaceBuilder::new(Loopback::new(Medium::Ethernet)).finalize();
+        InterfaceBuilder::new(Loopback::new(Medium::Ethernet), vec![]).finalize();
     }
 
     #[test]
     #[cfg(feature = "proto-ipv4")]
     fn test_no_icmp_no_unicast_ipv4() {
-        let (mut iface, mut socket_set) = create_loopback();
+        let mut iface = create_loopback();
 
         // Unknown Ipv4 Protocol
         //
@@ -2613,7 +2646,7 @@ mod test {
         // broadcast address
         let cx = iface.context(Instant::from_secs(0));
         assert_eq!(
-            iface.inner.process_ipv4(&cx, &mut socket_set, &frame),
+            iface.inner.process_ipv4(&cx, &mut iface.sockets, &frame),
             Ok(None)
         );
     }
@@ -2621,7 +2654,7 @@ mod test {
     #[test]
     #[cfg(feature = "proto-ipv6")]
     fn test_no_icmp_no_unicast_ipv6() {
-        let (mut iface, mut socket_set) = create_loopback();
+        let mut iface = create_loopback();
 
         // Unknown Ipv6 Protocol
         //
@@ -2645,7 +2678,7 @@ mod test {
         // broadcast address
         let cx = iface.context(Instant::from_secs(0));
         assert_eq!(
-            iface.inner.process_ipv6(&cx, &mut socket_set, &frame),
+            iface.inner.process_ipv6(&cx, &mut iface.sockets, &frame),
             Ok(None)
         );
     }
@@ -2654,7 +2687,7 @@ mod test {
     #[cfg(feature = "proto-ipv4")]
     fn test_icmp_error_no_payload() {
         static NO_BYTES: [u8; 0] = [];
-        let (mut iface, mut socket_set) = create_loopback();
+        let mut iface = create_loopback();
 
         // Unknown Ipv4 Protocol with no payload
         let repr = IpRepr::Ipv4(Ipv4Repr {
@@ -2698,7 +2731,7 @@ mod test {
         // And we correctly handle no payload.
         let cx = iface.context(Instant::from_secs(0));
         assert_eq!(
-            iface.inner.process_ipv4(&cx, &mut socket_set, &frame),
+            iface.inner.process_ipv4(&cx, &mut iface.sockets, &frame),
             Ok(Some(expected_repr))
         );
     }
@@ -2706,7 +2739,7 @@ mod test {
     #[test]
     #[cfg(feature = "proto-ipv4")]
     fn test_local_subnet_broadcasts() {
-        let (mut iface, _) = create_loopback();
+        let mut iface = create_loopback();
         iface.update_ip_addrs(|addrs| {
             addrs.iter_mut().next().map(|addr| {
                 *addr = IpCidr::Ipv4(Ipv4Cidr::new(Ipv4Address([192, 168, 1, 23]), 24));
@@ -2763,7 +2796,7 @@ mod test {
         static UDP_PAYLOAD: [u8; 12] = [
             0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, 0x57, 0x6f, 0x6c, 0x64, 0x21,
         ];
-        let (iface, mut socket_set) = create_loopback();
+        let mut iface = create_loopback();
 
         let mut udp_bytes_unicast = vec![0u8; 20];
         let mut udp_bytes_broadcast = vec![0u8; 20];
@@ -2825,7 +2858,7 @@ mod test {
         assert_eq!(
             iface
                 .inner
-                .process_udp(&cx, &mut socket_set, ip_repr, false, data),
+                .process_udp(&cx, &mut iface.sockets, ip_repr, false, data),
             Ok(Some(expected_repr))
         );
 
@@ -2853,7 +2886,7 @@ mod test {
         assert_eq!(
             iface.inner.process_udp(
                 &cx,
-                &mut socket_set,
+                &mut iface.sockets,
                 ip_repr,
                 false,
                 packet_broadcast.into_inner()
@@ -2870,7 +2903,7 @@ mod test {
 
         static UDP_PAYLOAD: [u8; 5] = [0x48, 0x65, 0x6c, 0x6c, 0x6f];
 
-        let (iface, mut socket_set) = create_loopback();
+        let mut iface = create_loopback();
 
         let rx_buffer = UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY], vec![0; 15]);
         let tx_buffer = UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY], vec![0; 15]);
@@ -2880,7 +2913,7 @@ mod test {
         let mut udp_bytes = vec![0u8; 13];
         let mut packet = UdpPacket::new_unchecked(&mut udp_bytes);
 
-        let socket_handle = socket_set.add(udp_socket);
+        let socket_handle = iface.add_socket(udp_socket);
 
         #[cfg(feature = "proto-ipv6")]
         let src_ip = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1);
@@ -2909,13 +2942,11 @@ mod test {
             hop_limit: 0x40,
         });
 
-        {
-            // Bind the socket to port 68
-            let mut socket = socket_set.get::<UdpSocket>(socket_handle);
-            assert_eq!(socket.bind(68), Ok(()));
-            assert!(!socket.can_recv());
-            assert!(socket.can_send());
-        }
+        // Bind the socket to port 68
+        let socket = iface.get_socket::<UdpSocket>(socket_handle);
+        assert_eq!(socket.bind(68), Ok(()));
+        assert!(!socket.can_recv());
+        assert!(socket.can_send());
 
         udp_repr.emit(
             &mut packet,
@@ -2931,20 +2962,18 @@ mod test {
         assert_eq!(
             iface
                 .inner
-                .process_udp(&cx, &mut socket_set, ip_repr, false, packet.into_inner()),
+                .process_udp(&cx, &mut iface.sockets, ip_repr, false, packet.into_inner()),
             Ok(None)
         );
 
-        {
-            // Make sure the payload to the UDP packet processed by process_udp is
-            // appended to the bound sockets rx_buffer
-            let mut socket = socket_set.get::<UdpSocket>(socket_handle);
-            assert!(socket.can_recv());
-            assert_eq!(
-                socket.recv(),
-                Ok((&UDP_PAYLOAD[..], IpEndpoint::new(src_ip.into(), 67)))
-            );
-        }
+        // Make sure the payload to the UDP packet processed by process_udp is
+        // appended to the bound sockets rx_buffer
+        let socket = iface.get_socket::<UdpSocket>(socket_handle);
+        assert!(socket.can_recv());
+        assert_eq!(
+            socket.recv(),
+            Ok((&UDP_PAYLOAD[..], IpEndpoint::new(src_ip.into(), 67)))
+        );
     }
 
     #[test]
@@ -2952,7 +2981,7 @@ mod test {
     fn test_handle_ipv4_broadcast() {
         use crate::wire::{Icmpv4Packet, Icmpv4Repr, Ipv4Packet};
 
-        let (mut iface, mut socket_set) = create_loopback();
+        let mut iface = create_loopback();
 
         let our_ipv4_addr = iface.ipv4_address().unwrap();
         let src_ipv4_addr = Ipv4Address([127, 0, 0, 2]);
@@ -3005,7 +3034,7 @@ mod test {
 
         let cx = iface.context(Instant::from_secs(0));
         assert_eq!(
-            iface.inner.process_ipv4(&cx, &mut socket_set, &frame),
+            iface.inner.process_ipv4(&cx, &mut iface.sockets, &frame),
             Ok(Some(expected_packet))
         );
     }
@@ -3025,7 +3054,7 @@ mod test {
         #[cfg(feature = "proto-ipv6")]
         const MAX_PAYLOAD_LEN: usize = 1192;
 
-        let (iface, mut socket_set) = create_loopback();
+        let mut iface = create_loopback();
 
         #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))]
         let src_addr = Ipv4Address([192, 168, 1, 1]);
@@ -3119,7 +3148,7 @@ mod test {
         assert_eq!(
             iface
                 .inner
-                .process_udp(&cx, &mut socket_set, ip_repr.into(), false, payload),
+                .process_udp(&cx, &mut iface.sockets, ip_repr.into(), false, payload),
             Ok(Some(IpPacket::Icmpv4((
                 expected_ip_repr,
                 expected_icmp_repr
@@ -3129,7 +3158,7 @@ mod test {
         assert_eq!(
             iface
                 .inner
-                .process_udp(&cx, &mut socket_set, ip_repr.into(), false, payload),
+                .process_udp(&cx, &mut iface.sockets, ip_repr.into(), false, payload),
             Ok(Some(IpPacket::Icmpv6((
                 expected_ip_repr,
                 expected_icmp_repr
@@ -3140,7 +3169,7 @@ mod test {
     #[test]
     #[cfg(all(feature = "medium-ethernet", feature = "proto-ipv4"))]
     fn test_handle_valid_arp_request() {
-        let (mut iface, mut socket_set) = create_loopback_ethernet();
+        let mut iface = create_loopback_ethernet();
 
         let mut eth_bytes = vec![0u8; 42];
 
@@ -3161,10 +3190,8 @@ mod test {
         frame.set_dst_addr(EthernetAddress::BROADCAST);
         frame.set_src_addr(remote_hw_addr);
         frame.set_ethertype(EthernetProtocol::Arp);
-        {
-            let mut packet = ArpPacket::new_unchecked(frame.payload_mut());
-            repr.emit(&mut packet);
-        }
+        let mut packet = ArpPacket::new_unchecked(frame.payload_mut());
+        repr.emit(&mut packet);
 
         let cx = iface.context(Instant::from_secs(0));
 
@@ -3172,7 +3199,7 @@ mod test {
         assert_eq!(
             iface
                 .inner
-                .process_ethernet(&cx, &mut socket_set, frame.into_inner()),
+                .process_ethernet(&cx, &mut iface.sockets, frame.into_inner()),
             Ok(Some(EthernetPacket::Arp(ArpRepr::EthernetIpv4 {
                 operation: ArpOperation::Reply,
                 source_hardware_addr: local_hw_addr,
@@ -3197,7 +3224,7 @@ mod test {
     #[test]
     #[cfg(all(feature = "medium-ethernet", feature = "proto-ipv6"))]
     fn test_handle_valid_ndisc_request() {
-        let (mut iface, mut socket_set) = create_loopback_ethernet();
+        let mut iface = create_loopback_ethernet();
 
         let mut eth_bytes = vec![0u8; 86];
 
@@ -3222,15 +3249,13 @@ mod test {
         frame.set_dst_addr(EthernetAddress([0x33, 0x33, 0x00, 0x00, 0x00, 0x00]));
         frame.set_src_addr(remote_hw_addr);
         frame.set_ethertype(EthernetProtocol::Ipv6);
-        {
-            ip_repr.emit(frame.payload_mut(), &ChecksumCapabilities::default());
-            solicit.emit(
-                &remote_ip_addr.into(),
-                &local_ip_addr.solicited_node().into(),
-                &mut Icmpv6Packet::new_unchecked(&mut frame.payload_mut()[ip_repr.buffer_len()..]),
-                &ChecksumCapabilities::default(),
-            );
-        }
+        ip_repr.emit(frame.payload_mut(), &ChecksumCapabilities::default());
+        solicit.emit(
+            &remote_ip_addr.into(),
+            &local_ip_addr.solicited_node().into(),
+            &mut Icmpv6Packet::new_unchecked(&mut frame.payload_mut()[ip_repr.buffer_len()..]),
+            &ChecksumCapabilities::default(),
+        );
 
         let icmpv6_expected = Icmpv6Repr::Ndisc(NdiscRepr::NeighborAdvert {
             flags: NdiscNeighborFlags::SOLICITED,
@@ -3252,7 +3277,7 @@ mod test {
         assert_eq!(
             iface
                 .inner
-                .process_ethernet(&cx, &mut socket_set, frame.into_inner()),
+                .process_ethernet(&cx, &mut iface.sockets, frame.into_inner()),
             Ok(Some(EthernetPacket::Ip(IpPacket::Icmpv6((
                 ipv6_expected,
                 icmpv6_expected
@@ -3274,7 +3299,7 @@ mod test {
     #[test]
     #[cfg(all(feature = "medium-ethernet", feature = "proto-ipv4"))]
     fn test_handle_other_arp_request() {
-        let (mut iface, mut socket_set) = create_loopback_ethernet();
+        let mut iface = create_loopback_ethernet();
 
         let mut eth_bytes = vec![0u8; 42];
 
@@ -3293,10 +3318,8 @@ mod test {
         frame.set_dst_addr(EthernetAddress::BROADCAST);
         frame.set_src_addr(remote_hw_addr);
         frame.set_ethertype(EthernetProtocol::Arp);
-        {
-            let mut packet = ArpPacket::new_unchecked(frame.payload_mut());
-            repr.emit(&mut packet);
-        }
+        let mut packet = ArpPacket::new_unchecked(frame.payload_mut());
+        repr.emit(&mut packet);
 
         let cx = iface.context(Instant::from_secs(0));
 
@@ -3304,7 +3327,7 @@ mod test {
         assert_eq!(
             iface
                 .inner
-                .process_ethernet(&cx, &mut socket_set, frame.into_inner()),
+                .process_ethernet(&cx, &mut iface.sockets, frame.into_inner()),
             Ok(None)
         );
 
@@ -3323,7 +3346,7 @@ mod test {
     #[test]
     #[cfg(all(feature = "medium-ethernet", feature = "proto-ipv4"))]
     fn test_arp_flush_after_update_ip() {
-        let (mut iface, mut socket_set) = create_loopback_ethernet();
+        let mut iface = create_loopback_ethernet();
 
         let mut eth_bytes = vec![0u8; 42];
 
@@ -3355,7 +3378,7 @@ mod test {
         assert_eq!(
             iface
                 .inner
-                .process_ethernet(&cx, &mut socket_set, frame.into_inner()),
+                .process_ethernet(&cx, &mut iface.sockets, frame.into_inner()),
             Ok(Some(EthernetPacket::Arp(ArpRepr::EthernetIpv4 {
                 operation: ArpOperation::Reply,
                 source_hardware_addr: local_hw_addr,
@@ -3396,24 +3419,22 @@ mod test {
         use crate::socket::{IcmpEndpoint, IcmpPacketMetadata, IcmpSocket, IcmpSocketBuffer};
         use crate::wire::Icmpv4Packet;
 
-        let (iface, mut socket_set) = create_loopback();
+        let mut iface = create_loopback();
 
         let rx_buffer = IcmpSocketBuffer::new(vec![IcmpPacketMetadata::EMPTY], vec![0; 24]);
         let tx_buffer = IcmpSocketBuffer::new(vec![IcmpPacketMetadata::EMPTY], vec![0; 24]);
 
         let icmpv4_socket = IcmpSocket::new(rx_buffer, tx_buffer);
 
-        let socket_handle = socket_set.add(icmpv4_socket);
+        let socket_handle = iface.add_socket(icmpv4_socket);
 
         let ident = 0x1234;
         let seq_no = 0x5432;
         let echo_data = &[0xff; 16];
 
-        {
-            let mut socket = socket_set.get::<IcmpSocket>(socket_handle);
-            // Bind to the ID 0x1234
-            assert_eq!(socket.bind(IcmpEndpoint::Ident(ident)), Ok(()));
-        }
+        let socket = iface.get_socket::<IcmpSocket>(socket_handle);
+        // Bind to the ID 0x1234
+        assert_eq!(socket.bind(IcmpEndpoint::Ident(ident)), Ok(()));
 
         // Ensure the ident we bound to and the ident of the packet are the same.
         let mut bytes = [0xff; 24];
@@ -3437,9 +3458,7 @@ mod test {
 
         // Open a socket and ensure the packet is handled due to the listening
         // socket.
-        {
-            assert!(!socket_set.get::<IcmpSocket>(socket_handle).can_recv());
-        }
+        assert!(!iface.get_socket::<IcmpSocket>(socket_handle).can_recv());
 
         // Confirm we still get EchoReply from `smoltcp` even with the ICMP socket listening
         let echo_reply = Icmpv4Repr::EchoReply {
@@ -3456,27 +3475,25 @@ mod test {
         assert_eq!(
             iface
                 .inner
-                .process_icmpv4(&cx, &mut socket_set, ip_repr, icmp_data),
+                .process_icmpv4(&cx, &mut iface.sockets, ip_repr, icmp_data),
             Ok(Some(IpPacket::Icmpv4((ipv4_reply, echo_reply))))
         );
 
-        {
-            let mut socket = socket_set.get::<IcmpSocket>(socket_handle);
-            assert!(socket.can_recv());
-            assert_eq!(
-                socket.recv(),
-                Ok((
-                    icmp_data,
-                    IpAddress::Ipv4(Ipv4Address::new(0x7f, 0x00, 0x00, 0x02))
-                ))
-            );
-        }
+        let socket = iface.get_socket::<IcmpSocket>(socket_handle);
+        assert!(socket.can_recv());
+        assert_eq!(
+            socket.recv(),
+            Ok((
+                icmp_data,
+                IpAddress::Ipv4(Ipv4Address::new(0x7f, 0x00, 0x00, 0x02))
+            ))
+        );
     }
 
     #[test]
     #[cfg(feature = "proto-ipv6")]
     fn test_solicited_node_addrs() {
-        let (mut iface, _) = create_loopback();
+        let mut iface = create_loopback();
         let mut new_addrs = vec![
             IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 1, 2, 0, 2), 64),
             IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 3, 4, 0, 0xffff), 64),
@@ -3499,7 +3516,7 @@ mod test {
     #[test]
     #[cfg(feature = "proto-ipv6")]
     fn test_icmpv6_nxthdr_unknown() {
-        let (mut iface, mut socket_set) = create_loopback();
+        let mut iface = create_loopback();
 
         let remote_ip_addr = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1);
 
@@ -3556,7 +3573,7 @@ mod test {
         // Ensure the unknown next header causes a ICMPv6 Parameter Problem
         // error message to be sent to the sender.
         assert_eq!(
-            iface.inner.process_ipv6(&cx, &mut socket_set, &frame),
+            iface.inner.process_ipv6(&cx, &mut iface.sockets, &frame),
             Ok(Some(IpPacket::Icmpv6((reply_ipv6_repr, reply_icmp_repr))))
         );
     }
@@ -3565,12 +3582,12 @@ mod test {
     #[cfg(feature = "proto-igmp")]
     fn test_handle_igmp() {
         fn recv_igmp(
-            mut iface: &mut Interface<'_, Loopback>,
+            iface: &mut Interface<'_, Loopback>,
             timestamp: Instant,
         ) -> Vec<(Ipv4Repr, IgmpRepr)> {
             let caps = iface.device.capabilities();
             let checksum_caps = &caps.checksum;
-            recv_all(&mut iface, timestamp)
+            recv_all(iface, timestamp)
                 .iter()
                 .filter_map(|frame| {
                     let ipv4_packet = match caps.medium {
@@ -3598,7 +3615,7 @@ mod test {
             Ipv4Address::new(224, 0, 0, 56),
         ];
 
-        let (mut iface, mut socket_set) = create_loopback();
+        let mut iface = create_loopback();
 
         // Join multicast groups
         let timestamp = Instant::now();
@@ -3643,7 +3660,7 @@ mod test {
         // GENERAL_QUERY_BYTES. Therefore `recv_all()` would return 0
         // pkts that could be checked.
         let cx = iface.context(timestamp);
-        iface.socket_ingress(&cx, &mut socket_set);
+        iface.socket_ingress(&cx);
 
         // Leave multicast groups
         let timestamp = Instant::now();
@@ -3666,7 +3683,7 @@ mod test {
         use crate::socket::{RawPacketMetadata, RawSocket, RawSocketBuffer};
         use crate::wire::{IpVersion, Ipv4Packet, UdpPacket, UdpRepr};
 
-        let (mut iface, mut socket_set) = create_loopback();
+        let mut iface = create_loopback();
 
         let packets = 1;
         let rx_buffer =
@@ -3676,7 +3693,7 @@ mod test {
             vec![0; 48 * packets],
         );
         let raw_socket = RawSocket::new(IpVersion::Ipv4, IpProtocol::Udp, rx_buffer, tx_buffer);
-        socket_set.add(raw_socket);
+        iface.add_socket(raw_socket);
 
         let src_addr = Ipv4Address([127, 0, 0, 2]);
         let dst_addr = Ipv4Address([127, 0, 0, 1]);
@@ -3725,7 +3742,7 @@ mod test {
 
         let cx = iface.context(Instant::from_millis(0));
         assert_eq!(
-            iface.inner.process_ipv4(&cx, &mut socket_set, &frame),
+            iface.inner.process_ipv4(&cx, &mut iface.sockets, &frame),
             Ok(None)
         );
     }
@@ -3736,7 +3753,7 @@ mod test {
         use crate::socket::{RawPacketMetadata, RawSocket, RawSocketBuffer};
         use crate::wire::{IpVersion, Ipv4Packet, UdpPacket, UdpRepr};
 
-        let (mut iface, mut socket_set) = create_loopback();
+        let mut iface = create_loopback();
 
         let packets = 1;
         let rx_buffer =
@@ -3746,7 +3763,7 @@ mod test {
             vec![0; 48 * packets],
         );
         let raw_socket = RawSocket::new(IpVersion::Ipv4, IpProtocol::Udp, rx_buffer, tx_buffer);
-        socket_set.add(raw_socket);
+        iface.add_socket(raw_socket);
 
         let src_addr = Ipv4Address([127, 0, 0, 2]);
         let dst_addr = Ipv4Address([127, 0, 0, 1]);
@@ -3795,7 +3812,7 @@ mod test {
         };
 
         let cx = iface.context(Instant::from_millis(0));
-        let frame = iface.inner.process_ipv4(&cx, &mut socket_set, &frame);
+        let frame = iface.inner.process_ipv4(&cx, &mut iface.sockets, &frame);
 
         // because the packet could not be handled we should send an Icmp message
         assert!(match frame {
@@ -3815,19 +3832,18 @@ mod test {
 
         static UDP_PAYLOAD: [u8; 5] = [0x48, 0x65, 0x6c, 0x6c, 0x6f];
 
-        let (mut iface, mut socket_set) = create_loopback();
+        let mut iface = create_loopback();
 
         let udp_rx_buffer = UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY], vec![0; 15]);
         let udp_tx_buffer = UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY], vec![0; 15]);
         let udp_socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer);
-        let udp_socket_handle = socket_set.add(udp_socket);
-        {
-            // Bind the socket to port 68
-            let mut socket = socket_set.get::<UdpSocket>(udp_socket_handle);
-            assert_eq!(socket.bind(68), Ok(()));
-            assert!(!socket.can_recv());
-            assert!(socket.can_send());
-        }
+        let udp_socket_handle = iface.add_socket(udp_socket);
+
+        // Bind the socket to port 68
+        let socket = iface.get_socket::<UdpSocket>(udp_socket_handle);
+        assert_eq!(socket.bind(68), Ok(()));
+        assert!(!socket.can_recv());
+        assert!(socket.can_send());
 
         let packets = 1;
         let raw_rx_buffer =
@@ -3842,7 +3858,7 @@ mod test {
             raw_rx_buffer,
             raw_tx_buffer,
         );
-        socket_set.add(raw_socket);
+        iface.add_socket(raw_socket);
 
         let src_addr = Ipv4Address([127, 0, 0, 2]);
         let dst_addr = Ipv4Address([127, 0, 0, 1]);
@@ -3890,18 +3906,16 @@ mod test {
 
         let cx = iface.context(Instant::from_millis(0));
         assert_eq!(
-            iface.inner.process_ipv4(&cx, &mut socket_set, &frame),
+            iface.inner.process_ipv4(&cx, &mut iface.sockets, &frame),
             Ok(None)
         );
 
-        {
-            // Make sure the UDP socket can still receive in presence of a Raw socket that handles UDP
-            let mut socket = socket_set.get::<UdpSocket>(udp_socket_handle);
-            assert!(socket.can_recv());
-            assert_eq!(
-                socket.recv(),
-                Ok((&UDP_PAYLOAD[..], IpEndpoint::new(src_addr.into(), 67)))
-            );
-        }
+        // Make sure the UDP socket can still receive in presence of a Raw socket that handles UDP
+        let socket = iface.get_socket::<UdpSocket>(udp_socket_handle);
+        assert!(socket.can_recv());
+        assert_eq!(
+            socket.recv(),
+            Ok((&UDP_PAYLOAD[..], IpEndpoint::new(src_addr.into(), 67)))
+        );
     }
 }

+ 7 - 9
src/socket/dhcpv4.rs

@@ -29,7 +29,7 @@ const PARAMETER_REQUEST_LIST: &[u8] = &[
 ];
 
 /// IPv4 configuration data provided by the DHCP server.
-#[derive(Debug, Eq, PartialEq)]
+#[derive(Debug, Eq, PartialEq, Clone, Copy)]
 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
 pub struct Config {
     /// IP address
@@ -103,11 +103,11 @@ enum ClientState {
 /// Return value for the `Dhcpv4Socket::poll` function
 #[derive(Debug, PartialEq, Eq)]
 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
-pub enum Event<'a> {
+pub enum Event {
     /// Configuration has been lost (for example, the lease has expired)
     Deconfigured,
     /// Configuration has been newly acquired, or modified.
-    Configured(&'a Config),
+    Configured(Config),
 }
 
 #[derive(Debug)]
@@ -544,12 +544,12 @@ impl Dhcpv4Socket {
     ///
     /// The socket has an internal "configuration changed" flag. If
     /// set, this function returns the configuration and resets the flag.
-    pub fn poll(&mut self) -> Option<Event<'_>> {
+    pub fn poll(&mut self) -> Option<Event> {
         if !self.config_changed {
             None
         } else if let ClientState::Renewing(state) = &self.state {
             self.config_changed = false;
-            Some(Event::Configured(&state.config))
+            Some(Event::Configured(state.config))
         } else {
             self.config_changed = false;
             Some(Event::Deconfigured)
@@ -626,9 +626,7 @@ mod test {
             });
         }
 
-        if i != reprs.len() {
-            panic!("Too few reprs emitted. Wanted {}, got {}", reprs.len(), i);
-        }
+        assert_eq!(i, reprs.len());
     }
 
     macro_rules! send {
@@ -836,7 +834,7 @@ mod test {
 
         assert_eq!(
             s.poll(),
-            Some(Event::Configured(&Config {
+            Some(Event::Configured(Config {
                 address: Ipv4Cidr::new(MY_IP, 24),
                 dns_servers: DNS_IPS,
                 router: Some(SERVER_IP),

+ 7 - 17
src/socket/mod.rs

@@ -24,7 +24,6 @@ mod icmp;
 mod meta;
 #[cfg(feature = "socket-raw")]
 mod raw;
-mod ref_;
 mod set;
 #[cfg(feature = "socket-tcp")]
 mod tcp;
@@ -59,9 +58,6 @@ pub use self::dhcpv4::{Config as Dhcpv4Config, Dhcpv4Socket, Event as Dhcpv4Even
 pub use self::set::{Handle as SocketHandle, Item as SocketSetItem, Set as SocketSet};
 pub use self::set::{Iter as SocketSetIter, IterMut as SocketSetIterMut};
 
-pub use self::ref_::Ref as SocketRef;
-pub(crate) use self::ref_::Session as SocketSession;
-
 /// Gives an indication on the next time the socket should be polled.
 #[derive(Debug, PartialOrd, Ord, PartialEq, Eq, Clone, Copy)]
 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
@@ -144,25 +140,19 @@ impl<'a> Socket<'a> {
     }
 }
 
-impl<'a> SocketSession for Socket<'a> {
-    fn finish(&mut self) {
-        dispatch_socket!(mut self, |socket| socket.finish())
-    }
-}
-
 /// A conversion trait for network sockets.
-pub trait AnySocket<'a>: SocketSession + Sized {
-    fn downcast<'c>(socket_ref: SocketRef<'c, Socket<'a>>) -> Option<SocketRef<'c, Self>>;
+pub trait AnySocket<'a>: Sized {
+    fn downcast<'c>(socket: &'c mut Socket<'a>) -> Option<&'c mut Self>;
 }
 
 macro_rules! from_socket {
     ($socket:ty, $variant:ident) => {
         impl<'a> AnySocket<'a> for $socket {
-            fn downcast<'c>(ref_: SocketRef<'c, Socket<'a>>) -> Option<SocketRef<'c, Self>> {
-                if let Socket::$variant(ref mut socket) = SocketRef::into_inner(ref_) {
-                    Some(SocketRef::new(socket))
-                } else {
-                    None
+            fn downcast<'c>(socket: &'c mut Socket<'a>) -> Option<&'c mut Self> {
+                #[allow(unreachable_patterns)]
+                match socket {
+                    Socket::$variant(socket) => Some(socket),
+                    _ => None,
                 }
             }
         }

+ 0 - 87
src/socket/ref_.rs

@@ -1,87 +0,0 @@
-use core::ops::{Deref, DerefMut};
-
-/// A trait for tracking a socket usage session.
-///
-/// Allows implementation of custom drop logic that runs only if the socket was changed
-/// in specific ways. For example, drop logic for UDP would check if the local endpoint
-/// has changed, and if yes, notify the socket set.
-#[doc(hidden)]
-pub trait Session {
-    fn finish(&mut self) {}
-}
-
-#[cfg(feature = "socket-raw")]
-impl<'a> Session for crate::socket::RawSocket<'a> {}
-#[cfg(all(
-    feature = "socket-icmp",
-    any(feature = "proto-ipv4", feature = "proto-ipv6")
-))]
-impl<'a> Session for crate::socket::IcmpSocket<'a> {}
-#[cfg(feature = "socket-udp")]
-impl<'a> Session for crate::socket::UdpSocket<'a> {}
-#[cfg(feature = "socket-tcp")]
-impl<'a> Session for crate::socket::TcpSocket<'a> {}
-#[cfg(feature = "socket-dhcpv4")]
-impl Session for crate::socket::Dhcpv4Socket {}
-
-/// A smart pointer to a socket.
-///
-/// Allows the network stack to efficiently determine if the socket state was changed in any way.
-pub struct Ref<'a, T: Session + 'a> {
-    /// Reference to the socket.
-    ///
-    /// This is almost always `Some` except when dropped in `into_inner` which removes the socket
-    /// reference. This properly tracks the initialization state without any additional bytes as
-    /// the `None` variant occupies the `0` pattern which is invalid for the reference.
-    socket: Option<&'a mut T>,
-}
-
-impl<'a, T: Session + 'a> Ref<'a, T> {
-    /// Wrap a pointer to a socket to make a smart pointer.
-    ///
-    /// Calling this function is only necessary if your code is using [into_inner].
-    ///
-    /// [into_inner]: #method.into_inner
-    pub fn new(socket: &'a mut T) -> Self {
-        Ref {
-            socket: Some(socket),
-        }
-    }
-
-    /// Unwrap a smart pointer to a socket.
-    ///
-    /// The finalization code is not run. Prompt operation of the network stack depends
-    /// on wrapping the returned pointer back and dropping it.
-    ///
-    /// Calling this function is only necessary to achieve composability if you *must*
-    /// map a `&mut SocketRef<'a, XSocket>` to a `&'a mut XSocket` (note the lifetimes);
-    /// be sure to call [new] afterwards.
-    ///
-    /// [new]: #method.new
-    pub fn into_inner(mut ref_: Self) -> &'a mut T {
-        ref_.socket.take().unwrap()
-    }
-}
-
-impl<'a, T: Session> Deref for Ref<'a, T> {
-    type Target = T;
-
-    fn deref(&self) -> &Self::Target {
-        // Deref is only used while the socket is still in place (into inner has not been called).
-        self.socket.as_ref().unwrap()
-    }
-}
-
-impl<'a, T: Session> DerefMut for Ref<'a, T> {
-    fn deref_mut(&mut self) -> &mut Self::Target {
-        self.socket.as_mut().unwrap()
-    }
-}
-
-impl<'a, T: Session> Drop for Ref<'a, T> {
-    fn drop(&mut self) {
-        if let Some(socket) = self.socket.take() {
-            Session::finish(socket);
-        }
-    }
-}

+ 10 - 80
src/socket/set.rs

@@ -1,9 +1,7 @@
 use core::{fmt, slice};
 use managed::ManagedSlice;
 
-#[cfg(feature = "socket-tcp")]
-use crate::socket::TcpState;
-use crate::socket::{AnySocket, Socket, SocketRef};
+use crate::socket::{AnySocket, Socket};
 
 /// An item of a socket set.
 ///
@@ -12,7 +10,6 @@ use crate::socket::{AnySocket, Socket, SocketRef};
 #[derive(Debug)]
 pub struct Item<'a> {
     socket: Socket<'a>,
-    refs: usize,
 }
 
 /// A handle, identifying a socket in a set.
@@ -44,7 +41,7 @@ impl<'a> Set<'a> {
         Set { sockets }
     }
 
-    /// Add a socket to the set with the reference count 1, and return its handle.
+    /// Add a socket to the set, and return its handle.
     ///
     /// # Panics
     /// This function panics if the storage is fixed-size (not a `Vec`) and is full.
@@ -56,7 +53,7 @@ impl<'a> Set<'a> {
             net_trace!("[{}]: adding", index);
             let handle = Handle(index);
             socket.meta_mut().handle = handle;
-            *slot = Some(Item { socket, refs: 1 });
+            *slot = Some(Item { socket });
             handle
         }
 
@@ -84,10 +81,11 @@ impl<'a> Set<'a> {
     /// # Panics
     /// This function may panic if the handle does not belong to this socket set
     /// or the socket has the wrong type.
-    pub fn get<T: AnySocket<'a>>(&mut self, handle: Handle) -> SocketRef<T> {
+    pub fn get<T: AnySocket<'a>>(&mut self, handle: Handle) -> &mut T {
         match self.sockets[handle.0].as_mut() {
-            Some(item) => T::downcast(SocketRef::new(&mut item.socket))
-                .expect("handle refers to a socket of a wrong type"),
+            Some(item) => {
+                T::downcast(&mut item.socket).expect("handle refers to a socket of a wrong type")
+            }
             None => panic!("handle does not refer to a valid socket"),
         }
     }
@@ -104,74 +102,6 @@ impl<'a> Set<'a> {
         }
     }
 
-    /// Increase reference count by 1.
-    ///
-    /// # Panics
-    /// This function may panic if the handle does not belong to this socket set.
-    pub fn retain(&mut self, handle: Handle) {
-        self.sockets[handle.0]
-            .as_mut()
-            .expect("handle does not refer to a valid socket")
-            .refs += 1
-    }
-
-    /// Decrease reference count by 1.
-    ///
-    /// # Panics
-    /// This function may panic if the handle does not belong to this socket set,
-    /// or if the reference count is already zero.
-    pub fn release(&mut self, handle: Handle) {
-        let refs = &mut self.sockets[handle.0]
-            .as_mut()
-            .expect("handle does not refer to a valid socket")
-            .refs;
-        if *refs == 0 {
-            panic!("decreasing reference count past zero")
-        }
-        *refs -= 1
-    }
-
-    /// Prune the sockets in this set.
-    ///
-    /// Pruning affects sockets with reference count 0. Open sockets are closed.
-    /// Closed sockets are removed and dropped.
-    pub fn prune(&mut self) {
-        for (index, item) in self.sockets.iter_mut().enumerate() {
-            let mut may_remove = false;
-            if let Some(Item {
-                refs: 0,
-                ref mut socket,
-            }) = *item
-            {
-                match *socket {
-                    #[cfg(feature = "socket-raw")]
-                    Socket::Raw(_) => may_remove = true,
-                    #[cfg(all(
-                        feature = "socket-icmp",
-                        any(feature = "proto-ipv4", feature = "proto-ipv6")
-                    ))]
-                    Socket::Icmp(_) => may_remove = true,
-                    #[cfg(feature = "socket-udp")]
-                    Socket::Udp(_) => may_remove = true,
-                    #[cfg(feature = "socket-tcp")]
-                    Socket::Tcp(ref mut socket) => {
-                        if socket.state() == TcpState::Closed {
-                            may_remove = true
-                        } else {
-                            socket.close()
-                        }
-                    }
-                    #[cfg(feature = "socket-dhcpv4")]
-                    Socket::Dhcpv4(_) => may_remove = true,
-                }
-            }
-            if may_remove {
-                net_trace!("[{}]: pruning", index);
-                *item = None
-            }
-        }
-    }
-
     /// Iterate every socket in this set.
     pub fn iter<'d>(&'d self) -> Iter<'d, 'a> {
         Iter {
@@ -179,7 +109,7 @@ impl<'a> Set<'a> {
         }
     }
 
-    /// Iterate every socket in this set, as SocketRef.
+    /// Iterate every socket in this set.
     pub fn iter_mut<'d>(&'d mut self) -> IterMut<'d, 'a> {
         IterMut {
             lower: self.sockets.iter_mut(),
@@ -217,12 +147,12 @@ pub struct IterMut<'a, 'b: 'a> {
 }
 
 impl<'a, 'b: 'a> Iterator for IterMut<'a, 'b> {
-    type Item = SocketRef<'a, Socket<'b>>;
+    type Item = &'a mut Socket<'b>;
 
     fn next(&mut self) -> Option<Self::Item> {
         for item_opt in &mut self.lower {
             if let Some(item) = item_opt.as_mut() {
-                return Some(SocketRef::new(&mut item.socket));
+                return Some(&mut item.socket);
             }
         }
         None