瀏覽代碼

Merge pull request #985 from smoltcp-rs/multicast-no-device

iface: do not require device and timestamp for multicast join/leave.
Thibaut Vandervelden 6 月之前
父節點
當前提交
649bce4012

+ 1 - 5
examples/multicast.rs

@@ -82,11 +82,7 @@ fn main() {
 
     // Join a multicast group to receive mDNS traffic
     iface
-        .join_multicast_group(
-            &mut device,
-            Ipv4Address::from_bytes(&MDNS_GROUP),
-            Instant::now(),
-        )
+        .join_multicast_group(Ipv4Address::from_bytes(&MDNS_GROUP))
         .unwrap();
 
     loop {

+ 1 - 1
examples/multicast6.rs

@@ -66,7 +66,7 @@ fn main() {
 
     // Join a multicast group
     iface
-        .join_multicast_group(&mut device, Ipv6Address::from_parts(&GROUP), Instant::now())
+        .join_multicast_group(Ipv6Address::from_parts(&GROUP))
         .unwrap();
 
     loop {

+ 122 - 122
src/iface/interface/igmp.rs

@@ -4,8 +4,6 @@ use super::*;
 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
 pub enum MulticastError {
-    /// The hardware device transmit buffer is full. Try again later.
-    Exhausted,
     /// The table of joined multicast groups is already full.
     GroupTableFull,
     /// Cannot join/leave the given multicast group.
@@ -15,7 +13,6 @@ pub enum MulticastError {
 impl core::fmt::Display for MulticastError {
     fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
         match self {
-            MulticastError::Exhausted => write!(f, "Exhausted"),
             MulticastError::GroupTableFull => write!(f, "GroupTableFull"),
             MulticastError::Unaddressable => write!(f, "Unaddressable"),
         }
@@ -27,138 +24,52 @@ impl std::error::Error for MulticastError {}
 
 impl Interface {
     /// Add an address to a list of subscribed multicast IP addresses.
-    ///
-    /// Returns `Ok(announce_sent)` if the address was added successfully, where `announce_sent`
-    /// indicates whether an initial immediate announcement has been sent.
-    pub fn join_multicast_group<D, T: Into<IpAddress>>(
+    pub fn join_multicast_group<T: Into<IpAddress>>(
         &mut self,
-        device: &mut D,
         addr: T,
-        timestamp: Instant,
-    ) -> Result<bool, MulticastError>
-    where
-        D: Device + ?Sized,
-    {
+    ) -> Result<(), MulticastError> {
         let addr = addr.into();
-        self.inner.now = timestamp;
-
-        let is_not_new = self
-            .inner
-            .multicast_groups
-            .insert(addr, ())
-            .map_err(|_| MulticastError::GroupTableFull)?
-            .is_some();
-        if is_not_new {
-            return Ok(false);
+        if !addr.is_multicast() {
+            return Err(MulticastError::Unaddressable);
         }
 
-        match addr {
-            IpAddress::Ipv4(addr) => {
-                if let Some(pkt) = self.inner.igmp_report_packet(IgmpVersion::Version2, addr) {
-                    // Send initial membership report
-                    let tx_token = device
-                        .transmit(timestamp)
-                        .ok_or(MulticastError::Exhausted)?;
-
-                    // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery.
-                    self.inner
-                        .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter)
-                        .unwrap();
-
-                    Ok(true)
-                } else {
-                    Ok(false)
-                }
-            }
-            #[cfg(feature = "proto-ipv6")]
-            IpAddress::Ipv6(addr) => {
-                // Build report packet containing this new address
-                if let Some(pkt) = self.inner.mldv2_report_packet(&[MldAddressRecordRepr::new(
-                    MldRecordType::ChangeToInclude,
-                    addr,
-                )]) {
-                    // Send initial membership report
-                    let tx_token = device
-                        .transmit(timestamp)
-                        .ok_or(MulticastError::Exhausted)?;
-
-                    // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery.
-                    self.inner
-                        .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter)
-                        .unwrap();
-
-                    Ok(true)
-                } else {
-                    Ok(false)
-                }
-            }
-            #[allow(unreachable_patterns)]
-            _ => Err(MulticastError::Unaddressable),
+        if let Some(state) = self.inner.multicast_groups.get_mut(&addr) {
+            *state = match state {
+                MulticastGroupState::Joining => MulticastGroupState::Joining,
+                MulticastGroupState::Joined => MulticastGroupState::Joined,
+                MulticastGroupState::Leaving => MulticastGroupState::Joined,
+            };
+        } else {
+            self.inner
+                .multicast_groups
+                .insert(addr, MulticastGroupState::Joining)
+                .map_err(|_| MulticastError::GroupTableFull)?;
         }
+        Ok(())
     }
 
     /// Remove an address from the subscribed multicast IP addresses.
-    ///
-    /// Returns `Ok(leave_sent)` if the address was removed successfully, where `leave_sent`
-    /// indicates whether an immediate leave packet has been sent.
-    pub fn leave_multicast_group<D, T: Into<IpAddress>>(
+    pub fn leave_multicast_group<T: Into<IpAddress>>(
         &mut self,
-        device: &mut D,
         addr: T,
-        timestamp: Instant,
-    ) -> Result<bool, MulticastError>
-    where
-        D: Device + ?Sized,
-    {
+    ) -> Result<(), MulticastError> {
         let addr = addr.into();
-        self.inner.now = timestamp;
-        let was_not_present = self.inner.multicast_groups.remove(&addr).is_none();
-        if was_not_present {
-            return Ok(false);
+        if !addr.is_multicast() {
+            return Err(MulticastError::Unaddressable);
         }
 
-        match addr {
-            IpAddress::Ipv4(addr) => {
-                if let Some(pkt) = self.inner.igmp_leave_packet(addr) {
-                    // Send group leave packet
-                    let tx_token = device
-                        .transmit(timestamp)
-                        .ok_or(MulticastError::Exhausted)?;
-
-                    // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery.
-                    self.inner
-                        .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter)
-                        .unwrap();
-
-                    Ok(true)
-                } else {
-                    Ok(false)
-                }
+        if let Some(state) = self.inner.multicast_groups.get_mut(&addr) {
+            let delete;
+            (*state, delete) = match state {
+                MulticastGroupState::Joining => (MulticastGroupState::Joined, true),
+                MulticastGroupState::Joined => (MulticastGroupState::Leaving, false),
+                MulticastGroupState::Leaving => (MulticastGroupState::Leaving, false),
+            };
+            if delete {
+                self.inner.multicast_groups.remove(&addr);
             }
-            #[cfg(feature = "proto-ipv6")]
-            IpAddress::Ipv6(addr) => {
-                if let Some(pkt) = self.inner.mldv2_report_packet(&[MldAddressRecordRepr::new(
-                    MldRecordType::ChangeToExclude,
-                    addr,
-                )]) {
-                    // Send group leave packet
-                    let tx_token = device
-                        .transmit(timestamp)
-                        .ok_or(MulticastError::Exhausted)?;
-
-                    // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery.
-                    self.inner
-                        .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter)
-                        .unwrap();
-
-                    Ok(true)
-                } else {
-                    Ok(false)
-                }
-            }
-            #[allow(unreachable_patterns)]
-            _ => Err(MulticastError::Unaddressable),
         }
+        Ok(())
     }
 
     /// Check whether the interface listens to given destination multicast IP address.
@@ -166,12 +77,101 @@ impl Interface {
         self.inner.has_multicast_group(addr)
     }
 
-    /// Depending on `igmp_report_state` and the therein contained
-    /// timeouts, send IGMP membership reports.
-    pub(crate) fn igmp_egress<D>(&mut self, device: &mut D) -> bool
+    /// Do multicast egress.
+    ///
+    /// - Send join/leave packets according to the multicast group state.
+    /// - Depending on `igmp_report_state` and the therein contained
+    ///   timeouts, send IGMP membership reports.
+    pub(crate) fn multicast_egress<D>(&mut self, device: &mut D) -> bool
     where
         D: Device + ?Sized,
     {
+        // Process multicast joins.
+        while let Some((&addr, _)) = self
+            .inner
+            .multicast_groups
+            .iter()
+            .find(|(_, &state)| state == MulticastGroupState::Joining)
+        {
+            match addr {
+                IpAddress::Ipv4(addr) => {
+                    if let Some(pkt) = self.inner.igmp_report_packet(IgmpVersion::Version2, addr) {
+                        let Some(tx_token) = device.transmit(self.inner.now) else {
+                            break;
+                        };
+
+                        // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery.
+                        self.inner
+                            .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter)
+                            .unwrap();
+                    }
+                }
+                #[cfg(feature = "proto-ipv6")]
+                IpAddress::Ipv6(addr) => {
+                    if let Some(pkt) = self.inner.mldv2_report_packet(&[MldAddressRecordRepr::new(
+                        MldRecordType::ChangeToInclude,
+                        addr,
+                    )]) {
+                        let Some(tx_token) = device.transmit(self.inner.now) else {
+                            break;
+                        };
+
+                        // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery.
+                        self.inner
+                            .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter)
+                            .unwrap();
+                    }
+                }
+            }
+
+            // NOTE(unwrap): this is always replacing an existing entry, so it can't fail due to the map being full.
+            self.inner
+                .multicast_groups
+                .insert(addr, MulticastGroupState::Joined)
+                .unwrap();
+        }
+
+        // Process multicast leaves.
+        while let Some((&addr, _)) = self
+            .inner
+            .multicast_groups
+            .iter()
+            .find(|(_, &state)| state == MulticastGroupState::Leaving)
+        {
+            match addr {
+                IpAddress::Ipv4(addr) => {
+                    if let Some(pkt) = self.inner.igmp_leave_packet(addr) {
+                        let Some(tx_token) = device.transmit(self.inner.now) else {
+                            break;
+                        };
+
+                        // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery.
+                        self.inner
+                            .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter)
+                            .unwrap();
+                    }
+                }
+                #[cfg(feature = "proto-ipv6")]
+                IpAddress::Ipv6(addr) => {
+                    if let Some(pkt) = self.inner.mldv2_report_packet(&[MldAddressRecordRepr::new(
+                        MldRecordType::ChangeToExclude,
+                        addr,
+                    )]) {
+                        let Some(tx_token) = device.transmit(self.inner.now) else {
+                            break;
+                        };
+
+                        // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery.
+                        self.inner
+                            .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter)
+                            .unwrap();
+                    }
+                }
+            }
+
+            self.inner.multicast_groups.remove(&addr);
+        }
+
         match self.inner.igmp_report_state {
             IgmpReportState::ToSpecificQuery {
                 version,

+ 25 - 4
src/iface/interface/mod.rs

@@ -82,6 +82,16 @@ pub struct Interface {
     fragmenter: Fragmenter,
 }
 
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+enum MulticastGroupState {
+    /// Joining group, we have to send the join packet.
+    Joining,
+    /// We've already sent the join packet, we have nothing to do.
+    Joined,
+    /// We want to leave the group, we have to send a leave packet.
+    Leaving,
+}
+
 /// The device independent part of an Ethernet network interface.
 ///
 /// Separating the device from the data required for processing and dispatching makes
@@ -112,7 +122,7 @@ pub struct InterfaceInner {
     any_ip: bool,
     routes: Routes,
     #[cfg(any(feature = "proto-igmp", feature = "proto-ipv6"))]
-    multicast_groups: LinearMap<IpAddress, (), IFACE_MAX_MULTICAST_GROUP_COUNT>,
+    multicast_groups: LinearMap<IpAddress, MulticastGroupState, IFACE_MAX_MULTICAST_GROUP_COUNT>,
     /// When to report for (all or) the next multicast group membership via IGMP
     #[cfg(feature = "proto-igmp")]
     igmp_report_state: IgmpReportState,
@@ -437,7 +447,7 @@ impl Interface {
 
         #[cfg(feature = "proto-igmp")]
         {
-            readiness_may_have_changed |= self.igmp_egress(device);
+            readiness_may_have_changed |= self.multicast_egress(device);
         }
 
         readiness_may_have_changed
@@ -749,18 +759,29 @@ impl InterfaceInner {
     /// If built without feature `proto-igmp` this function will
     /// always return `false` when using IPv4.
     fn has_multicast_group<T: Into<IpAddress>>(&self, addr: T) -> bool {
+        /// Return false if we don't have the multicast group,
+        /// or we're leaving it.
+        fn wanted_state(x: Option<&MulticastGroupState>) -> bool {
+            match x {
+                None => false,
+                Some(MulticastGroupState::Joining) => true,
+                Some(MulticastGroupState::Joined) => true,
+                Some(MulticastGroupState::Leaving) => false,
+            }
+        }
+
         let addr = addr.into();
         match addr {
             #[cfg(feature = "proto-igmp")]
             IpAddress::Ipv4(key) => {
                 key == Ipv4Address::MULTICAST_ALL_SYSTEMS
-                    || self.multicast_groups.get(&addr).is_some()
+                    || wanted_state(self.multicast_groups.get(&addr))
             }
             #[cfg(feature = "proto-ipv6")]
             IpAddress::Ipv6(key) => {
                 key == Ipv6Address::LINK_LOCAL_ALL_NODES
                     || self.has_solicited_node(key)
-                    || self.multicast_groups.get(&addr).is_some()
+                    || wanted_state(self.multicast_groups.get(&addr))
             }
             #[cfg(feature = "proto-rpl")]
             IpAddress::Ipv6(Ipv6Address::LINK_LOCAL_ALL_RPL_NODES) => true,

+ 4 - 6
src/iface/interface/tests/ipv4.rs

@@ -702,10 +702,9 @@ fn test_handle_igmp(#[case] medium: Medium) {
     // Join multicast groups
     let timestamp = Instant::ZERO;
     for group in &groups {
-        iface
-            .join_multicast_group(&mut device, *group, timestamp)
-            .unwrap();
+        iface.join_multicast_group(*group).unwrap();
     }
+    iface.poll(timestamp, &mut device, &mut sockets);
 
     let reports = recv_igmp(&mut device, timestamp);
     assert_eq!(reports.len(), 2);
@@ -745,10 +744,9 @@ fn test_handle_igmp(#[case] medium: Medium) {
     // Leave multicast groups
     let timestamp = Instant::ZERO;
     for group in &groups {
-        iface
-            .leave_multicast_group(&mut device, *group, timestamp)
-            .unwrap();
+        iface.leave_multicast_group(*group).unwrap();
     }
+    iface.poll(timestamp, &mut device, &mut sockets);
 
     let leaves = recv_igmp(&mut device, timestamp);
     assert_eq!(leaves.len(), 2);

+ 7 - 7
src/iface/interface/tests/ipv6.rs

@@ -1289,7 +1289,7 @@ fn test_join_ipv6_multicast_group(#[case] medium: Medium) {
             .collect::<std::vec::Vec<_>>()
     }
 
-    let (mut iface, _sockets, mut device) = setup(medium);
+    let (mut iface, mut sockets, mut device) = setup(medium);
 
     let groups = [
         Ipv6Address::from_parts(&[0xff05, 0, 0, 0, 0, 0, 0, 0x00fb]),
@@ -1299,12 +1299,12 @@ fn test_join_ipv6_multicast_group(#[case] medium: Medium) {
     let timestamp = Instant::from_millis(0);
 
     for &group in &groups {
-        iface
-            .join_multicast_group(&mut device, group, timestamp)
-            .unwrap();
+        iface.join_multicast_group(group).unwrap();
         assert!(iface.has_multicast_group(group));
     }
     assert!(iface.has_multicast_group(Ipv6Address::LINK_LOCAL_ALL_NODES));
+    iface.poll(timestamp, &mut device, &mut sockets);
+    assert!(iface.has_multicast_group(Ipv6Address::LINK_LOCAL_ALL_NODES));
 
     let reports = recv_icmpv6(&mut device, timestamp);
     assert_eq!(reports.len(), 2);
@@ -1374,9 +1374,9 @@ fn test_join_ipv6_multicast_group(#[case] medium: Medium) {
             }
         );
 
-        iface
-            .leave_multicast_group(&mut device, group_addr, timestamp)
-            .unwrap();
+        iface.leave_multicast_group(group_addr).unwrap();
+        assert!(!iface.has_multicast_group(group_addr));
+        iface.poll(timestamp, &mut device, &mut sockets);
         assert!(!iface.has_multicast_group(group_addr));
     }
 }