Selaa lähdekoodia

Fix determination of local address from incoming packets.

We've advertised this capability before in examples, but it did not
actually work.
whitequark 7 vuotta sitten
vanhempi
commit
0904645c1b
3 muutettua tiedostoa jossa 48 lisäystä ja 6 poistoa
  1. 1 1
      examples/loopback.rs
  2. 38 5
      src/socket/tcp.rs
  3. 9 0
      src/wire/ip.rs

+ 1 - 1
examples/loopback.rs

@@ -149,7 +149,7 @@ fn main() {
             if !socket.is_open() {
                 if !did_connect {
                     socket.connect((IpAddress::v4(127, 0, 0, 1), 1234),
-                                   (IpAddress::v4(127, 0, 0, 1), 65000)).unwrap();
+                                   (IpAddress::Unspecified, 65000)).unwrap();
                     did_connect = true;
                 }
             }

+ 38 - 5
src/socket/tcp.rs

@@ -368,7 +368,7 @@ impl<'a> TcpSocket<'a> {
     /// Connect to a given endpoint.
     ///
     /// The local port must be provided explicitly. Assuming `fn get_ephemeral_port() -> u16`
-    /// allocates a port from the 49152 to 65535 range, a connection may be established as follows:
+    /// allocates a port between 49152 and 65535, a connection may be established as follows:
     ///
     /// ```rust,ignore
     /// socket.connect((IpAddress::v4(10, 0, 0, 1), 80), get_ephemeral_port())
@@ -386,9 +386,16 @@ impl<'a> TcpSocket<'a> {
 
         if self.is_open() { return Err(()) }
         if remote_endpoint.port == 0 { return Err(()) }
-        if remote_endpoint.addr.is_unspecified() { return Err(()) }
         if local_endpoint.port == 0 { return Err(()) }
-        if local_endpoint.addr.is_unspecified() { return Err(()) }
+
+        // If local address is not provided, use an unspecified address but a specified protocol.
+        // This lets us lower IpRepr later to determine IP header size and calculate MSS,
+        // but without committing to a specific address right away.
+        let local_addr = match remote_endpoint.addr {
+            IpAddress::Unspecified => return Err(()),
+            _ => remote_endpoint.addr.as_unspecified(),
+        };
+        let local_endpoint = IpEndpoint { addr: local_addr, ..local_endpoint };
 
         // Carry over the local sequence number.
         let local_seq_no = self.local_seq_no;
@@ -852,6 +859,7 @@ impl<'a> TcpSocket<'a> {
             }) => {
                 net_trace!("[{}]{}:{}: received SYN|ACK",
                            self.debug_id, self.local_endpoint, self.remote_endpoint);
+                self.local_endpoint  = IpEndpoint::new(ip_repr.dst_addr(), repr.dst_port);
                 self.remote_last_seq = self.local_seq_no + 1;
                 self.remote_seq_no   = seq_number + 1;
                 self.remote_last_ack = seq_number;
@@ -1269,6 +1277,8 @@ mod test {
         let mut limits = DeviceLimits::default();
         limits.max_transmission_unit = 1520;
         let result = socket.dispatch(timestamp, &limits, &mut |ip_repr, payload| {
+            let ip_repr = ip_repr.lower(&[LOCAL_END.addr.into()]).unwrap();
+
             assert_eq!(ip_repr.protocol(), IpProtocol::Tcp);
             assert_eq!(ip_repr.src_addr(), LOCAL_IP);
             assert_eq!(ip_repr.dst_addr(), REMOTE_IP);
@@ -1533,7 +1543,7 @@ mod test {
     fn socket_syn_sent() -> TcpSocket<'static> {
         let mut s = socket();
         s.state           = State::SynSent;
-        s.local_endpoint  = LOCAL_END;
+        s.local_endpoint  = IpEndpoint::new(IpAddress::v4(0, 0, 0, 0), LOCAL_PORT);
         s.remote_endpoint = REMOTE_END;
         s.local_seq_no    = LOCAL_SEQ;
         s
@@ -1542,7 +1552,7 @@ mod test {
     #[test]
     fn test_connect_validation() {
         let mut s = socket();
-        assert_eq!(s.connect(REMOTE_END, (IpAddress::Unspecified, 80)), Err(()));
+        assert_eq!(s.connect(REMOTE_END, (IpAddress::Unspecified, 80)), Ok(()));
         assert_eq!(s.connect(REMOTE_END, (IpAddress::v4(0, 0, 0, 0), 80)), Err(()));
         assert_eq!(s.connect(REMOTE_END, (IpAddress::v4(10, 0, 0, 0), 0)), Err(()));
         assert_eq!(s.connect((IpAddress::Unspecified, 80), LOCAL_END), Err(()));
@@ -1550,6 +1560,29 @@ mod test {
         assert_eq!(s.connect((IpAddress::v4(10, 0, 0, 0), 0), LOCAL_END), Err(()));
     }
 
+    #[test]
+    fn test_connect() {
+        let mut s = socket();
+        s.local_seq_no = LOCAL_SEQ;
+        s.connect(REMOTE_END, LOCAL_END.port).unwrap();
+        assert_eq!(s.local_endpoint, IpEndpoint::new(IpAddress::v4(0, 0, 0, 0), LOCAL_END.port));
+        recv!(s, [TcpRepr {
+            control:    TcpControl::Syn,
+            seq_number: LOCAL_SEQ,
+            ack_number: None,
+            max_seg_size: Some(1480),
+            ..RECV_TEMPL
+        }]);
+        send!(s, TcpRepr {
+            control:    TcpControl::Syn,
+            seq_number: REMOTE_SEQ,
+            ack_number: Some(LOCAL_SEQ + 1),
+            max_seg_size: Some(1400),
+            ..SEND_TEMPL
+        });
+        assert_eq!(s.local_endpoint, LOCAL_END);
+    }
+
     #[test]
     fn test_syn_sent_sanity() {
         let mut s = socket();

+ 9 - 0
src/wire/ip.rs

@@ -71,6 +71,15 @@ impl Address {
             &Address::Ipv4(addr)  => addr.is_unspecified()
         }
     }
+
+    /// Return an unspecified address that has the same IP version as `self`.
+    pub fn as_unspecified(&self) -> Address {
+        match self {
+            &Address::Unspecified => Address::Unspecified,
+            // &Address::Ipv4 => Address::Ipv4(Ipv4Address::UNSPECIFIED),
+            &Address::Ipv4(_) => Address::Ipv4(Ipv4Address(/*FIXME*/[0x00; 4])),
+        }
+    }
 }
 
 impl Default for Address {