Эх сурвалжийг харах

Merge pull request #1009 from bergzand/pr/mldv2queryresp

feat: Add MLDv2 query response support
Dario Nieuwenhuis 6 сар өмнө
parent
commit
33bf798438

+ 10 - 0
src/iface/interface/ipv6.rs

@@ -422,6 +422,16 @@ impl InterfaceInner {
                 #[cfg(feature = "medium-ip")]
                 Medium::Ip => None,
             },
+            #[cfg(feature = "multicast")]
+            Icmpv6Repr::Mld(repr) => match repr {
+                // [RFC 3810 § 6.2], reception checks
+                MldRepr::Query { .. }
+                    if ip_repr.hop_limit == 1 && ip_repr.src_addr.is_link_local() =>
+                {
+                    self.process_mldv2(ip_repr, repr)
+                }
+                _ => None,
+            },
 
             // Don't report an error if a packet with unknown type
             // has been handled by an ICMP socket

+ 108 - 1
src/iface/interface/multicast.rs

@@ -1,7 +1,7 @@
 use core::result::Result;
 use heapless::LinearMap;
 
-#[cfg(feature = "proto-ipv4")]
+#[cfg(any(feature = "proto-ipv4", feature = "proto-ipv6"))]
 use super::{check, IpPayload, Packet};
 use super::{Interface, InterfaceInner};
 use crate::config::IFACE_MAX_MULTICAST_GROUP_COUNT;
@@ -34,6 +34,18 @@ pub(crate) enum IgmpReportState {
     },
 }
 
+#[cfg(feature = "proto-ipv6")]
+pub(crate) enum MldReportState {
+    Inactive,
+    ToGeneralQuery {
+        timeout: crate::time::Instant,
+    },
+    ToSpecificQuery {
+        group: Ipv6Address,
+        timeout: crate::time::Instant,
+    },
+}
+
 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
 enum GroupState {
     /// Joining group, we have to send the join packet.
@@ -49,6 +61,8 @@ pub(crate) struct State {
     /// When to report for (all or) the next multicast group membership via IGMP
     #[cfg(feature = "proto-ipv4")]
     igmp_report_state: IgmpReportState,
+    #[cfg(feature = "proto-ipv6")]
+    mld_report_state: MldReportState,
 }
 
 impl State {
@@ -57,6 +71,8 @@ impl State {
             groups: LinearMap::new(),
             #[cfg(feature = "proto-ipv4")]
             igmp_report_state: IgmpReportState::Inactive,
+            #[cfg(feature = "proto-ipv6")]
+            mld_report_state: MldReportState::Inactive,
         }
     }
 
@@ -306,6 +322,46 @@ impl Interface {
             }
             _ => {}
         }
+        #[cfg(feature = "proto-ipv6")]
+        match self.inner.multicast.mld_report_state {
+            MldReportState::ToGeneralQuery { timeout } if self.inner.now >= timeout => {
+                let records = self
+                    .inner
+                    .multicast
+                    .groups
+                    .iter()
+                    .filter_map(|(addr, _)| match addr {
+                        IpAddress::Ipv6(addr) => Some(MldAddressRecordRepr::new(
+                            MldRecordType::ModeIsExclude,
+                            *addr,
+                        )),
+                        #[allow(unreachable_patterns)]
+                        _ => None,
+                    })
+                    .collect::<heapless::Vec<_, IFACE_MAX_MULTICAST_GROUP_COUNT>>();
+                if let Some(pkt) = self.inner.mldv2_report_packet(&records) {
+                    if let Some(tx_token) = device.transmit(self.inner.now) {
+                        self.inner
+                            .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter)
+                            .unwrap();
+                    };
+                };
+                self.inner.multicast.mld_report_state = MldReportState::Inactive;
+            }
+            MldReportState::ToSpecificQuery { group, timeout } if self.inner.now >= timeout => {
+                let record = MldAddressRecordRepr::new(MldRecordType::ModeIsExclude, group);
+                if let Some(pkt) = self.inner.mldv2_report_packet(&[record]) {
+                    if let Some(tx_token) = device.transmit(self.inner.now) {
+                        // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery.
+                        self.inner
+                            .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter)
+                            .unwrap();
+                    }
+                }
+                self.inner.multicast.mld_report_state = MldReportState::Inactive;
+            }
+            _ => {}
+        }
     }
 }
 
@@ -425,4 +481,55 @@ impl InterfaceInner {
             )
         })
     }
