Browse Source

socket_set: add get_mut, make get immutable.

Dario Nieuwenhuis 2 years ago
parent
commit
9e18ca127e

+ 2 - 2
examples/benchmark.rs

@@ -120,7 +120,7 @@ fn main() {
         }
 
         // tcp:1234: emit data
-        let socket = sockets.get::<tcp::Socket>(tcp1_handle);
+        let socket = sockets.get_mut::<tcp::Socket>(tcp1_handle);
         if !socket.is_open() {
             socket.listen(1234).unwrap();
         }
@@ -138,7 +138,7 @@ fn main() {
         }
 
         // tcp:1235: sink data
-        let socket = sockets.get::<tcp::Socket>(tcp2_handle);
+        let socket = sockets.get_mut::<tcp::Socket>(tcp2_handle);
         if !socket.is_open() {
             socket.listen(1235).unwrap();
         }

+ 2 - 2
examples/client.rs

@@ -54,7 +54,7 @@ fn main() {
     let mut sockets = SocketSet::new(vec![]);
     let tcp_handle = sockets.add(tcp_socket);
 
-    let socket = sockets.get::<tcp::Socket>(tcp_handle);
+    let socket = sockets.get_mut::<tcp::Socket>(tcp_handle);
     socket
         .connect(iface.context(), (address, port), 49500)
         .unwrap();
@@ -69,7 +69,7 @@ fn main() {
             }
         }
 
