Browse Source

tcp: return own error enums for public API.

Dario Nieuwenhuis 2 years ago
parent
commit
591b789d1e
1 changed files with 71 additions and 40 deletions
  1. 71 40
      src/socket/tcp.rs

+ 71 - 40
src/socket/tcp.rs

@@ -16,12 +16,43 @@ use crate::wire::{
     IpAddress, IpEndpoint, IpListenEndpoint, IpProtocol, IpRepr, TcpControl, TcpRepr, TcpSeqNumber,
     TCP_HEADER_LEN,
 };
-use crate::{Error, Result};
+use crate::Error;
 
 macro_rules! tcp_trace {
     ($($arg:expr),*) => (net_log!(trace, $($arg),*));
 }
 
+/// Error returned by [`Socket::listen`]
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+#[cfg_attr(feature = "defmt", derive(defmt::Format))]
+pub enum ListenError {
+    InvalidState,
+    Unaddressable,
+}
+
+/// Error returned by [`Socket::connect`]
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+#[cfg_attr(feature = "defmt", derive(defmt::Format))]
+pub enum ConnectError {
+    InvalidState,
+    Unaddressable,
+}
+
+/// Error returned by [`Socket::send`]
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+#[cfg_attr(feature = "defmt", derive(defmt::Format))]
+pub enum SendError {
+    InvalidState,
+}
+
+/// Error returned by [`Socket::recv`]
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+#[cfg_attr(feature = "defmt", derive(defmt::Format))]
+pub enum RecvError {
+    InvalidState,
+    Finished,
+}
+
 /// A TCP socket ring buffer.
 pub type SocketBuffer<'a> = RingBuffer<'a, u8>;
 