+
+    /// Host duties of the **MLDv2** protocol.
+    ///
+    /// Sets up `mld_report_state` for responding to MLD general/specific membership queries.
+    /// Membership must not be reported immediately in order to avoid flooding the network
+    /// after a query is broadcasted by a router; Currently the delay is fixed and not randomized.
+    #[cfg(feature = "proto-ipv6")]
+    pub(super) fn process_mldv2<'frame>(
+        &mut self,
+        ip_repr: Ipv6Repr,
+        repr: MldRepr<'frame>,
+    ) -> Option<Packet<'frame>> {
+        match repr {
+            MldRepr::Query {
+                mcast_addr,
+                max_resp_code,
+                ..
+            } => {
+                // Do not respont immediately to the query, but wait a random time
+                let delay = crate::time::Duration::from_millis(
+                    (self.rand.rand_u16() % max_resp_code).into(),
+                );
+                // General query
+                if mcast_addr.is_unspecified()
+                    && (ip_repr.dst_addr == IPV6_LINK_LOCAL_ALL_NODES
+                        || self.has_ip_addr(ip_repr.dst_addr))
+                {
+                    let ipv6_multicast_group_count = self
+                        .multicast
+                        .groups
+                        .keys()
+                        .filter(|a| matches!(a, IpAddress::Ipv6(_)))
+                        .count();
+                    if ipv6_multicast_group_count != 0 {
+                        self.multicast.mld_report_state = MldReportState::ToGeneralQuery {
+                            timeout: self.now + delay,
+                        };
+                    }
+                }
+                if self.has_multicast_group(mcast_addr) && ip_repr.dst_addr == mcast_addr {
+                    self.multicast.mld_report_state = MldReportState::ToSpecificQuery {
+                        group: mcast_addr,
+                        timeout: self.now + delay,
+                    };
+                }
+                None
+            }
+            MldRepr::Report { .. } => None,
+            MldRepr::ReportRecordReprs { .. } => None,
+        }
+    }
 }

+ 184 - 0
src/iface/interface/tests/ipv6.rs

@@ -1378,3 +1378,187 @@ fn test_join_ipv6_multicast_group(#[case] medium: Medium) {
         assert!(!iface.has_multicast_group(group_addr));
     }
 }