-        let socket = sockets.get::<tcp::Socket>(tcp_handle);
+        let socket = sockets.get_mut::<tcp::Socket>(tcp_handle);
         if socket.is_active() && !tcp_active {
             debug!("connected");
         } else if !socket.is_active() && tcp_active {

+ 1 - 1
examples/dhcp_client.rs

@@ -60,7 +60,7 @@ fn main() {
             debug!("poll error: {}", e);
         }
 
-        let event = sockets.get::<dhcpv4::Socket>(dhcp_handle).poll();
+        let event = sockets.get_mut::<dhcpv4::Socket>(dhcp_handle).poll();
         match event {
             None => {}
             Some(dhcpv4::Event::Configured(config)) => {

+ 2 - 2
examples/dns.rs

@@ -67,7 +67,7 @@ fn main() {
     let mut sockets = SocketSet::new(vec![]);
     let dns_handle = sockets.add(dns_socket);
 
-    let socket = sockets.get::<dns::Socket>(dns_handle);
+    let socket = sockets.get_mut::<dns::Socket>(dns_handle);
     let query = socket.start_query(iface.context(), name).unwrap();
 
     loop {
@@ -82,7 +82,7 @@ fn main() {
         }
 
         match sockets
-            .get::<dns::Socket>(dns_handle)
+            .get_mut::<dns::Socket>(dns_handle)
             .get_query_result(query)
         {
             Ok(addrs) => {

+ 1 - 1
examples/httpclient.rs

@@ -76,7 +76,7 @@ fn main() {
             }
         }
 
-        let socket = sockets.get::<tcp::Socket>(tcp_handle);
+        let socket = sockets.get_mut::<tcp::Socket>(tcp_handle);
         let cx = iface.context();
 
         state = match state {

+ 2 - 2
examples/loopback.rs

@@ -128,7 +128,7 @@ fn main() {
             }
         }
 
-        let mut socket = sockets.get::<tcp::Socket>(server_handle);
+        let mut socket = sockets.get_mut::<tcp::Socket>(server_handle);
         if !socket.is_active() && !socket.is_listening() {
             if !did_listen {
                 debug!("listening");
@@ -146,7 +146,7 @@ fn main() {
             done = true;
         }
 
-        let mut socket = sockets.get::<tcp::Socket>(client_handle);
+        let mut socket = sockets.get_mut::<tcp::Socket>(client_handle);
         let cx = iface.context();
         if !socket.is_open() {
             if !did_connect {

+ 2 - 2
examples/multicast.rs

@@ -78,7 +78,7 @@ fn main() {
             }
         }
 
-        let socket = sockets.get::<raw::Socket>(raw_handle);
+        let socket = sockets.get_mut::<raw::Socket>(raw_handle);
 
         if socket.can_recv() {
             // For display purposes only - normally we wouldn't process incoming IGMP packets
@@ -95,7 +95,7 @@ fn main() {
             }
         }
 
-        let socket = sockets.get::<udp::Socket>(udp_handle);
+        let socket = sockets.get_mut::<udp::Socket>(udp_handle);
         if !socket.is_open() {
             socket.bind(MDNS_PORT).unwrap()
         }

+ 1 - 1
examples/ping.rs

@@ -157,7 +157,7 @@ fn main() {
         }
 
         let timestamp = Instant::now();
-        let socket = sockets.get::<icmp::Socket>(icmp_handle);
+        let socket = sockets.get_mut::<icmp::Socket>(icmp_handle);
         if !socket.is_open() {
             socket.bind(icmp::Endpoint::Ident(ident)).unwrap();
             send_at = timestamp;

+ 5 - 5
examples/server.rs

@@ -81,7 +81,7 @@ fn main() {
         }
 
         // udp:6969: respond "hello"
-        let socket = sockets.get::<udp::Socket>(udp_handle);
+        let socket = sockets.get_mut::<udp::Socket>(udp_handle);
         if !socket.is_open() {
             socket.bind(6969).unwrap()
         }
@@ -107,7 +107,7 @@ fn main() {
         }
 
         // tcp:6969: respond "hello"
-        let socket = sockets.get::<tcp::Socket>(tcp1_handle);
+        let socket = sockets.get_mut::<tcp::Socket>(tcp1_handle);
         if !socket.is_open() {
             socket.listen(6969).unwrap();
         }
@@ -120,7 +120,7 @@ fn main() {
         }
 
         // tcp:6970: echo with reverse
-        let socket = sockets.get::<tcp::Socket>(tcp2_handle);
+        let socket = sockets.get_mut::<tcp::Socket>(tcp2_handle);
         if !socket.is_open() {
             socket.listen(6970).unwrap()
         }
@@ -162,7 +162,7 @@ fn main() {
         }
 
         // tcp:6971: sinkhole
-        let socket = sockets.get::<tcp::Socket>(tcp3_handle);
+        let socket = sockets.get_mut::<tcp::Socket>(tcp3_handle);
         if !socket.is_open() {
             socket.listen(6971).unwrap();
             socket.set_keep_alive(Some(Duration::from_millis(1000)));
@@ -183,7 +183,7 @@ fn main() {
         }
 
         // tcp:6972: fountain
-        let socket = sockets.get::<tcp::Socket>(tcp4_handle);
+        let socket = sockets.get_mut::<tcp::Socket>(tcp4_handle);
         if !socket.is_open() {
             socket.listen(6972).unwrap()
         }

+ 3 - 3
examples/sixlowpan.rs

@@ -104,7 +104,7 @@ fn main() {
     let udp_handle = sockets.add(udp_socket);
     let tcp_handle = sockets.add(tcp_socket);
 
-    let socket = sockets.get::<tcp::Socket>(tcp_handle);
+    let socket = sockets.get_mut::<tcp::Socket>(tcp_handle);
     socket.listen(50000).unwrap();
 
     let mut tcp_active = false;
@@ -124,7 +124,7 @@ fn main() {
         }
 
         // udp:6969: respond "hello"
-        let socket = sockets.get::<udp::Socket>(udp_handle);
+        let socket = sockets.get_mut::<udp::Socket>(udp_handle);
         if !socket.is_open() {
             socket.bind(6969).unwrap()
         }
@@ -150,7 +150,7 @@ fn main() {
             socket.send_slice(&buffer[..len], endpoint).unwrap();
         }
 
-        let socket = sockets.get::<tcp::Socket>(tcp_handle);
+        let socket = sockets.get_mut::<tcp::Socket>(tcp_handle);
         if socket.is_active() && !tcp_active {
             debug!("connected");
         } else if !socket.is_active() && tcp_active {

+ 2 - 2
examples/sixlowpan_benchmark.rs

@@ -197,7 +197,7 @@ fn main() {
         }
 
         // tcp:1234: emit data
-        let socket = sockets.get::<tcp::Socket>(tcp1_handle);
+        let socket = sockets.get_mut::<tcp::Socket>(tcp1_handle);
         if !socket.is_open() {
             socket.listen(1234).unwrap();
         }
@@ -213,7 +213,7 @@ fn main() {
         }
 
         // tcp:1235: sink data
-        let socket = sockets.get::<tcp::Socket>(tcp2_handle);
+        let socket = sockets.get_mut::<tcp::Socket>(tcp2_handle);
         if !socket.is_open() {
             socket.listen(1235).unwrap();
         }

+ 16 - 15
src/iface/interface.rs

@@ -187,6 +187,7 @@ let iface = InterfaceBuilder::new()
 ```
     "##
     )]
+    #[allow(clippy::new_without_default)]
     pub fn new() -> Self {
         InterfaceBuilder {
             #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))]
@@ -1614,7 +1615,7 @@ impl<'a> InterfaceInner<'a> {
                         // normal IPv6 UDP payload, which is not what we have here.
                         for udp_socket in sockets
                             .items_mut()
-                            .filter_map(|i| udp::Socket::downcast(&mut i.socket))
+                            .filter_map(|i| udp::Socket::downcast_mut(&mut i.socket))
                         {
                             if udp_socket.accepts(self, &IpRepr::Ipv6(ipv6_repr), &udp_repr) {
                                 udp_socket.process(
@@ -1743,7 +1744,7 @@ impl<'a> InterfaceInner<'a> {
         // Pass every IP packet to all raw sockets we have registered.
         for raw_socket in sockets
             .items_mut()
-            .filter_map(|i| raw::Socket::downcast(&mut i.socket))
+            .filter_map(|i| raw::Socket::downcast_mut(&mut i.socket))
         {
             if raw_socket.accepts(ip_repr) {
                 raw_socket.process(self, ip_repr, ip_payload);
@@ -1861,7 +1862,7 @@ impl<'a> InterfaceInner<'a> {
                 {
                     if let Some(dhcp_socket) = sockets
                         .items_mut()
-                        .filter_map(|i| dhcpv4::Socket::downcast(&mut i.socket))
+                        .filter_map(|i| dhcpv4::Socket::downcast_mut(&mut i.socket))
                         .next()
                     {
                         let (src_addr, dst_addr) = (ip_repr.src_addr(), ip_repr.dst_addr());
@@ -2040,7 +2041,7 @@ impl<'a> InterfaceInner<'a> {
         #[cfg(all(feature = "socket-icmp", feature = "proto-ipv6"))]
         for icmp_socket in _sockets
             .items_mut()
-            .filter_map(|i| icmp::Socket::downcast(&mut i.socket))
+            .filter_map(|i| icmp::Socket::downcast_mut(&mut i.socket))
         {
             if icmp_socket.accepts(self, &ip_repr, &icmp_repr.into()) {
                 icmp_socket.process(self, &ip_repr, &icmp_repr.into());
@@ -2219,7 +2220,7 @@ impl<'a> InterfaceInner<'a> {
         #[cfg(all(feature = "socket-icmp", feature = "proto-ipv4"))]
         for icmp_socket in _sockets
             .items_mut()
-            .filter_map(|i| icmp::Socket::downcast(&mut i.socket))
+            .filter_map(|i| icmp::Socket::downcast_mut(&mut i.socket))
         {
             if icmp_socket.accepts(self, &ip_repr, &icmp_repr.into()) {
                 icmp_socket.process(self, &ip_repr, &icmp_repr.into());
@@ -2344,7 +2345,7 @@ impl<'a> InterfaceInner<'a> {
         #[cfg(feature = "socket-udp")]
         for udp_socket in sockets
             .items_mut()
-            .filter_map(|i| udp::Socket::downcast(&mut i.socket))
+            .filter_map(|i| udp::Socket::downcast_mut(&mut i.socket))
         {
             if udp_socket.accepts(self, &ip_repr, &udp_repr) {
                 udp_socket.process(self, &ip_repr, &udp_repr, udp_payload);
@@ -2355,7 +2356,7 @@ impl<'a> InterfaceInner<'a> {
         #[cfg(feature = "socket-dns")]
         for dns_socket in sockets
             .items_mut()
-            .filter_map(|i| dns::Socket::downcast(&mut i.socket))
+            .filter_map(|i| dns::Socket::downcast_mut(&mut i.socket))
         {
             if dns_socket.accepts(&ip_repr, &udp_repr) {
                 dns_socket.process(self, &ip_repr, &udp_repr, udp_payload);
@@ -2412,7 +2413,7 @@ impl<'a> InterfaceInner<'a> {
 
         for tcp_socket in sockets
             .items_mut()
-            .filter_map(|i| tcp::Socket::downcast(&mut i.socket))
+            .filter_map(|i| tcp::Socket::downcast_mut(&mut i.socket))
         {
             if tcp_socket.accepts(self, &ip_repr, &tcp_repr) {
                 return tcp_socket
@@ -3543,7 +3544,7 @@ mod test {
         });
 
         // Bind the socket to port 68
-        let socket = sockets.get::<udp::Socket>(socket_handle);
+        let socket = sockets.get_mut::<udp::Socket>(socket_handle);
         assert_eq!(socket.bind(68), Ok(()));
         assert!(!socket.can_recv());
         assert!(socket.can_send());
@@ -3567,7 +3568,7 @@ mod test {
 
         // Make sure the payload to the UDP packet processed by process_udp is
         // appended to the bound sockets rx_buffer
-        let socket = sockets.get::<udp::Socket>(socket_handle);
+        let socket = sockets.get_mut::<udp::Socket>(socket_handle);
         assert!(socket.can_recv());
         assert_eq!(
             socket.recv(),
@@ -4011,7 +4012,7 @@ mod test {
         let seq_no = 0x5432;
         let echo_data = &[0xff; 16];
 
-        let socket = sockets.get::<icmp::Socket>(socket_handle);
+        let socket = sockets.get_mut::<icmp::Socket>(socket_handle);
         // Bind to the ID 0x1234
         assert_eq!(socket.bind(icmp::Endpoint::Ident(ident)), Ok(()));
 
@@ -4037,7 +4038,7 @@ mod test {
 
         // Open a socket and ensure the packet is handled due to the listening
         // socket.
-        assert!(!sockets.get::<icmp::Socket>(socket_handle).can_recv());
+        assert!(!sockets.get_mut::<icmp::Socket>(socket_handle).can_recv());
 
         // Confirm we still get EchoReply from `smoltcp` even with the ICMP socket listening
         let echo_reply = Icmpv4Repr::EchoReply {
@@ -4055,7 +4056,7 @@ mod test {
             Some(IpPacket::Icmpv4((ipv4_reply, echo_reply)))
         );
 
-        let socket = sockets.get::<icmp::Socket>(socket_handle);
+        let socket = sockets.get_mut::<icmp::Socket>(socket_handle);
         assert!(socket.can_recv());
         assert_eq!(
             socket.recv(),
@@ -4331,7 +4332,7 @@ mod test {
         let udp_socket_handle = sockets.add(udp_socket);
 
         // Bind the socket to port 68
-        let socket = sockets.get::<udp::Socket>(udp_socket_handle);
+        let socket = sockets.get_mut::<udp::Socket>(udp_socket_handle);
         assert_eq!(socket.bind(68), Ok(()));
         assert!(!socket.can_recv());
         assert!(socket.can_send());
@@ -4398,7 +4399,7 @@ mod test {
         assert_eq!(iface.inner.process_ipv4(&mut sockets, &frame), None);
 
         // Make sure the UDP socket can still receive in presence of a Raw socket that handles UDP
-        let socket = sockets.get::<udp::Socket>(udp_socket_handle);
+        let socket = sockets.get_mut::<udp::Socket>(udp_socket_handle);
         assert!(socket.can_recv());
         assert_eq!(
             socket.recv(),

+ 16 - 3
src/iface/socket_set.rs

@@ -93,15 +93,28 @@ impl<'a> SocketSet<'a> {
     /// # Panics
     /// This function may panic if the handle does not belong to this socket set
     /// or the socket has the wrong type.
-    pub fn get<T: AnySocket<'a>>(&mut self, handle: SocketHandle) -> &mut T {
-        match self.sockets[handle.0].inner.as_mut() {
+    pub fn get<T: AnySocket<'a>>(&self, handle: SocketHandle) -> &T {
+        match self.sockets[handle.0].inner.as_ref() {
             Some(item) => {
-                T::downcast(&mut item.socket).expect("handle refers to a socket of a wrong type")
+                T::downcast(&item.socket).expect("handle refers to a socket of a wrong type")
             }
             None => panic!("handle does not refer to a valid socket"),
         }
     }
 
+    /// Get a mutable socket from the set by its handle, as mutable.
+    ///
+    /// # Panics
+    /// This function may panic if the handle does not belong to this socket set
+    /// or the socket has the wrong type.
+    pub fn get_mut<T: AnySocket<'a>>(&mut self, handle: SocketHandle) -> &mut T {
+        match self.sockets[handle.0].inner.as_mut() {
+            Some(item) => T::downcast_mut(&mut item.socket)
+                .expect("handle refers to a socket of a wrong type"),
+            None => panic!("handle does not refer to a valid socket"),
+        }
+    }
+
     /// Remove a socket from the set, without changing its state.
     ///
     /// # Panics

+ 11 - 2
src/socket/mod.rs

@@ -93,7 +93,8 @@ impl<'a> Socket<'a> {
 /// A conversion trait for network sockets.
 pub trait AnySocket<'a>: Sized {
     fn upcast(self) -> Socket<'a>;
-    fn downcast<'c>(socket: &'c mut Socket<'a>) -> Option<&'c mut Self>;
+    fn downcast<'c>(socket: &'c Socket<'a>) -> Option<&'c Self>;
+    fn downcast_mut<'c>(socket: &'c mut Socket<'a>) -> Option<&'c mut Self>;
 }
 
 macro_rules! from_socket {
@@ -103,7 +104,15 @@ macro_rules! from_socket {
                 Socket::$variant(self)
             }
 
-            fn downcast<'c>(socket: &'c mut Socket<'a>) -> Option<&'c mut Self> {
+            fn downcast<'c>(socket: &'c Socket<'a>) -> Option<&'c Self> {
+                #[allow(unreachable_patterns)]
+                match socket {
+                    Socket::$variant(socket) => Some(socket),
+                    _ => None,
+                }
+            }
+
+            fn downcast_mut<'c>(socket: &'c mut Socket<'a>) -> Option<&'c mut Self> {
                 #[allow(unreachable_patterns)]
                 match socket {
                     Socket::$variant(socket) => Some(socket),