Преглед изворни кода

ip: pass address by ref for get_source_address_ip

Thibaut Vandervelden пре 1 година
родитељ
комит
434f7eb37f
6 измењених фајлова са 51 додато и 22 уклоњено
  1. 10 10
      src/iface/interface/mod.rs
  2. 36 7
      src/iface/interface/tests/ipv6.rs
  3. 1 1
      src/socket/dns.rs
  4. 2 2
      src/socket/icmp.rs
  5. 1 1
      src/socket/tcp.rs
  6. 1 1
      src/socket/udp.rs

+ 10 - 10
src/iface/interface/mod.rs

@@ -468,21 +468,21 @@ impl Interface {
     /// the first IPv4 address from the list of addresses. For IPv6, the address is based on the
     /// destination address and uses RFC6724 for selecting the source address.
     pub fn get_source_address(&self, dst_addr: &IpAddress) -> Option<IpAddress> {
-        self.inner.get_source_address(*dst_addr)
+        self.inner.get_source_address(dst_addr)
     }
 
     /// Get an address from the interface that could be used as source address. This is the first
     /// IPv4 address from the list of addresses in the interface.
     #[cfg(feature = "proto-ipv4")]
     pub fn get_source_address_ipv4(&self, dst_addr: &Ipv4Address) -> Option<Ipv4Address> {
-        self.inner.get_source_address_ipv4(*dst_addr)
+        self.inner.get_source_address_ipv4(dst_addr)
     }
 
     /// Get an address from the interface that could be used as source address. The selection is
     /// based on RFC6724.
     #[cfg(feature = "proto-ipv6")]
     pub fn get_source_address_ipv6(&self, dst_addr: &Ipv6Address) -> Option<Ipv6Address> {
-        self.inner.get_source_address_ipv6(*dst_addr)
+        self.inner.get_source_address_ipv6(dst_addr)
     }
 
     /// Update the IP addresses of the interface.
@@ -948,7 +948,7 @@ impl InterfaceInner {
     }
 
     #[allow(unused)] // unused depending on which sockets are enabled
-    pub(crate) fn get_source_address(&self, dst_addr: IpAddress) -> Option<IpAddress> {
+    pub(crate) fn get_source_address(&self, dst_addr: &IpAddress) -> Option<IpAddress> {
         match dst_addr {
             #[cfg(feature = "proto-ipv4")]
             IpAddress::Ipv4(addr) => self.get_source_address_ipv4(addr).map(|a| a.into()),
@@ -959,7 +959,7 @@ impl InterfaceInner {
 
     #[cfg(feature = "proto-ipv4")]
     #[allow(unused)]
-    pub(crate) fn get_source_address_ipv4(&self, _dst_addr: Ipv4Address) -> Option<Ipv4Address> {
+    pub(crate) fn get_source_address_ipv4(&self, _dst_addr: &Ipv4Address) -> Option<Ipv4Address> {
         for cidr in self.ip_addrs.iter() {
             #[allow(irrefutable_let_patterns)] // if only ipv4 is enabled
             if let IpCidr::Ipv4(cidr) = cidr {
@@ -971,7 +971,7 @@ impl InterfaceInner {
 
     #[cfg(feature = "proto-ipv6")]
     #[allow(unused)]
-    pub(crate) fn get_source_address_ipv6(&self, dst_addr: Ipv6Address) -> Option<Ipv6Address> {
+    pub(crate) fn get_source_address_ipv6(&self, dst_addr: &Ipv6Address) -> Option<Ipv6Address> {
         // RFC 6724 describes how to select the correct source address depending on the destination
         // address.
 
@@ -1029,7 +1029,7 @@ impl InterfaceInner {
                 #[cfg(feature = "proto-ipv6")]
                 IpCidr::Ipv6(a) => Some(a),
             })
-            .find(|a| is_candidate_source_address(&dst_addr, &a.address()))
+            .find(|a| is_candidate_source_address(dst_addr, &a.address()))
             .unwrap();
 
         for addr in self.ip_addrs.iter().filter_map(|a| match a {
@@ -1038,12 +1038,12 @@ impl InterfaceInner {
             #[cfg(feature = "proto-ipv6")]
             IpCidr::Ipv6(a) => Some(a),
         }) {
-            if !is_candidate_source_address(&dst_addr, &addr.address()) {
+            if !is_candidate_source_address(dst_addr, &addr.address()) {
                 continue;
             }
 
             // Rule 1: prefer the address that is the same as the output destination address.
-            if candidate.address() != dst_addr && addr.address() == dst_addr {
+            if candidate.address() != *dst_addr && addr.address() == *dst_addr {
                 candidate = addr;
             }
 
@@ -1063,7 +1063,7 @@ impl InterfaceInner {
             // Rule 6: prefer matching label (TODO)
             // Rule 7: prefer temporary addresses (TODO)
             // Rule 8: use longest matching prefix
-            if common_prefix_length(candidate, &dst_addr) < common_prefix_length(addr, &dst_addr) {
+            if common_prefix_length(candidate, dst_addr) < common_prefix_length(addr, dst_addr) {
                 candidate = addr;
             }
         }

+ 36 - 7
src/iface/interface/tests/ipv6.rs

@@ -797,33 +797,62 @@ fn get_source_address() {
         Ipv6Address::new(0x2001, 0x0db9, 0x0003, 0, 0, 0, 0, 2);
 
     assert_eq!(
-        iface.inner.get_source_address_ipv6(LINK_LOCAL_ADDR),
+        iface.inner.get_source_address_ipv6(&LINK_LOCAL_ADDR),
         Some(OWN_LINK_LOCAL_ADDR)
     );
     assert_eq!(
-        iface.inner.get_source_address_ipv6(UNIQUE_LOCAL_ADDR1),
+        iface.inner.get_source_address_ipv6(&UNIQUE_LOCAL_ADDR1),
         Some(OWN_UNIQUE_LOCAL_ADDR1)
     );
     assert_eq!(
-        iface.inner.get_source_address_ipv6(UNIQUE_LOCAL_ADDR2),
+        iface.inner.get_source_address_ipv6(&UNIQUE_LOCAL_ADDR2),
         Some(OWN_UNIQUE_LOCAL_ADDR2)
     );
     assert_eq!(
-        iface.inner.get_source_address_ipv6(UNIQUE_LOCAL_ADDR3),
+        iface.inner.get_source_address_ipv6(&UNIQUE_LOCAL_ADDR3),
         Some(OWN_UNIQUE_LOCAL_ADDR1)
     );
     assert_eq!(
         iface
             .inner
-            .get_source_address_ipv6(Ipv6Address::LINK_LOCAL_ALL_NODES),
+            .get_source_address_ipv6(&Ipv6Address::LINK_LOCAL_ALL_NODES),
         Some(OWN_LINK_LOCAL_ADDR)
     );
     assert_eq!(
-        iface.inner.get_source_address_ipv6(GLOBAL_UNICAST_ADDR1),
+        iface.inner.get_source_address_ipv6(&GLOBAL_UNICAST_ADDR1),
         Some(OWN_GLOBAL_UNICAST_ADDR1)
     );
     assert_eq!(
-        iface.inner.get_source_address_ipv6(GLOBAL_UNICAST_ADDR2),
+        iface.inner.get_source_address_ipv6(&GLOBAL_UNICAST_ADDR2),
+        Some(OWN_GLOBAL_UNICAST_ADDR1)
+    );
+
+    assert_eq!(
+        iface.get_source_address_ipv6(&LINK_LOCAL_ADDR),
+        Some(OWN_LINK_LOCAL_ADDR)
+    );
+    assert_eq!(
+        iface.get_source_address_ipv6(&UNIQUE_LOCAL_ADDR1),
+        Some(OWN_UNIQUE_LOCAL_ADDR1)
+    );
+    assert_eq!(
+        iface.get_source_address_ipv6(&UNIQUE_LOCAL_ADDR2),
+        Some(OWN_UNIQUE_LOCAL_ADDR2)
+    );
+    assert_eq!(
+        iface.get_source_address_ipv6(&UNIQUE_LOCAL_ADDR3),
+        Some(OWN_UNIQUE_LOCAL_ADDR1)
+    );
+    assert_eq!(
+        iface.get_source_address_ipv6(&Ipv6Address::LINK_LOCAL_ALL_NODES),
+        Some(OWN_LINK_LOCAL_ADDR)
+    );
+    assert_eq!(
+        iface.get_source_address_ipv6(&GLOBAL_UNICAST_ADDR1),
+        Some(OWN_GLOBAL_UNICAST_ADDR1)
+    );
+    assert_eq!(
+        iface.get_source_address_ipv6(&GLOBAL_UNICAST_ADDR2),
         Some(OWN_GLOBAL_UNICAST_ADDR1)
     );
 }

+ 1 - 1
src/socket/dns.rs

@@ -610,7 +610,7 @@ impl<'a> Socket<'a> {
                 };
 
                 let dst_addr = servers[pq.server_idx];
-                let src_addr = cx.get_source_address(dst_addr).unwrap(); // TODO remove unwrap
+                let src_addr = cx.get_source_address(&dst_addr).unwrap(); // TODO remove unwrap
                 let ip_repr = IpRepr::new(
                     src_addr,
                     dst_addr,

+ 2 - 2
src/socket/icmp.rs

@@ -539,7 +539,7 @@ impl<'a> Socket<'a> {
             match *remote_endpoint {
                 #[cfg(feature = "proto-ipv4")]
                 IpAddress::Ipv4(dst_addr) => {
-                    let src_addr = match cx.get_source_address_ipv4(dst_addr) {
+                    let src_addr = match cx.get_source_address_ipv4(&dst_addr) {
                         Some(addr) => addr,
                         None => {
                             net_trace!(
@@ -571,7 +571,7 @@ impl<'a> Socket<'a> {
                 }
                 #[cfg(feature = "proto-ipv6")]
                 IpAddress::Ipv6(dst_addr) => {
-                    let src_addr = match cx.get_source_address_ipv6(dst_addr) {
+                    let src_addr = match cx.get_source_address_ipv6(&dst_addr) {
                         Some(addr) => addr,
                         None => {
                             net_trace!(

+ 1 - 1
src/socket/tcp.rs

@@ -851,7 +851,7 @@ impl<'a> Socket<'a> {
                     addr
                 }
                 None => cx
-                    .get_source_address(remote_endpoint.addr)
+                    .get_source_address(&remote_endpoint.addr)
                     .ok_or(ConnectError::Unaddressable)?,
             },
             port: local_endpoint.port,

+ 1 - 1
src/socket/udp.rs

@@ -519,7 +519,7 @@ impl<'a> Socket<'a> {
         let res = self.tx_buffer.dequeue_with(|packet_meta, payload_buf| {
             let src_addr = match endpoint.addr {
                 Some(addr) => addr,
-                None => match cx.get_source_address(packet_meta.endpoint.addr) {
+                None => match cx.get_source_address(&packet_meta.endpoint.addr) {
                     Some(addr) => addr,
                     None => {
                         net_trace!(