+
+#[rstest]
+#[case(Medium::Ethernet)]
+#[cfg(all(feature = "multicast", feature = "medium-ethernet"))]
+fn test_handle_valid_multicast_query(#[case] medium: Medium) {
+    fn recv_icmpv6(
+        device: &mut crate::tests::TestingDevice,
+        timestamp: Instant,
+    ) -> std::vec::Vec<Ipv6Packet<std::vec::Vec<u8>>> {
+        let caps = device.capabilities();
+        recv_all(device, timestamp)
+            .iter()
+            .filter_map(|frame| {
+                let ipv6_packet = match caps.medium {
+                    #[cfg(feature = "medium-ethernet")]
+                    Medium::Ethernet => {
+                        let eth_frame = EthernetFrame::new_checked(frame).ok()?;
+                        Ipv6Packet::new_checked(eth_frame.payload()).ok()?
+                    }
+                    #[cfg(feature = "medium-ip")]
+                    Medium::Ip => Ipv6Packet::new_checked(&frame[..]).ok()?,
+                    #[cfg(feature = "medium-ieee802154")]
+                    Medium::Ieee802154 => todo!(),
+                };
+                let buf = ipv6_packet.into_inner().to_vec();
+                Some(Ipv6Packet::new_unchecked(buf))
+            })
+            .collect::<std::vec::Vec<_>>()
+    }
+
+    let (mut iface, mut sockets, mut device) = setup(medium);
+
+    let mut timestamp = Instant::ZERO;
+
+    let mut eth_bytes = vec![0u8; 86];
+
+    let local_ip_addr = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 101);
+    let remote_ip_addr = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 100);
+    let remote_hw_addr = EthernetAddress([0x52, 0x54, 0x00, 0x00, 0x00, 0x00]);
+    let query_ip_addr = Ipv6Address::new(0xff02, 0, 0, 0, 0, 0, 0, 0x1234);
+
+    iface.join_multicast_group(query_ip_addr).unwrap();
+    iface
+        .join_multicast_group(local_ip_addr.solicited_node())
+        .unwrap();
+
+    iface.poll(timestamp, &mut device, &mut sockets);
+    // flush multicast reports from the join_multicast_group calls
+    recv_icmpv6(&mut device, timestamp);
+
+    let queries = [
+        // General query, expect both multicast addresses back
+        (
+            Ipv6Address::UNSPECIFIED,
+            IPV6_LINK_LOCAL_ALL_NODES,
+            vec![query_ip_addr, local_ip_addr.solicited_node()],
+        ),
+        // Address specific query, expect only the queried address back
+        (query_ip_addr, query_ip_addr, vec![query_ip_addr]),
+    ];
+
+    for (mcast_query, address, _results) in queries.iter() {
+        let query = Icmpv6Repr::Mld(MldRepr::Query {
+            max_resp_code: 1000,
+            mcast_addr: *mcast_query,
+            s_flag: false,
+            qrv: 1,
+            qqic: 60,
+            num_srcs: 0,
+            data: &[0, 0, 0, 0],
+        });
+
+        let ip_repr = IpRepr::Ipv6(Ipv6Repr {
+            src_addr: remote_ip_addr,
+            dst_addr: *address,
+            next_header: IpProtocol::Icmpv6,
+            hop_limit: 1,
+            payload_len: query.buffer_len(),
+        });
+
+        let mut frame = EthernetFrame::new_unchecked(&mut eth_bytes);
+        frame.set_dst_addr(EthernetAddress([0x33, 0x33, 0x00, 0x00, 0x00, 0x00]));
+        frame.set_src_addr(remote_hw_addr);
+        frame.set_ethertype(EthernetProtocol::Ipv6);
+        ip_repr.emit(frame.payload_mut(), &ChecksumCapabilities::default());
+        query.emit(
+            &remote_ip_addr,
+            address,
+            &mut Icmpv6Packet::new_unchecked(&mut frame.payload_mut()[ip_repr.header_len()..]),
+            &ChecksumCapabilities::default(),
+        );
+
+        iface.inner.process_ethernet(
+            &mut sockets,
+            PacketMeta::default(),
+            frame.into_inner(),
+            &mut iface.fragments,
+        );
+
+        timestamp += crate::time::Duration::from_millis(1000);
+        iface.poll(timestamp, &mut device, &mut sockets);
+    }
+
+    let reports = recv_icmpv6(&mut device, timestamp);
+    assert_eq!(reports.len(), queries.len());
+
+    let caps = device.capabilities();
+    let checksum_caps = &caps.checksum;
+    for ((_mcast_query, _address, results), ipv6_packet) in queries.iter().zip(reports) {
+        let buf = ipv6_packet.into_inner();
+        let ipv6_packet = Ipv6Packet::new_unchecked(buf.as_slice());
+
+        let ipv6_repr = Ipv6Repr::parse(&ipv6_packet).unwrap();
+        let ip_payload = ipv6_packet.payload();
+        assert_eq!(ipv6_repr.dst_addr, IPV6_LINK_LOCAL_ALL_MLDV2_ROUTERS);
+
+        // The first 2 octets of this payload hold the next-header indicator and the
+        // Hop-by-Hop header length (in 8-octet words, minus 1). The remaining 6 octets
+        // hold the Hop-by-Hop PadN and Router Alert options.
+        let hbh_header = Ipv6HopByHopHeader::new_checked(&ip_payload[..8]).unwrap();
+        let hbh_repr = Ipv6HopByHopRepr::parse(&hbh_header).unwrap();
+
+        assert_eq!(hbh_repr.options.len(), 3);
+        assert_eq!(
+            hbh_repr.options[0],
+            Ipv6OptionRepr::Unknown {
+                type_: Ipv6OptionType::Unknown(IpProtocol::Icmpv6.into()),
+                length: 0,
+                data: &[],
+            }
+        );
+        assert_eq!(
+            hbh_repr.options[1],
+            Ipv6OptionRepr::RouterAlert(Ipv6OptionRouterAlert::MulticastListenerDiscovery)
+        );
+        assert_eq!(hbh_repr.options[2], Ipv6OptionRepr::PadN(0));
+
+        let icmpv6_packet =
+            Icmpv6Packet::new_checked(&ip_payload[hbh_repr.buffer_len()..]).unwrap();
+        let icmpv6_repr = Icmpv6Repr::parse(
+            &ipv6_packet.src_addr(),
+            &ipv6_packet.dst_addr(),
+            &icmpv6_packet,
+            checksum_caps,
+        )
+        .unwrap();
+
+        let record_data = match icmpv6_repr {
+            Icmpv6Repr::Mld(MldRepr::Report {
+                nr_mcast_addr_rcrds,
+                data,
+            }) => {
+                assert_eq!(nr_mcast_addr_rcrds, results.len() as u16);
+                data
+            }
+            other => panic!("unexpected icmpv6_repr: {:?}", other),
+        };
+
+        let mut record_reprs = Vec::new();
+        let mut payload = record_data;
+
+        // FIXME: parsing multiple address records should be done by the MLD code
+        while !payload.is_empty() {
+            let record = MldAddressRecord::new_checked(payload).unwrap();
+            let mut record_repr = MldAddressRecordRepr::parse(&record).unwrap();
+            payload = record_repr.payload;
+            record_repr.payload = &[];
+            record_reprs.push(record_repr);
+        }
+
+        let expected_records = results
+            .iter()
+            .map(|addr| MldAddressRecordRepr {
+                num_srcs: 0,
+                mcast_addr: *addr,
+                record_type: MldRecordType::ModeIsExclude,
+                aux_data_len: 0,
+                payload: &[],
+            })
+            .collect::<Vec<_>>();
+
+        assert_eq!(record_reprs, expected_records);
+    }
+}