Эх сурвалжийг харах

Rework error handling in TcpSocket::connect.

whitequark 7 жил өмнө
parent
commit
53a3875452
4 өөрчлөгдсөн 48 нэмэгдсэн , 16 устгасан
  1. 1 1
      examples/server.rs
  2. 3 0
      src/lib.rs
  3. 41 12
      src/socket/tcp.rs
  4. 3 3
      src/wire/ip.rs

+ 1 - 1
examples/server.rs

@@ -63,7 +63,7 @@ fn main() {
         // udp:6969: respond "yo dawg"
         {
             let socket: &mut UdpSocket = sockets.get_mut(udp_handle).as_socket();
-            if socket.endpoint().is_unspecified() {
+            if !socket.endpoint().is_specified() {
                 socket.bind(6969)
             }
 

+ 3 - 0
src/lib.rs

@@ -101,6 +101,8 @@ pub mod socket;
 pub enum Error {
     /// An operation cannot proceed because a buffer is empty or full.
     Exhausted,
+    /// An operation is not permitted in the current state.
+    Illegal,
     /// An endpoint or address of a remote host could not be translated to a lower level address.
     /// E.g. there was no an Ethernet address corresponding to an IPv4 address in the ARP cache,
     /// or a TCP connection attempt was made to an unspecified endpoint.
@@ -136,6 +138,7 @@ impl fmt::Display for Error {
     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
         match self {
             &Error::Exhausted     => write!(f, "buffer space exhausted"),
+            &Error::Illegal       => write!(f, "illegal operation"),
             &Error::Unaddressable => write!(f, "unaddressable destination"),
             &Error::Truncated     => write!(f, "truncated packet"),
             &Error::Checksum      => write!(f, "checksum error"),

+ 41 - 12
src/socket/tcp.rs

@@ -379,20 +379,20 @@ impl<'a> TcpSocket<'a> {
     /// This function returns an error if the socket was open; see [is_open](#method.is_open).
     /// It also returns an error if the local or remote port is zero, or if the remote address
     /// is unspecified.
-    pub fn connect<T, U>(&mut self, remote_endpoint: T, local_endpoint: U) -> Result<(), ()>
+    pub fn connect<T, U>(&mut self, remote_endpoint: T, local_endpoint: U) -> Result<(), Error>
             where T: Into<IpEndpoint>, U: Into<IpEndpoint> {
         let remote_endpoint = remote_endpoint.into();
         let local_endpoint  = local_endpoint.into();
 
-        if self.is_open() { return Err(()) }
-        if remote_endpoint.port == 0 { return Err(()) }
-        if local_endpoint.port == 0 { return Err(()) }
+        if self.is_open() { return Err(Error::Illegal) }
+        if !remote_endpoint.is_specified() { return Err(Error::Unaddressable) }
+        if local_endpoint.port == 0 { return Err(Error::Unaddressable) }
 
         // If local address is not provided, use an unspecified address but a specified protocol.
         // This lets us lower IpRepr later to determine IP header size and calculate MSS,
         // but without committing to a specific address right away.
         let local_addr = match remote_endpoint.addr {
-            IpAddress::Unspecified => return Err(()),
+            IpAddress::Unspecified => return Err(Error::Unaddressable),
             _ => remote_endpoint.addr.to_unspecified(),
         };
         let local_endpoint = IpEndpoint { addr: local_addr, ..local_endpoint };
@@ -982,7 +982,7 @@ impl<'a> TcpSocket<'a> {
     pub(crate) fn dispatch<F, R>(&mut self, timestamp: u64, limits: &DeviceLimits,
                                  emit: &mut F) -> Result<R, Error>
             where F: FnMut(&IpRepr, &IpPayload) -> Result<R, Error> {
-        if self.remote_endpoint.is_unspecified() { return Err(Error::Exhausted) }
+        if !self.remote_endpoint.is_specified() { return Err(Error::Exhausted) }
 
         let mut repr = TcpRepr {
             src_port:     self.local_endpoint.port,
@@ -1585,12 +1585,14 @@ mod test {
     #[test]
     fn test_connect_validation() {
         let mut s = socket();
-        assert_eq!(s.connect(REMOTE_END, (IpAddress::Unspecified, 80)), Ok(()));
-        assert_eq!(s.connect(REMOTE_END, (IpAddress::v4(0, 0, 0, 0), 80)), Err(()));
-        assert_eq!(s.connect(REMOTE_END, (IpAddress::v4(10, 0, 0, 0), 0)), Err(()));
-        assert_eq!(s.connect((IpAddress::Unspecified, 80), LOCAL_END), Err(()));
-        assert_eq!(s.connect((IpAddress::v4(0, 0, 0, 0), 80), LOCAL_END), Err(()));
-        assert_eq!(s.connect((IpAddress::v4(10, 0, 0, 0), 0), LOCAL_END), Err(()));
+        assert_eq!(s.connect((IpAddress::v4(0, 0, 0, 0), 80), LOCAL_END),
+                   Err(Error::Unaddressable));
+        assert_eq!(s.connect(REMOTE_END, (IpAddress::v4(10, 0, 0, 0), 0)),
+                   Err(Error::Unaddressable));
+        assert_eq!(s.connect((IpAddress::v4(10, 0, 0, 0), 0), LOCAL_END),
+                   Err(Error::Unaddressable));
+        assert_eq!(s.connect((IpAddress::Unspecified, 80), LOCAL_END),
+                   Err(Error::Unaddressable));
     }
 
     #[test]
@@ -1616,6 +1618,33 @@ mod test {
         assert_eq!(s.local_endpoint, LOCAL_END);
     }
 
+    #[test]
+    fn test_connect_unspecified_local() {
+        let mut s = socket();
+        assert_eq!(s.connect(REMOTE_END, (IpAddress::v4(0, 0, 0, 0), 80)),
+                   Ok(()));
+        s.abort();
+        assert_eq!(s.connect(REMOTE_END, (IpAddress::Unspecified, 80)),
+                   Ok(()));
+        s.abort();
+    }
+
+    #[test]
+    fn test_connect_specified_local() {
+        let mut s = socket();
+        assert_eq!(s.connect(REMOTE_END, (IpAddress::v4(10, 0, 0, 2), 80)),
+                   Ok(()));
+    }
+
+    #[test]
+    fn test_connect_twice() {
+        let mut s = socket();
+        assert_eq!(s.connect(REMOTE_END, (IpAddress::Unspecified, 80)),
+                   Ok(()));
+        assert_eq!(s.connect(REMOTE_END, (IpAddress::Unspecified, 80)),
+                   Err(Error::Illegal));
+    }
+
     #[test]
     fn test_syn_sent_sanity() {
         let mut s = socket();

+ 3 - 3
src/wire/ip.rs

@@ -120,9 +120,9 @@ impl Endpoint {
         Endpoint { addr: addr, port: port }
     }
 
-    /// Query whether the endpoint has an unspecified address.
-    pub fn is_unspecified(&self) -> bool {
-        self.addr.is_unspecified()
+    /// Query whether the endpoint has a specified address and port.
+    pub fn is_specified(&self) -> bool {
+        !self.addr.is_unspecified() && self.port != 0
     }
 }