Răsfoiți Sursa

Implement a SocketRef smart pointer to detect state changes.

Egor Karavaev 7 ani în urmă
părinte
comite
096ce02ac4
8 a modificat fișierele cu 145 adăugiri și 87 ștergeri
  1. 3 4
      examples/client.rs
  2. 3 4
      examples/loopback.rs
  3. 2 3
      examples/ping.rs
  4. 6 6
      examples/server.rs
  5. 10 13
      src/iface/ethernet.rs
  6. 32 34
      src/socket/mod.rs
  7. 73 0
      src/socket/ref_.rs
  8. 16 23
      src/socket/set.rs

+ 3 - 4
examples/client.rs

@@ -12,8 +12,7 @@ use std::os::unix::io::AsRawFd;
 use smoltcp::phy::wait as phy_wait;
 use smoltcp::wire::{EthernetAddress, Ipv4Address, IpAddress, IpCidr};
 use smoltcp::iface::{ArpCache, SliceArpCache, EthernetInterface};
-use smoltcp::socket::{AsSocket, SocketSet};
-use smoltcp::socket::{TcpSocket, TcpSocketBuffer};
+use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer};
 
 fn main() {
     utils::setup_logging("");
@@ -50,14 +49,14 @@ fn main() {
     let tcp_handle = sockets.add(tcp_socket);
 
     {
-        let socket: &mut TcpSocket = sockets.get_mut(tcp_handle).as_socket();
+        let mut socket = sockets.get::<TcpSocket>(tcp_handle);
         socket.connect((address, port), 49500).unwrap();
     }
 
     let mut tcp_active = false;
     loop {
         {
-            let socket: &mut TcpSocket = sockets.get_mut(tcp_handle).as_socket();
+            let mut socket = sockets.get::<TcpSocket>(tcp_handle);
             if socket.is_active() && !tcp_active {
                 debug!("connected");
             } else if !socket.is_active() && tcp_active {

+ 3 - 4
examples/loopback.rs

@@ -19,8 +19,7 @@ use core::str;
 use smoltcp::phy::Loopback;
 use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr};
 use smoltcp::iface::{ArpCache, SliceArpCache, EthernetInterface};
-use smoltcp::socket::{AsSocket, SocketSet};
-use smoltcp::socket::{TcpSocket, TcpSocketBuffer};
+use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer};
 
 #[cfg(not(feature = "std"))]
 mod mock {
@@ -124,7 +123,7 @@ fn main() {
     let mut done = false;
     while !done && clock.elapsed() < 10_000 {
         {
-            let socket: &mut TcpSocket = socket_set.get_mut(server_handle).as_socket();
+            let mut socket = socket_set.get::<TcpSocket>(server_handle);
             if !socket.is_active() && !socket.is_listening() {
                 if !did_listen {
                     debug!("listening");
@@ -141,7 +140,7 @@ fn main() {
         }
 
         {
-            let socket: &mut TcpSocket = socket_set.get_mut(client_handle).as_socket();
+            let mut socket = socket_set.get::<TcpSocket>(client_handle);
             if !socket.is_open() {
                 if !did_connect {
                     debug!("connecting");

+ 2 - 3
examples/ping.rs

@@ -16,8 +16,7 @@ use smoltcp::wire::{EthernetAddress, IpVersion, IpProtocol, IpAddress, IpCidr,
                     Ipv4Address, Ipv4Packet, Ipv4Repr,
                     Icmpv4Repr, Icmpv4Packet};
 use smoltcp::iface::{ArpCache, SliceArpCache, EthernetInterface};
-use smoltcp::socket::{AsSocket, SocketSet};
-use smoltcp::socket::{RawSocket, RawSocketBuffer, RawPacketBuffer};
+use smoltcp::socket::{SocketSet, RawSocket, RawSocketBuffer, RawPacketBuffer};
 use std::collections::HashMap;
 use byteorder::{ByteOrder, NetworkEndian};
 
@@ -75,7 +74,7 @@ fn main() {
 
     loop {
         {
-            let socket: &mut RawSocket = sockets.get_mut(raw_handle).as_socket();
+            let mut socket = sockets.get::<RawSocket>(raw_handle);
 
             let timestamp = Instant::now().duration_since(startup_time);
             let timestamp_us = (timestamp.as_secs() * 1000000) +

+ 6 - 6
examples/server.rs

@@ -13,7 +13,7 @@ use std::os::unix::io::AsRawFd;
 use smoltcp::phy::wait as phy_wait;
 use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr};
 use smoltcp::iface::{ArpCache, SliceArpCache, EthernetInterface};
-use smoltcp::socket::{AsSocket, SocketSet};
+use smoltcp::socket::SocketSet;
 use smoltcp::socket::{UdpSocket, UdpSocketBuffer, UdpPacketBuffer};
 use smoltcp::socket::{TcpSocket, TcpSocketBuffer};
 
@@ -70,7 +70,7 @@ fn main() {
     loop {
         // udp:6969: respond "hello"
         {
-            let socket: &mut UdpSocket = sockets.get_mut(udp_handle).as_socket();
+            let mut socket = sockets.get::<UdpSocket>(udp_handle);
             if !socket.is_open() {
                 socket.bind(6969).unwrap()
             }
@@ -93,7 +93,7 @@ fn main() {
 
         // tcp:6969: respond "hello"
         {
-            let socket: &mut TcpSocket = sockets.get_mut(tcp1_handle).as_socket();
+            let mut socket = sockets.get::<TcpSocket>(tcp1_handle);
             if !socket.is_open() {
                 socket.listen(6969).unwrap();
             }
@@ -108,7 +108,7 @@ fn main() {
 
         // tcp:6970: echo with reverse
         {
-            let socket: &mut TcpSocket = sockets.get_mut(tcp2_handle).as_socket();
+            let mut socket = sockets.get::<TcpSocket>(tcp2_handle);
             if !socket.is_open() {
                 socket.listen(6970).unwrap()
             }
@@ -145,7 +145,7 @@ fn main() {
 
         // tcp:6971: sinkhole
         {
-            let socket: &mut TcpSocket = sockets.get_mut(tcp3_handle).as_socket();
+            let mut socket = sockets.get::<TcpSocket>(tcp3_handle);
             if !socket.is_open() {
                 socket.listen(6971).unwrap();
                 socket.set_keep_alive(Some(1000));
@@ -165,7 +165,7 @@ fn main() {
 
         // tcp:6972: fountain
         {
-            let socket: &mut TcpSocket = sockets.get_mut(tcp4_handle).as_socket();
+            let mut socket = sockets.get::<TcpSocket>(tcp4_handle);
             if !socket.is_open() {
                 socket.listen(6972).unwrap()
             }

+ 10 - 13
src/iface/ethernet.rs

@@ -13,7 +13,7 @@ use wire::{Ipv4Packet, Ipv4Repr};
 use wire::{Icmpv4Packet, Icmpv4Repr, Icmpv4DstUnreachable};
 #[cfg(feature = "socket-udp")] use wire::{UdpPacket, UdpRepr};
 #[cfg(feature = "socket-tcp")] use wire::{TcpPacket, TcpRepr, TcpControl};
-use socket::{Socket, SocketSet, AsSocket};
+use socket::{Socket, SocketSet, AnySocket};
 #[cfg(feature = "socket-raw")] use socket::RawSocket;
 #[cfg(feature = "socket-udp")] use socket::UdpSocket;
 #[cfg(feature = "socket-tcp")] use socket::TcpSocket;
@@ -195,29 +195,29 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
         let mut caps = self.device.capabilities();
         caps.max_transmission_unit -= EthernetFrame::<&[u8]>::header_len();
 
-        for socket in sockets.iter_mut() {
+        for mut socket in sockets.iter_mut() {
             let mut device_result = Ok(());
             let socket_result =
-                match socket {
+                match *socket {
                     #[cfg(feature = "socket-raw")]
-                    &mut Socket::Raw(ref mut socket) =>
+                    Socket::Raw(ref mut socket) =>
                         socket.dispatch(|response| {
                             device_result = self.dispatch(timestamp, Packet::Raw(response));
                             device_result
                         }, &caps.checksum),
                     #[cfg(feature = "socket-udp")]
-                    &mut Socket::Udp(ref mut socket) =>
+                    Socket::Udp(ref mut socket) =>
                         socket.dispatch(|response| {
                             device_result = self.dispatch(timestamp, Packet::Udp(response));
                             device_result
                         }),
                     #[cfg(feature = "socket-tcp")]
-                    &mut Socket::Tcp(ref mut socket) =>
+                    Socket::Tcp(ref mut socket) =>
                         socket.dispatch(timestamp, &caps, |response| {
                             device_result = self.dispatch(timestamp, Packet::Tcp(response));
                             device_result
                         }),
-                    &mut Socket::__Nonexhaustive(_) => unreachable!()
+                    Socket::__Nonexhaustive(_) => unreachable!()
                 };
             match (device_result, socket_result) {
                 (Err(Error::Unaddressable), _) => break, // no one to transmit to
@@ -323,8 +323,7 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
 
         // Pass every IP packet to all raw sockets we have registered.
         #[cfg(feature = "socket-raw")]
-        for raw_socket in sockets.iter_mut().filter_map(
-                <Socket as AsSocket<RawSocket>>::try_as_socket) {
+        for mut raw_socket in sockets.iter_mut().filter_map(RawSocket::downcast) {
             if !raw_socket.accepts(&ip_repr) { continue }
 
             match raw_socket.process(&ip_repr, ip_payload, &checksum_caps) {
@@ -415,8 +414,7 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
         let checksum_caps = self.device.capabilities().checksum;
         let udp_repr = UdpRepr::parse(&udp_packet, &src_addr, &dst_addr, &checksum_caps)?;
 
-        for udp_socket in sockets.iter_mut().filter_map(
-                <Socket as AsSocket<UdpSocket>>::try_as_socket) {
+        for mut udp_socket in sockets.iter_mut().filter_map(UdpSocket::downcast) {
             if !udp_socket.accepts(&ip_repr, &udp_repr) { continue }
 
             match udp_socket.process(&ip_repr, &udp_repr) {
@@ -458,8 +456,7 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
         let checksum_caps = self.device.capabilities().checksum;
         let tcp_repr = TcpRepr::parse(&tcp_packet, &src_addr, &dst_addr, &checksum_caps)?;
 
-        for tcp_socket in sockets.iter_mut().filter_map(
-                <Socket as AsSocket<TcpSocket>>::try_as_socket) {
+        for mut tcp_socket in sockets.iter_mut().filter_map(TcpSocket::downcast) {
             if !tcp_socket.accepts(&ip_repr, &tcp_repr) { continue }
 
             match tcp_socket.process(timestamp, &ip_repr, &tcp_repr) {

+ 32 - 34
src/socket/mod.rs

@@ -17,6 +17,7 @@ use wire::IpRepr;
 #[cfg(feature = "socket-udp")] mod udp;
 #[cfg(feature = "socket-tcp")] mod tcp;
 mod set;
+mod ref_;
 
 #[cfg(feature = "socket-raw")]
 pub use self::raw::{PacketBuffer as RawPacketBuffer,
@@ -36,19 +37,19 @@ pub use self::tcp::{SocketBuffer as TcpSocketBuffer,
 pub use self::set::{Set as SocketSet, Item as SocketSetItem, Handle as SocketHandle};
 pub use self::set::{Iter as SocketSetIter, IterMut as SocketSetIterMut};
 
+pub use self::ref_::Ref as SocketRef;
+pub(crate) use self::ref_::Session as SocketSession;
+
 /// A network socket.
 ///
 /// This enumeration abstracts the various types of sockets based on the IP protocol.
-/// To downcast a `Socket` value down to a concrete socket, use
-/// the [AsSocket](trait.AsSocket.html) trait, and call e.g. `socket.as_socket::<UdpSocket<_>>()`.
+/// To downcast a `Socket` value to a concrete socket, use the [AnySocket] trait,
+/// e.g. to get `UdpSocket`, call `UdpSocket::downcast(socket)`.
+///
+/// It is usually more convenient to use [SocketSet::get] instead.
 ///
-/// The `process` and `dispatch` functions are fundamentally asymmetric and thus differ in
-/// their use of the [trait PacketRepr](trait.PacketRepr.html). When `process` is called,
-/// the packet length is already known and no allocation is required; on the other hand,
-/// `process` would have to downcast a `&PacketRepr` to e.g. an `&UdpRepr` through `Any`,
-/// which is rather inelegant. Conversely, when `dispatch` is called, the packet length is
-/// not yet known and the packet storage has to be allocated; but the `&PacketRepr` is sufficient
-/// since the lower layers treat the packet as an opaque octet sequence.
+/// [AnySocket]: trait.AnySocket.html
+/// [SocketSet::get]: struct.SocketSet.html#method.get
 #[derive(Debug)]
 pub enum Socket<'a, 'b: 'a> {
     #[cfg(feature = "socket-raw")]
@@ -90,40 +91,37 @@ impl<'a, 'b> Socket<'a, 'b> {
     }
 }
 
+impl<'a, 'b> SocketSession for Socket<'a, 'b> {
+    fn finish(&mut self) {
+        dispatch_socket!(self, |socket [mut]| socket.finish())
+    }
+}
+
 /// A conversion trait for network sockets.
-///
-/// This trait is used to concisely downcast [Socket](trait.Socket.html) values to their
-/// concrete types.
-pub trait AsSocket<T> {
-    fn as_socket(&mut self) -> &mut T;
-    fn try_as_socket(&mut self) -> Option<&mut T>;
+pub trait AnySocket<'a, 'b>: SocketSession + Sized {
+    fn downcast<'c>(socket_ref: SocketRef<'c, Socket<'a, 'b>>) ->
+                   Option<SocketRef<'c, Self>>;
 }
 
-macro_rules! as_socket {
+macro_rules! from_socket {
     ($socket:ty, $variant:ident) => {
-        impl<'a, 'b> AsSocket<$socket> for Socket<'a, 'b> {
-            fn as_socket(&mut self) -> &mut $socket {
-                match self {
-                    &mut Socket::$variant(ref mut socket) => socket,
-                    _ => panic!(concat!(".as_socket::<",
-                                        stringify!($socket),
-                                        "> called on wrong socket type"))
-                }
-            }
-
-            fn try_as_socket(&mut self) -> Option<&mut $socket> {
-                match self {
-                    &mut Socket::$variant(ref mut socket) => Some(socket),
-                    _ => None,
-                }
+        impl<'a, 'b> AnySocket<'a, 'b> for $socket {
+            fn downcast<'c>(ref_: SocketRef<'c, Socket<'a, 'b>>) ->
+                           Option<SocketRef<'c, Self>> {
+                SocketRef::map(ref_, |socket| {
+                    match *socket {
+                        Socket::$variant(ref mut socket) => Some(socket),
+                        _ => None,
+                    }
+                })
             }
         }
     }
 }
 
 #[cfg(feature = "socket-raw")]
-as_socket!(RawSocket<'a, 'b>, Raw);
+from_socket!(RawSocket<'a, 'b>, Raw);
 #[cfg(feature = "socket-udp")]
-as_socket!(UdpSocket<'a, 'b>, Udp);
+from_socket!(UdpSocket<'a, 'b>, Udp);
 #[cfg(feature = "socket-tcp")]
-as_socket!(TcpSocket<'a>, Tcp);
+from_socket!(TcpSocket<'a>, Tcp);

+ 73 - 0
src/socket/ref_.rs

@@ -0,0 +1,73 @@
+use core::ops::{Deref, DerefMut};
+
+#[cfg(feature = "socket-raw")]
+use socket::RawSocket;
+#[cfg(feature = "socket-udp")]
+use socket::UdpSocket;
+#[cfg(feature = "socket-tcp")]
+use socket::TcpSocket;
+
+/// A trait for tracking a socket usage session.
+///
+/// Allows implementation of custom drop logic that runs only if the socket was changed
+/// in specific ways. For example, drop logic for UDP would check if the local endpoint
+/// has changed, and if yes, notify the socket set.
+#[doc(hidden)]
+pub trait Session {
+    fn finish(&mut self) {}
+}
+
+#[cfg(feature = "socket-raw")]
+impl<'a, 'b> Session for RawSocket<'a, 'b> {}
+#[cfg(feature = "socket-udp")]
+impl<'a, 'b> Session for UdpSocket<'a, 'b> {}
+#[cfg(feature = "socket-tcp")]
+impl<'a> Session for TcpSocket<'a> {}
+
+/// A smart pointer to a socket.
+///
+/// Allows the network stack to efficiently determine if the socket state was changed in any way.
+pub struct Ref<'a, T: Session + 'a> {
+    socket:   &'a mut T,
+    consumed: bool,
+}
+
+impl<'a, T: Session> Ref<'a, T> {
+    pub(crate) fn new(socket: &'a mut T) -> Self {
+        Ref { socket, consumed: false }
+    }
+}
+
+impl<'a, T: Session + 'a> Ref<'a, T> {
+    pub(crate) fn map<U, F>(mut ref_: Self, f: F) -> Option<Ref<'a, U>>
+            where U: Session + 'a, F: FnOnce(&'a mut T) -> Option<&'a mut U> {
+        if let Some(socket) = f(ref_.socket) {
+            ref_.consumed = true;
+            Some(Ref::new(socket))
+        } else {
+            None
+        }
+    }
+}
+
+impl<'a, T: Session> Deref for Ref<'a, T> {
+    type Target = T;
+
+    fn deref(&self) -> &Self::Target {
+        self.socket
+    }
+}
+
+impl<'a, T: Session> DerefMut for Ref<'a, T> {
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        self.socket
+    }
+}
+
+impl<'a, T: Session> Drop for Ref<'a, T> {
+    fn drop(&mut self) {
+        if !self.consumed {
+            Session::finish(self.socket);
+        }
+    }
+}

+ 16 - 23
src/socket/set.rs

@@ -1,7 +1,7 @@
 use core::{fmt, slice};
 use managed::ManagedSlice;
 
-use super::Socket;
+use super::{Socket, SocketRef, AnySocket};
 #[cfg(feature = "socket-tcp")] use super::TcpState;
 
 /// An item of a socket set.
@@ -28,7 +28,7 @@ impl fmt::Display for Handle {
     }
 }
 
-/// An extensible set of sockets, with stable numeric identifiers.
+/// An extensible set of sockets.
 ///
 /// The lifetimes `'b` and `'c` are used when storing a `Socket<'b, 'c>`.
 #[derive(Debug)]
@@ -79,26 +79,19 @@ impl<'a, 'b: 'a, 'c: 'a + 'b> Set<'a, 'b, 'c> {
         }
     }
 
-    /// Get a socket from the set by its handle.
-    ///
-    /// # Panics
-    /// This function may panic if the handle does not belong to this socket set.
-    pub fn get(&self, handle: Handle) -> &Socket<'b, 'c> {
-        &self.sockets[handle.0]
-             .as_ref()
-             .expect("handle does not refer to a valid socket")
-             .socket
-    }
-
     /// Get a socket from the set by its handle, as mutable.
     ///
     /// # Panics
-    /// This function may panic if the handle does not belong to this socket set.
-    pub fn get_mut(&mut self, handle: Handle) -> &mut Socket<'b, 'c> {
-        &mut self.sockets[handle.0]
-                 .as_mut()
-                 .expect("handle does not refer to a valid socket")
-                 .socket
+    /// This function may panic if the handle does not belong to this socket set
+    /// or the socket has the wrong type.
+    pub fn get<T: AnySocket<'b, 'c>>(&mut self, handle: Handle) -> SocketRef<T> {
+        match self.sockets[handle.0].as_mut() {
+            Some(item) => {
+                T::downcast(SocketRef::new(&mut item.socket))
+                  .expect("handle refers to a socket of a wrong type")
+            }
+            None => panic!("handle does not refer to a valid socket")
+        }
     }
 
     /// Remove a socket from the set, without changing its state.
@@ -175,7 +168,7 @@ impl<'a, 'b: 'a, 'c: 'a + 'b> Set<'a, 'b, 'c> {
         Iter { lower: self.sockets.iter() }
     }
 
-    /// Iterate every socket in this set, as mutable.
+    /// Iterate every socket in this set, as SocketRef.
     pub fn iter_mut<'d>(&'d mut self) -> IterMut<'d, 'b, 'c> {
         IterMut { lower: self.sockets.iter_mut() }
     }
@@ -207,16 +200,16 @@ impl<'a, 'b: 'a, 'c: 'a + 'b> Iterator for Iter<'a, 'b, 'c> {
 /// This struct is created by the [iter_mut](struct.SocketSet.html#method.iter_mut)
 /// on [socket sets](struct.SocketSet.html).
 pub struct IterMut<'a, 'b: 'a, 'c: 'a + 'b> {
-    lower: slice::IterMut<'a, Option<Item<'b, 'c>>>
+    lower: slice::IterMut<'a, Option<Item<'b, 'c>>>,
 }
 
 impl<'a, 'b: 'a, 'c: 'a + 'b> Iterator for IterMut<'a, 'b, 'c> {
-    type Item = &'a mut Socket<'b, 'c>;
+    type Item = SocketRef<'a, Socket<'b, 'c>>;
 
     fn next(&mut self) -> Option<Self::Item> {
         while let Some(item_opt) = self.lower.next() {
             if let Some(item) = item_opt.as_mut() {
-                return Some(&mut item.socket)
+                return Some(SocketRef::new(&mut item.socket))
             }
         }
         None