瀏覽代碼

multicast: use a single map for both ipv4 and ipv6.

Dario Nieuwenhuis 7 月之前
父節點
當前提交
4990fb979d
共有 2 個文件被更改,包括 47 次插入50 次删除
  1. 39 38
      src/iface/interface/igmp.rs
  2. 8 12
      src/iface/interface/mod.rs

+ 39 - 38
src/iface/interface/igmp.rs

@@ -39,20 +39,22 @@ impl Interface {
     where
         D: Device + ?Sized,
     {
+        let addr = addr.into();
         self.inner.now = timestamp;
 
-        match addr.into() {
+        let is_not_new = self
+            .inner
+            .multicast_groups
+            .insert(addr, ())
+            .map_err(|_| MulticastError::GroupTableFull)?
+            .is_some();
+        if is_not_new {
+            return Ok(false);
+        }
+
+        match addr {
             IpAddress::Ipv4(addr) => {
-                let is_not_new = self
-                    .inner
-                    .ipv4_multicast_groups
-                    .insert(addr, ())
-                    .map_err(|_| MulticastError::GroupTableFull)?
-                    .is_some();
-                if is_not_new {
-                    Ok(false)
-                } else if let Some(pkt) = self.inner.igmp_report_packet(IgmpVersion::Version2, addr)
-                {
+                if let Some(pkt) = self.inner.igmp_report_packet(IgmpVersion::Version2, addr) {
                     // Send initial membership report
                     let tx_token = device
                         .transmit(timestamp)
@@ -71,19 +73,10 @@ impl Interface {
             #[cfg(feature = "proto-ipv6")]
             IpAddress::Ipv6(addr) => {
                 // Build report packet containing this new address
-                let report_record = &[MldAddressRecordRepr::new(
+                if let Some(pkt) = self.inner.mldv2_report_packet(&[MldAddressRecordRepr::new(
                     MldRecordType::ChangeToInclude,
                     addr,
-                )];
-                let is_not_new = self
-                    .inner
-                    .ipv6_multicast_groups
-                    .insert(addr, ())
-                    .map_err(|_| MulticastError::GroupTableFull)?
-                    .is_some();
-                if is_not_new {
-                    Ok(false)
-                } else if let Some(pkt) = self.inner.mldv2_report_packet(report_record) {
+                )]) {
                     // Send initial membership report
                     let tx_token = device
                         .transmit(timestamp)
@@ -117,14 +110,16 @@ impl Interface {
     where
         D: Device + ?Sized,
     {
+        let addr = addr.into();
         self.inner.now = timestamp;
+        let was_not_present = self.inner.multicast_groups.remove(&addr).is_none();
+        if was_not_present {
+            return Ok(false);
+        }
 
-        match addr.into() {
+        match addr {
             IpAddress::Ipv4(addr) => {
-                let was_not_present = self.inner.ipv4_multicast_groups.remove(&addr).is_none();
-                if was_not_present {
-                    Ok(false)
-                } else if let Some(pkt) = self.inner.igmp_leave_packet(addr) {
+                if let Some(pkt) = self.inner.igmp_leave_packet(addr) {
                     // Send group leave packet
                     let tx_token = device
                         .transmit(timestamp)
@@ -142,14 +137,10 @@ impl Interface {
             }
             #[cfg(feature = "proto-ipv6")]
             IpAddress::Ipv6(addr) => {
-                let report_record = &[MldAddressRecordRepr::new(
+                if let Some(pkt) = self.inner.mldv2_report_packet(&[MldAddressRecordRepr::new(
                     MldRecordType::ChangeToExclude,
                     addr,
-                )];
-                let was_not_present = self.inner.ipv6_multicast_groups.remove(&addr).is_none();
-                if was_not_present {
-                    Ok(false)
-                } else if let Some(pkt) = self.inner.mldv2_report_packet(report_record) {
+                )]) {
                     // Send group leave packet
                     let tx_token = device
                         .transmit(timestamp)
@@ -210,10 +201,14 @@ impl Interface {
             } if self.inner.now >= timeout => {
                 let addr = self
                     .inner
-                    .ipv4_multicast_groups
+                    .multicast_groups
                     .iter()
-                    .nth(next_index)
-                    .map(|(addr, ())| *addr);
+                    .filter_map(|(addr, _)| match addr {
+                        IpAddress::Ipv4(addr) => Some(*addr),
+                        #[allow(unreachable_patterns)]
+                        _ => None,
+                    })
+                    .nth(next_index);
 
                 match addr {
                     Some(addr) => {
@@ -280,15 +275,21 @@ impl InterfaceInner {
                 if group_addr.is_unspecified()
                     && ipv4_repr.dst_addr == Ipv4Address::MULTICAST_ALL_SYSTEMS
                 {
+                    let ipv4_multicast_group_count = self
+                        .multicast_groups
+                        .keys()
+                        .filter(|a| matches!(a, IpAddress::Ipv4(_)))
+                        .count();
+
                     // Are we member in any groups?
-                    if self.ipv4_multicast_groups.iter().next().is_some() {
+                    if ipv4_multicast_group_count != 0 {
                         let interval = match version {
                             IgmpVersion::Version1 => Duration::from_millis(100),
                             IgmpVersion::Version2 => {
                                 // No dependence on a random generator
                                 // (see [#24](https://github.com/m-labs/smoltcp/issues/24))
                                 // but at least spread reports evenly across max_resp_time.
-                                let intervals = self.ipv4_multicast_groups.len() as u32 + 1;
+                                let intervals = ipv4_multicast_group_count as u32 + 1;
                                 max_resp_time / intervals
                             }
                         };

+ 8 - 12
src/iface/interface/mod.rs

@@ -111,10 +111,8 @@ pub struct InterfaceInner {
     ip_addrs: Vec<IpCidr, IFACE_MAX_ADDR_COUNT>,
     any_ip: bool,
     routes: Routes,
-    #[cfg(feature = "proto-igmp")]
-    ipv4_multicast_groups: LinearMap<Ipv4Address, (), IFACE_MAX_MULTICAST_GROUP_COUNT>,
-    #[cfg(feature = "proto-ipv6")]
-    ipv6_multicast_groups: LinearMap<Ipv6Address, (), IFACE_MAX_MULTICAST_GROUP_COUNT>,
+    #[cfg(any(feature = "proto-igmp", feature = "proto-ipv6"))]
+    multicast_groups: LinearMap<IpAddress, (), IFACE_MAX_MULTICAST_GROUP_COUNT>,
     /// When to report for (all or) the next multicast group membership via IGMP
     #[cfg(feature = "proto-igmp")]
     igmp_report_state: IgmpReportState,
@@ -226,10 +224,8 @@ impl Interface {
                 routes: Routes::new(),
                 #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))]
                 neighbor_cache: NeighborCache::new(),
-                #[cfg(feature = "proto-igmp")]
-                ipv4_multicast_groups: LinearMap::new(),
-                #[cfg(feature = "proto-ipv6")]
-                ipv6_multicast_groups: LinearMap::new(),
+                #[cfg(any(feature = "proto-igmp", feature = "proto-ipv6"))]
+                multicast_groups: LinearMap::new(),
                 #[cfg(feature = "proto-igmp")]
                 igmp_report_state: IgmpReportState::Inactive,
                 #[cfg(feature = "medium-ieee802154")]
@@ -753,17 +749,18 @@ impl InterfaceInner {
     /// If built without feature `proto-igmp` this function will
     /// always return `false` when using IPv4.
     fn has_multicast_group<T: Into<IpAddress>>(&self, addr: T) -> bool {
-        match addr.into() {
+        let addr = addr.into();
+        match addr {
             #[cfg(feature = "proto-igmp")]
             IpAddress::Ipv4(key) => {
                 key == Ipv4Address::MULTICAST_ALL_SYSTEMS
-                    || self.ipv4_multicast_groups.get(&key).is_some()
+                    || self.multicast_groups.get(&addr).is_some()
             }
             #[cfg(feature = "proto-ipv6")]
             IpAddress::Ipv6(key) => {
                 key == Ipv6Address::LINK_LOCAL_ALL_NODES
                     || self.has_solicited_node(key)
-                    || self.ipv6_multicast_groups.get(&key).is_some()
+                    || self.multicast_groups.get(&addr).is_some()
             }
             #[cfg(feature = "proto-rpl")]
             IpAddress::Ipv6(Ipv6Address::LINK_LOCAL_ALL_RPL_NODES) => true,
@@ -784,7 +781,6 @@ impl InterfaceInner {
             #[cfg(feature = "proto-ipv4")]
             Ok(IpVersion::Ipv4) => {
                 let ipv4_packet = check!(Ipv4Packet::new_checked(ip_payload));
-
                 self.process_ipv4(sockets, meta, HardwareAddress::Ip, &ipv4_packet, frag)
             }
             #[cfg(feature = "proto-ipv6")]