Browse Source

tcp: immediately choose source address in connect().

Dario Nieuwenhuis 3 years ago
parent
commit
52628e2d4e
3 changed files with 41 additions and 19 deletions
  1. 21 1
      src/iface/interface.rs
  2. 9 18
      src/socket/tcp.rs
  3. 11 0
      src/wire/ip.rs

+ 21 - 1
src/iface/interface.rs

@@ -1055,6 +1055,18 @@ impl<'a> InterfaceInner<'a> {
         &mut self.rand
     }
 
+    #[allow(unused)] // unused depending on which sockets are enabled
+    pub(crate) fn get_source_address(&mut self, dst_addr: IpAddress) -> Option<IpAddress> {
+        let v = dst_addr.version().unwrap();
+        for cidr in self.ip_addrs.iter() {
+            let addr = cidr.address();
+            if addr.version() == Some(v) {
+                return Some(addr);
+            }
+        }
+        None
+    }
+
     #[cfg(test)]
     pub(crate) fn mock() -> Self {
         Self {
@@ -1080,7 +1092,15 @@ impl<'a> InterfaceInner<'a> {
             },
             now: Instant::from_millis_const(0),
 
-            ip_addrs: ManagedSlice::Owned(vec![]),
+            ip_addrs: ManagedSlice::Owned(vec![
+                #[cfg(feature = "proto-ipv4")]
+                IpCidr::Ipv4(Ipv4Cidr::new(Ipv4Address::new(192, 168, 1, 1), 24)),
+                #[cfg(feature = "proto-ipv6")]
+                IpCidr::Ipv6(Ipv6Cidr::new(
+                    Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]),
+                    64,
+                )),
+            ]),
             rand: Rand::new(1234),
             routes: Routes::new(&mut [][..]),
 

+ 9 - 18
src/socket/tcp.rs

@@ -718,8 +718,8 @@ impl<'a> TcpSocket<'a> {
         T: Into<IpEndpoint>,
         U: Into<IpEndpoint>,
     {
-        let remote_endpoint = remote_endpoint.into();
-        let local_endpoint = local_endpoint.into();
+        let remote_endpoint: IpEndpoint = remote_endpoint.into();
+        let mut local_endpoint: IpEndpoint = local_endpoint.into();
 
         if self.is_open() {
             return Err(Error::Illegal);
@@ -731,17 +731,12 @@ impl<'a> TcpSocket<'a> {
             return Err(Error::Unaddressable);
         }
 
-        // If local address is not provided, use an unspecified address but a specified protocol.
-        // This lets us lower IpRepr later to determine IP header size and calculate MSS,
-        // but without committing to a specific address right away.
-        let local_addr = match local_endpoint.addr {
-            IpAddress::Unspecified => remote_endpoint.addr.as_unspecified(),
-            ip => ip,
-        };
-        let local_endpoint = IpEndpoint {
-            addr: local_addr,
-            ..local_endpoint
-        };
+        // If local address is not provided, choose it automatically.
+        if local_endpoint.addr.is_unspecified() {
+            local_endpoint.addr = cx
+                .get_source_address(remote_endpoint.addr)
+                .ok_or(Error::Unaddressable)?;
+        }
 
         self.reset();
         self.local_endpoint = local_endpoint;
@@ -1626,7 +1621,6 @@ impl<'a> TcpSocket<'a> {
                     self.remote_mss = max_seg_size as usize;
                 }
 
-                self.local_endpoint = IpEndpoint::new(ip_repr.dst_addr(), repr.dst_port);
                 self.remote_seq_no = repr.seq_number + 1;
                 self.remote_last_seq = self.local_seq_no + 1;
                 self.remote_last_ack = Some(repr.seq_number);
@@ -3256,10 +3250,7 @@ mod test {
         s.socket
             .connect(&mut s.cx, REMOTE_END, LOCAL_END.port)
             .unwrap();
-        assert_eq!(
-            s.local_endpoint,
-            IpEndpoint::new(MOCK_UNSPECIFIED, LOCAL_END.port)
-        );
+        assert_eq!(s.local_endpoint, LOCAL_END);
         recv!(
             s,
             [TcpRepr {

+ 11 - 0
src/wire/ip.rs

@@ -111,6 +111,17 @@ impl Address {
         Address::Ipv6(Ipv6Address::new(a0, a1, a2, a3, a4, a5, a6, a7))
     }
 
+    /// Return the protocol version.
+    pub fn version(&self) -> Option<Version> {
+        match self {
+            Address::Unspecified => None,
+            #[cfg(feature = "proto-ipv4")]
+            Address::Ipv4(_) => Some(Version::Ipv4),
+            #[cfg(feature = "proto-ipv6")]
+            Address::Ipv6(_) => Some(Version::Ipv6),
+        }
+    }
+
     /// Return an address as a sequence of octets, in big-endian.
     pub fn as_bytes(&self) -> &[u8] {
         match *self {