Browse Source

Return specific sockets from `new` functions instead of `Socket`.

* Add Into<Socket> implementations for sockets
* Make SocketSet::add generic over Into<Socket>
Philipp Oppermann 7 years ago
parent
commit
c2d18ec071
5 changed files with 54 additions and 55 deletions
  1. 12 15
      src/socket/icmp.rs
  2. 17 25
      src/socket/raw.rs
  3. 5 1
      src/socket/set.rs
  4. 10 7
      src/socket/tcp.rs
  5. 10 7
      src/socket/udp.rs

+ 12 - 15
src/socket/icmp.rs

@@ -104,14 +104,14 @@ pub struct IcmpSocket<'a, 'b: 'a> {
 impl<'a, 'b> IcmpSocket<'a, 'b> {
     /// Create an ICMPv4 socket with the given buffers.
     pub fn new(rx_buffer: SocketBuffer<'a, 'b>,
-               tx_buffer: SocketBuffer<'a, 'b>) -> Socket<'a, 'b> {
-        Socket::Icmp(IcmpSocket {
+               tx_buffer: SocketBuffer<'a, 'b>) -> IcmpSocket<'a, 'b> {
+        IcmpSocket {
             meta:      SocketMeta::default(),
             rx_buffer: rx_buffer,
             tx_buffer: tx_buffer,
             endpoint:  Endpoint::default(),
             hop_limit: None
-        })
+        }
     }
 
     /// Return the socket handle.
@@ -170,10 +170,7 @@ impl<'a, 'b> IcmpSocket<'a, 'b> {
     /// use smoltcp::socket::IcmpEndpoint;
     ///
     /// let mut icmp_socket = // ...
-    /// # match IcmpSocket::new(rx_buffer, tx_buffer) {
-    /// #     Socket::Icmp(socket) => socket,
-    /// #     _ => unreachable!()
-    /// # };
+    /// # IcmpSocket::new(rx_buffer, tx_buffer);
     ///
     /// // Bind to ICMP error responses for UDP packets sent from port 53.
     /// let endpoint = IpEndpoint::from(53);
@@ -194,10 +191,7 @@ impl<'a, 'b> IcmpSocket<'a, 'b> {
     /// use smoltcp::socket::IcmpEndpoint;
     ///
     /// let mut icmp_socket = // ...
-    /// # match IcmpSocket::new(rx_buffer, tx_buffer) {
-    /// #     Socket::Icmp(socket) => socket,
-    /// #     _ => unreachable!()
-    /// # };
+    /// # IcmpSocket::new(rx_buffer, tx_buffer);
     ///
     /// // Bind to ICMP messages with the ICMP identifier 0x1234
     /// icmp_socket.bind(IcmpEndpoint::Ident(0x1234)).unwrap();
@@ -360,6 +354,12 @@ impl<'a, 'b> IcmpSocket<'a, 'b> {
     }
 }
 
+impl<'a, 'b> Into<Socket<'a, 'b>> for IcmpSocket<'a, 'b> {
+    fn into(self) -> Socket<'a, 'b> {
+        Socket::Icmp(self)
+    }
+}
+
 #[cfg(test)]
 mod test {
     use phy::DeviceCapabilities;
@@ -376,10 +376,7 @@ mod test {
 
     fn socket(rx_buffer: SocketBuffer<'static, 'static>,
               tx_buffer: SocketBuffer<'static, 'static>) -> IcmpSocket<'static, 'static> {
-        match IcmpSocket::new(rx_buffer, tx_buffer) {
-            Socket::Icmp(socket) => socket,
-            _ => unreachable!()
-        }
+        IcmpSocket::new(rx_buffer, tx_buffer)
     }
 
     const REMOTE_IPV4: Ipv4Address = Ipv4Address([0x7f, 0x00, 0x00, 0x02]);

+ 17 - 25
src/socket/raw.rs

@@ -73,14 +73,14 @@ impl<'a, 'b> RawSocket<'a, 'b> {
     /// with the given buffers.
     pub fn new(ip_version: IpVersion, ip_protocol: IpProtocol,
                rx_buffer: SocketBuffer<'a, 'b>,
-               tx_buffer: SocketBuffer<'a, 'b>) -> Socket<'a, 'b> {
-        Socket::Raw(RawSocket {
+               tx_buffer: SocketBuffer<'a, 'b>) -> RawSocket<'a, 'b> {
+        RawSocket {
             meta: SocketMeta::default(),
             ip_version,
             ip_protocol,
             rx_buffer,
             tx_buffer,
-        })
+        }
     }
 
     /// Return the socket handle.
@@ -251,6 +251,12 @@ impl<'a, 'b> RawSocket<'a, 'b> {
     }
 }
 
+impl<'a, 'b> Into<Socket<'a, 'b>> for RawSocket<'a, 'b> {
+    fn into(self) -> Socket<'a, 'b> {
+        Socket::Raw(self)
+    }
+}
+
 #[cfg(test)]
 mod test {
     use wire::IpRepr;
@@ -275,11 +281,8 @@ mod test {
         pub fn socket(rx_buffer: SocketBuffer<'static, 'static>,
                   tx_buffer: SocketBuffer<'static, 'static>)
                 -> RawSocket<'static, 'static> {
-            match RawSocket::new(IpVersion::Ipv4, IpProtocol::Unknown(IP_PROTO),
-                                 rx_buffer, tx_buffer) {
-                Socket::Raw(socket) => socket,
-                _ => unreachable!()
-            }
+            RawSocket::new(IpVersion::Ipv4, IpProtocol::Unknown(IP_PROTO),
+                rx_buffer, tx_buffer)
         }
 
         pub const IP_PROTO: u8 = 63;
@@ -310,11 +313,8 @@ mod test {
         pub fn socket(rx_buffer: SocketBuffer<'static, 'static>,
                   tx_buffer: SocketBuffer<'static, 'static>)
                 -> RawSocket<'static, 'static> {
-            match RawSocket::new(IpVersion::Ipv6, IpProtocol::Unknown(IP_PROTO),
-                                 rx_buffer, tx_buffer) {
-                Socket::Raw(socket) => socket,
-                _ => unreachable!()
-            }
+            RawSocket::new(IpVersion::Ipv6, IpProtocol::Unknown(IP_PROTO),
+                                 rx_buffer, tx_buffer)
         }
 
         pub const IP_PROTO: u8 = 63;
@@ -514,24 +514,16 @@ mod test {
     fn test_doesnt_accept_wrong_proto() {
         #[cfg(feature = "proto-ipv4")]
         {
-            let socket = match RawSocket::new(IpVersion::Ipv4,
-                                              IpProtocol::Unknown(ipv4_locals::IP_PROTO+1),
-                                              buffer(1), buffer(1)) {
-                Socket::Raw(socket) => socket,
-                _ => unreachable!()
-            };
+            let socket = RawSocket::new(IpVersion::Ipv4,
+                IpProtocol::Unknown(ipv4_locals::IP_PROTO+1), buffer(1), buffer(1));
             assert!(!socket.accepts(&ipv4_locals::HEADER_REPR));
             #[cfg(feature = "proto-ipv6")]
             assert!(!socket.accepts(&ipv6_locals::HEADER_REPR));
         }
         #[cfg(feature = "proto-ipv6")]
         {
-            let socket = match RawSocket::new(IpVersion::Ipv6,
-                                              IpProtocol::Unknown(ipv6_locals::IP_PROTO+1),
-                                              buffer(1), buffer(1)) {
-                Socket::Raw(socket) => socket,
-                _ => unreachable!()
-            };
+            let socket = RawSocket::new(IpVersion::Ipv6,
+                IpProtocol::Unknown(ipv6_locals::IP_PROTO+1), buffer(1), buffer(1));
             assert!(!socket.accepts(&ipv6_locals::HEADER_REPR));
             #[cfg(feature = "proto-ipv4")]
             assert!(!socket.accepts(&ipv4_locals::HEADER_REPR));

+ 5 - 1
src/socket/set.rs

@@ -47,7 +47,9 @@ impl<'a, 'b: 'a, 'c: 'a + 'b> Set<'a, 'b, 'c> {
     ///
     /// # Panics
     /// This function panics if the storage is fixed-size (not a `Vec`) and is full.
-    pub fn add(&mut self, socket: Socket<'b, 'c>) -> Handle {
+    pub fn add<T>(&mut self, socket: T) -> Handle
+        where T: Into<Socket<'b, 'c>>
+    {
         fn put<'b, 'c>(index: usize, slot: &mut Option<Item<'b, 'c>>,
                        mut socket: Socket<'b, 'c>) -> Handle {
             net_trace!("[{}]: adding", index);
@@ -57,6 +59,8 @@ impl<'a, 'b: 'a, 'c: 'a + 'b> Set<'a, 'b, 'c> {
             handle
         }
 
+        let socket = socket.into();
+
         for (index, slot) in self.sockets.iter_mut().enumerate() {
             if slot.is_none() {
                 return put(index, slot, socket)

+ 10 - 7
src/socket/tcp.rs

@@ -230,7 +230,7 @@ const DEFAULT_MSS: usize = 536;
 
 impl<'a> TcpSocket<'a> {
     /// Create a socket using the given buffers.
-    pub fn new<T>(rx_buffer: T, tx_buffer: T) -> Socket<'a, 'static>
+    pub fn new<T>(rx_buffer: T, tx_buffer: T) -> TcpSocket<'a>
             where T: Into<SocketBuffer<'a>> {
         let (rx_buffer, tx_buffer) = (rx_buffer.into(), tx_buffer.into());
         if rx_buffer.capacity() > <u16>::max_value() as usize {
@@ -238,7 +238,7 @@ impl<'a> TcpSocket<'a> {
                    <u16>::max_value())
         }
 
-        Socket::Tcp(TcpSocket {
+        TcpSocket {
             meta:            SocketMeta::default(),
             state:           State::Closed,
             timer:           Timer::default(),
@@ -259,7 +259,7 @@ impl<'a> TcpSocket<'a> {
             remote_win_len:  0,
             remote_mss:      DEFAULT_MSS,
             remote_last_ts:  None,
-        })
+        }
     }
 
     /// Return the socket handle.
@@ -1495,6 +1495,12 @@ impl<'a> TcpSocket<'a> {
     }
 }
 
+impl<'a> Into<Socket<'a, 'static>> for TcpSocket<'a> {
+    fn into(self) -> Socket<'a, 'static> {
+        Socket::Tcp(self)
+    }
+}
+
 impl<'a> fmt::Write for TcpSocket<'a> {
     fn write_str(&mut self, slice: &str) -> fmt::Result {
         let slice = slice.as_bytes();
@@ -1680,10 +1686,7 @@ mod test {
 
         let rx_buffer = SocketBuffer::new(vec![0; 64]);
         let tx_buffer = SocketBuffer::new(vec![0; 64]);
-        match TcpSocket::new(rx_buffer, tx_buffer) {
-            Socket::Tcp(socket) => socket,
-            _ => unreachable!()
-        }
+        TcpSocket::new(rx_buffer, tx_buffer)
     }
 
     fn socket_syn_received() -> TcpSocket<'static> {

+ 10 - 7
src/socket/udp.rs

@@ -70,14 +70,14 @@ pub struct UdpSocket<'a, 'b: 'a> {
 impl<'a, 'b> UdpSocket<'a, 'b> {
     /// Create an UDP socket with the given buffers.
     pub fn new(rx_buffer: SocketBuffer<'a, 'b>,
-               tx_buffer: SocketBuffer<'a, 'b>) -> Socket<'a, 'b> {
-        Socket::Udp(UdpSocket {
+               tx_buffer: SocketBuffer<'a, 'b>) -> UdpSocket<'a, 'b> {
+        UdpSocket {
             meta:      SocketMeta::default(),
             endpoint:  IpEndpoint::default(),
             rx_buffer: rx_buffer,
             tx_buffer: tx_buffer,
             hop_limit: None
-        })
+        }
     }
 
     /// Return the socket handle.
@@ -256,6 +256,12 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
     }
 }
 
+impl<'a, 'b> Into<Socket<'a, 'b>> for UdpSocket<'a, 'b> {
+    fn into(self) -> Socket<'a, 'b> {
+        Socket::Udp(self)
+    }
+}
+
 #[cfg(test)]
 mod test {
     use wire::{IpAddress, IpRepr, UdpRepr};
@@ -277,10 +283,7 @@ mod test {
     fn socket(rx_buffer: SocketBuffer<'static, 'static>,
               tx_buffer: SocketBuffer<'static, 'static>)
             -> UdpSocket<'static, 'static> {
-        match UdpSocket::new(rx_buffer, tx_buffer) {
-            Socket::Udp(socket) => socket,
-            _ => unreachable!()
-        }
+        UdpSocket::new(rx_buffer, tx_buffer)
     }
 
     const LOCAL_PORT:  u16        = 53;