Browse Source

use provided ip for TcpSocket::connect instead of 0.0.0.0

fdb-hiroshima 5 years ago
parent
commit
c03bb50dc1
1 changed files with 17 additions and 4 deletions
  1. 17 4
      src/socket/tcp.rs

+ 17 - 4
src/socket/tcp.rs

@@ -483,9 +483,9 @@ impl<'a> TcpSocket<'a> {
         // If local address is not provided, use an unspecified address but a specified protocol.
         // This lets us lower IpRepr later to determine IP header size and calculate MSS,
         // but without committing to a specific address right away.
-        let local_addr = match remote_endpoint.addr {
-            IpAddress::Unspecified => return Err(Error::Unaddressable),
-            _ => remote_endpoint.addr.to_unspecified(),
+        let local_addr = match local_endpoint.addr {
+            IpAddress::Unspecified => remote_endpoint.addr.to_unspecified(),
+            ip => ip,
         };
         let local_endpoint = IpEndpoint { addr: local_addr, ..local_endpoint };
 
@@ -1894,6 +1894,16 @@ mod test {
         s
     }
 
+    fn socket_syn_sent_with_local_ipendpoint(local: IpEndpoint) -> TcpSocket<'static> {
+        let mut s = socket();
+        s.state           = State::SynSent;
+        s.local_endpoint  = local;
+        s.remote_endpoint = REMOTE_END;
+        s.local_seq_no    = LOCAL_SEQ;
+        s.remote_last_seq = LOCAL_SEQ;
+        s
+    }
+
     fn socket_established_with_buffer_sizes(tx_len: usize, rx_len: usize) -> TcpSocket<'static> {
         let mut s = socket_syn_received_with_buffer_sizes(tx_len, rx_len);
         s.state           = State::Established;
@@ -2318,6 +2328,9 @@ mod test {
                    Err(Error::Unaddressable));
         assert_eq!(s.connect((IpAddress::Unspecified, 80), LOCAL_END),
                    Err(Error::Unaddressable));
+        s.connect(REMOTE_END, LOCAL_END).expect("Connect failed with valid parameters");
+        assert_eq!(s.local_endpoint(), LOCAL_END);
+        assert_eq!(s.remote_endpoint(), REMOTE_END);
     }
 
     #[test]
@@ -2378,7 +2391,7 @@ mod test {
         let mut s = socket();
         s.local_seq_no    = LOCAL_SEQ;
         s.connect(REMOTE_END, LOCAL_END).unwrap();
-        sanity!(s, socket_syn_sent());
+        sanity!(s, socket_syn_sent_with_local_ipendpoint(LOCAL_END));
     }
 
     #[test]