@@ -676,17 +707,17 @@ impl<'a> Socket<'a> {
     /// This function returns `Err(Error::Illegal)` if the socket was already open
     /// (see [is_open](#method.is_open)), and `Err(Error::Unaddressable)`
     /// if the port in the given endpoint is zero.
-    pub fn listen<T>(&mut self, local_endpoint: T) -> Result<()>
+    pub fn listen<T>(&mut self, local_endpoint: T) -> Result<(), ListenError>
     where
         T: Into<IpListenEndpoint>,
     {
         let local_endpoint = local_endpoint.into();
         if local_endpoint.port == 0 {
-            return Err(Error::Unaddressable);
+            return Err(ListenError::Unaddressable);
         }
 
         if self.is_open() {
-            return Err(Error::Illegal);
+            return Err(ListenError::InvalidState);
         }
 
         self.reset();
@@ -715,7 +746,7 @@ impl<'a> Socket<'a> {
         cx: &mut Context,
         remote_endpoint: T,
         local_endpoint: U,
-    ) -> Result<()>
+    ) -> Result<(), ConnectError>
     where
         T: Into<IpEndpoint>,
         U: Into<IpListenEndpoint>,
@@ -724,13 +755,13 @@ impl<'a> Socket<'a> {
         let local_endpoint: IpListenEndpoint = local_endpoint.into();
 
         if self.is_open() {
-            return Err(Error::Illegal);
+            return Err(ConnectError::InvalidState);
         }
         if remote_endpoint.port == 0 || remote_endpoint.addr.is_unspecified() {
-            return Err(Error::Unaddressable);
+            return Err(ConnectError::Unaddressable);
         }
         if local_endpoint.port == 0 {
-            return Err(Error::Unaddressable);
+            return Err(ConnectError::Unaddressable);
         }
 
         // If local address is not provided, choose it automatically.
@@ -738,19 +769,19 @@ impl<'a> Socket<'a> {
             addr: match local_endpoint.addr {
                 Some(addr) => {
                     if addr.is_unspecified() {
-                        return Err(Error::Unaddressable);
+                        return Err(ConnectError::Unaddressable);
                     }
                     addr
                 }
                 None => cx
                     .get_source_address(remote_endpoint.addr)
-                    .ok_or(Error::Unaddressable)?,
+                    .ok_or(ConnectError::Unaddressable)?,
             },
             port: local_endpoint.port,
         };
 
         if local_endpoint.addr.version() != remote_endpoint.addr.version() {
-            return Err(Error::Illegal);
+            return Err(ConnectError::Unaddressable);
         }
 
         self.reset();
@@ -940,12 +971,12 @@ impl<'a> Socket<'a> {
         !self.rx_buffer.is_empty()
     }
 
-    fn send_impl<'b, F, R>(&'b mut self, f: F) -> Result<R>
+    fn send_impl<'b, F, R>(&'b mut self, f: F) -> Result<R, SendError>
     where
         F: FnOnce(&'b mut SocketBuffer<'a>) -> (usize, R),
     {
         if !self.may_send() {
-            return Err(Error::Illegal);
+            return Err(SendError::InvalidState);
         }
 
         // The connection might have been idle for a long time, and so remote_last_ts
@@ -973,7 +1004,7 @@ impl<'a> Socket<'a> {
     ///
     /// This function returns `Err(Error::Illegal)` if the transmit half of
     /// the connection is not open; see [may_send](#method.may_send).
-    pub fn send<'b, F, R>(&'b mut self, f: F) -> Result<R>
+    pub fn send<'b, F, R>(&'b mut self, f: F) -> Result<R, SendError>
     where
         F: FnOnce(&'b mut [u8]) -> (usize, R),
     {
@@ -986,28 +1017,28 @@ impl<'a> Socket<'a> {
     /// by the amount of free space in the transmit buffer; down to zero.
     ///
     /// See also [send](#method.send).
-    pub fn send_slice(&mut self, data: &[u8]) -> Result<usize> {
+    pub fn send_slice(&mut self, data: &[u8]) -> Result<usize, SendError> {
         self.send_impl(|tx_buffer| {
             let size = tx_buffer.enqueue_slice(data);
             (size, size)
         })
     }
 
-    fn recv_error_check(&mut self) -> Result<()> {
+    fn recv_error_check(&mut self) -> Result<(), RecvError> {
         // We may have received some data inside the initial SYN, but until the connection
         // is fully open we must not dequeue any data, as it may be overwritten by e.g.
         // another (stale) SYN. (We do not support TCP Fast Open.)
         if !self.may_recv() {
             if self.rx_fin_received {
-                return Err(Error::Finished);
+                return Err(RecvError::Finished);
             }
-            return Err(Error::Illegal);
+            return Err(RecvError::InvalidState);
         }
 
         Ok(())
     }
 
-    fn recv_impl<'b, F, R>(&'b mut self, f: F) -> Result<R>
+    fn recv_impl<'b, F, R>(&'b mut self, f: F) -> Result<R, RecvError>
     where
         F: FnOnce(&'b mut SocketBuffer<'a>) -> (usize, R),
     {
@@ -1037,7 +1068,7 @@ impl<'a> Socket<'a> {
     ///
     /// In all other cases, `Err(Error::Illegal)` is returned and previously received data (if any)
     /// may be incomplete (truncated).
-    pub fn recv<'b, F, R>(&'b mut self, f: F) -> Result<R>
+    pub fn recv<'b, F, R>(&'b mut self, f: F) -> Result<R, RecvError>
     where
         F: FnOnce(&'b mut [u8]) -> (usize, R),
     {
@@ -1050,7 +1081,7 @@ impl<'a> Socket<'a> {
     /// by the amount of occupied space in the receive buffer; down to zero.
     ///
     /// See also [recv](#method.recv).
-    pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<usize> {
+    pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<usize, RecvError> {
         self.recv_impl(|rx_buffer| {
             let size = rx_buffer.dequeue_slice(data);
             (size, size)
@@ -1061,7 +1092,7 @@ impl<'a> Socket<'a> {
     /// the receive buffer, and return a pointer to it.
     ///
     /// This function otherwise behaves identically to [recv](#method.recv).
-    pub fn peek(&mut self, size: usize) -> Result<&[u8]> {
+    pub fn peek(&mut self, size: usize) -> Result<&[u8], RecvError> {
         self.recv_error_check()?;
 
         let buffer = self.rx_buffer.get_allocated(0, size);
@@ -1076,7 +1107,7 @@ impl<'a> Socket<'a> {
     /// the receive buffer, and fill a slice from it.
     ///
     /// This function otherwise behaves identically to [recv_slice](#method.recv_slice).
-    pub fn peek_slice(&mut self, data: &mut [u8]) -> Result<usize> {
+    pub fn peek_slice(&mut self, data: &mut [u8]) -> Result<usize, RecvError> {
         let buffer = self.peek(data.len())?;
         let data = &mut data[..buffer.len()];
         data.copy_from_slice(buffer);
@@ -1262,7 +1293,7 @@ impl<'a> Socket<'a> {
         cx: &mut Context,
         ip_repr: &IpRepr,
         repr: &TcpRepr,
-    ) -> Result<Option<(IpRepr, TcpRepr<'static>)>> {
+    ) -> Result<Option<(IpRepr, TcpRepr<'static>)>, Error> {
         debug_assert!(self.accepts(cx, ip_repr, repr));
 
         // Consider how much the sequence number space differs from the transmit buffer space.
@@ -1903,9 +1934,9 @@ impl<'a> Socket<'a> {
         }
     }
 
-    pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<()>
+    pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<(), Error>
     where
-        F: FnOnce(&mut Context, (IpRepr, TcpRepr)) -> Result<()>,
+        F: FnOnce(&mut Context, (IpRepr, TcpRepr)) -> Result<(), Error>,
     {
         if self.tuple.is_none() {
             return Err(Error::Exhausted);
@@ -2365,7 +2396,7 @@ mod test {
         socket: &mut TestSocket,
         timestamp: Instant,
         repr: &TcpRepr,
-    ) -> Result<Option<TcpRepr<'static>>> {
+    ) -> Result<Option<TcpRepr<'static>>, Error> {
         socket.cx.set_now(timestamp);
 
         let ip_repr = IpReprIpvX(IpvXRepr {
@@ -2391,7 +2422,7 @@ mod test {
 
     fn recv<F>(socket: &mut TestSocket, timestamp: Instant, mut f: F)
     where
-        F: FnMut(Result<TcpRepr>),
+        F: FnMut(Result<TcpRepr, Error>),
     {
         socket.cx.set_now(timestamp);
 
@@ -2736,14 +2767,14 @@ mod test {
     #[test]
     fn test_listen_validation() {
         let mut s = socket();
-        assert_eq!(s.listen(0), Err(Error::Unaddressable));
+        assert_eq!(s.listen(0), Err(ListenError::Unaddressable));
     }
 
     #[test]
     fn test_listen_twice() {
         let mut s = socket();
         assert_eq!(s.listen(80), Ok(()));
-        assert_eq!(s.listen(80), Err(Error::Illegal));
+        assert_eq!(s.listen(80), Err(ListenError::InvalidState));
     }
 
     #[test]
@@ -3052,17 +3083,17 @@ mod test {
         assert_eq!(
             s.socket
                 .connect(&mut s.cx, REMOTE_END, (IpvXAddress::UNSPECIFIED, 0)),
-            Err(Error::Unaddressable)
+            Err(ConnectError::Unaddressable)
         );
         assert_eq!(
             s.socket
                 .connect(&mut s.cx, REMOTE_END, (IpvXAddress::UNSPECIFIED, 1024)),
-            Err(Error::Unaddressable)
+            Err(ConnectError::Unaddressable)
         );
         assert_eq!(
             s.socket
                 .connect(&mut s.cx, (IpvXAddress::UNSPECIFIED, 0), LOCAL_END),
-            Err(Error::Unaddressable)
+            Err(ConnectError::Unaddressable)
         );
         s.socket
             .connect(&mut s.cx, REMOTE_END, LOCAL_END)
@@ -3125,7 +3156,7 @@ mod test {
         assert_eq!(s.socket.connect(&mut s.cx, REMOTE_END, 80), Ok(()));
         assert_eq!(
             s.socket.connect(&mut s.cx, REMOTE_END, 80),
-            Err(Error::Illegal)
+            Err(ConnectError::InvalidState)
         );
     }
 
@@ -6297,7 +6328,7 @@ mod test {
             (3, ())
         })
         .unwrap();
-        assert_eq!(s.recv(|_| (0, ())), Err(Error::Finished));
+        assert_eq!(s.recv(|_| (0, ())), Err(RecvError::Finished));
     }
 
     #[test]
@@ -6319,7 +6350,7 @@ mod test {
             (3, ())
         })
         .unwrap();
-        assert_eq!(s.recv(|_| (0, ())), Err(Error::Finished));
+        assert_eq!(s.recv(|_| (0, ())), Err(RecvError::Finished));
     }
 
     #[test]
@@ -6341,7 +6372,7 @@ mod test {
             (3, ())
         })
         .unwrap();
-        assert_eq!(s.recv(|_| (0, ())), Err(Error::Finished));
+        assert_eq!(s.recv(|_| (0, ())), Err(RecvError::Finished));
     }
 
     #[test]
@@ -6393,7 +6424,7 @@ mod test {
         );
         // Error must be `Illegal` even if we've received a FIN,
         // because we are missing data.
-        assert_eq!(s.recv(|_| (0, ())), Err(Error::Illegal));
+        assert_eq!(s.recv(|_| (0, ())), Err(RecvError::InvalidState));
     }
 
     #[test]
@@ -6422,7 +6453,7 @@ mod test {
             (3, ())
         })
         .unwrap();
-        assert_eq!(s.recv(|_| (0, ())), Err(Error::Illegal));
+        assert_eq!(s.recv(|_| (0, ())), Err(RecvError::InvalidState));
     }
 
     #[test]
@@ -6466,7 +6497,7 @@ mod test {
             (3, ())
         })
         .unwrap();
-        assert_eq!(s.recv(|_| (0, ())), Err(Error::Illegal));
+        assert_eq!(s.recv(|_| (0, ())), Err(RecvError::InvalidState));
     }
 
     // =========================================================================================//