Bläddra i källkod

Merge #669

669: Adds one-shot mDNS resolution r=Dirbaio a=benbrittain

RFC 6762 Section 5.1 specifies a one-shot multicast DNS query. This
query has minimal differences from a standard DNS query, mostly just
using a multicast address and a different port (5353 vs 53).

A fully standards compliant mDNS implementation would use UDP source
port 5353 as well to issue queries, however we MUST NOT use that port
and continue using an ephemeral port until features such as service
discovery are implemented.

This change also allows specifying what kind of DNS query we wish to
perform.

https://www.rfc-editor.org/rfc/rfc6762#section-5.1

Co-authored-by: Benjamin Brittain <ben@brittain.org>
Co-authored-by: Thibaut Vandervelden <thvdveld@vub.be>
bors[bot] 2 år sedan
förälder
incheckning
c1274210bc
6 ändrade filer med 73 tillägg och 27 borttagningar
  1. 2 1
      Cargo.toml
  2. 4 2
      examples/dns.rs
  3. 4 12
      src/iface/interface.rs
  4. 59 11
      src/socket/dns.rs
  5. 1 1
      src/wire/dhcpv4.rs
  6. 3 0
      src/wire/mod.rs

+ 2 - 1
Cargo.toml

@@ -60,6 +60,7 @@ defmt = [ "dep:defmt", "heapless/defmt", "heapless/defmt-impl" ]
 "socket-icmp" = ["socket"]
 "socket-dhcpv4" = ["socket", "medium-ethernet", "proto-dhcpv4"]
 "socket-dns" = ["socket", "proto-dns"]
+"socket-mdns" = ["socket-dns"]
 
 "async" = []
 
@@ -69,7 +70,7 @@ default = [
   "phy-raw_socket", "phy-tuntap_interface",
   "proto-ipv4", "proto-igmp", "proto-dhcpv4", "proto-ipv6", "proto-dns",
   "proto-ipv4-fragmentation", "proto-sixlowpan-fragmentation",
-  "socket-raw", "socket-icmp", "socket-udp", "socket-tcp", "socket-dhcpv4", "socket-dns",
+  "socket-raw", "socket-icmp", "socket-udp", "socket-tcp", "socket-dhcpv4", "socket-dns", "socket-mdns",
   "async"
 ]
 

+ 4 - 2
examples/dns.rs

@@ -13,7 +13,7 @@ use smoltcp::phy::{wait as phy_wait, Medium};
 use smoltcp::socket::dns::{self, GetQueryResultError};
 use smoltcp::time::Instant;
 use smoltcp::wire::{
-    EthernetAddress, HardwareAddress, IpAddress, IpCidr, Ipv4Address, Ipv6Address,
+    DnsQueryType, EthernetAddress, HardwareAddress, IpAddress, IpCidr, Ipv4Address, Ipv6Address,
 };
 use std::collections::BTreeMap;
 use std::os::unix::io::AsRawFd;
@@ -68,7 +68,9 @@ fn main() {
     let dns_handle = sockets.add(dns_socket);
 
     let socket = sockets.get_mut::<dns::Socket>(dns_handle);
-    let query = socket.start_query(iface.context(), name).unwrap();
+    let query = socket
+        .start_query(iface.context(), name, DnsQueryType::A)
+        .unwrap();
 
     loop {
         let timestamp = Instant::now();

+ 4 - 12
src/iface/interface.rs

@@ -1015,22 +1015,14 @@ impl<'a> Interface<'a> {
         self.inner.now = timestamp;
 
         #[cfg(feature = "proto-ipv4-fragmentation")]
-        if let Err(e) = self
-            .fragments
+        self.fragments
             .ipv4_fragments
-            .remove_when(|frag| Ok(timestamp >= frag.expires_at()?))
-        {
-            return Err(e);
-        }
+            .remove_when(|frag| Ok(timestamp >= frag.expires_at()?))?;
 
         #[cfg(feature = "proto-sixlowpan-fragmentation")]
-        if let Err(e) = self
-            .fragments
+        self.fragments
             .sixlowpan_fragments
-            .remove_when(|frag| Ok(timestamp >= frag.expires_at()?))
-        {
-            return Err(e);
-        }
+            .remove_when(|frag| Ok(timestamp >= frag.expires_at()?))?;
 
         #[cfg(feature = "proto-ipv4-fragmentation")]
         match self.ipv4_egress(device) {

+ 59 - 11
src/socket/dns.rs

@@ -16,11 +16,20 @@ pub const MAX_ADDRESS_COUNT: usize = 4;
 pub const MAX_SERVER_COUNT: usize = 4;
 
 const DNS_PORT: u16 = 53;
+const MDNS_DNS_PORT: u16 = 5353;
 const MAX_NAME_LEN: usize = 255;
 const RETRANSMIT_DELAY: Duration = Duration::from_millis(1_000);
 const MAX_RETRANSMIT_DELAY: Duration = Duration::from_millis(10_000);
 const RETRANSMIT_TIMEOUT: Duration = Duration::from_millis(10_000); // Should generally be 2-10 secs
 
+#[cfg(feature = "proto-ipv6")]
+const MDNS_IPV6_ADDR: IpAddress = IpAddress::Ipv6(crate::wire::Ipv6Address([
+    0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfb,
+]));
+
+#[cfg(feature = "proto-ipv4")]
+const MDNS_IPV4_ADDR: IpAddress = IpAddress::Ipv4(crate::wire::Ipv4Address([224, 0, 0, 251]));
+
 /// Error returned by [`Socket::start_query`]
 #[derive(Debug, PartialEq, Eq, Clone, Copy)]
 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
@@ -81,6 +90,14 @@ struct PendingQuery {
     delay: Duration,
 
     server_idx: usize,
+    mdns: MulticastDns,
+}
+
+#[derive(Debug)]
+pub enum MulticastDns {
+    Disabled,
+    #[cfg(feature = "socket-mdns")]
+    Enabled,
 }
 
 #[derive(Debug)]
@@ -185,6 +202,7 @@ impl<'a> Socket<'a> {
         &mut self,
         cx: &mut Context,
         name: &str,
+        query_type: Type,
     ) -> Result<QueryHandle, StartQueryError> {
         let mut name = name.as_bytes();
 
@@ -200,6 +218,13 @@ impl<'a> Socket<'a> {
 
         let mut raw_name: Vec<u8, MAX_NAME_LEN> = Vec::new();
 
+        let mut mdns = MulticastDns::Disabled;
+        #[cfg(feature = "socket-mdns")]
+        if name.split(|&c| c == b'.').last().unwrap() == b"local" {
+            net_trace!("Starting a mDNS query");
+            mdns = MulticastDns::Enabled;
+        }
+
         for s in name.split(|&c| c == b'.') {
             if s.len() > 63 {
                 net_trace!("invalid name: too long label");
@@ -224,7 +249,7 @@ impl<'a> Socket<'a> {
             .push(0x00)
             .map_err(|_| StartQueryError::NameTooLong)?;
 
-        self.start_query_raw(cx, &raw_name)
+        self.start_query_raw(cx, &raw_name, query_type, mdns)
     }
 
     /// Start a query with a raw (wire-format) DNS name.
@@ -235,19 +260,22 @@ impl<'a> Socket<'a> {
         &mut self,
         cx: &mut Context,
         raw_name: &[u8],
+        query_type: Type,
+        mdns: MulticastDns,
     ) -> Result<QueryHandle, StartQueryError> {
         let handle = self.find_free_query().ok_or(StartQueryError::NoFreeSlot)?;
 
         self.queries[handle.0] = Some(DnsQuery {
             state: State::Pending(PendingQuery {
                 name: Vec::from_slice(raw_name).map_err(|_| StartQueryError::NameTooLong)?,
-                type_: Type::A,
+                type_: query_type,
                 txid: cx.rand().rand_u16(),
                 port: cx.rand().rand_source_port(),
                 delay: RETRANSMIT_DELAY,
                 timeout_at: None,
                 retransmit_at: Instant::ZERO,
                 server_idx: 0,
+                mdns,
             }),
             #[cfg(feature = "async")]
             waker: WakerRegistration::new(),
@@ -313,11 +341,12 @@ impl<'a> Socket<'a> {
     }
 
     pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool {
-        udp_repr.src_port == DNS_PORT
+        (udp_repr.src_port == DNS_PORT
             && self
                 .servers
                 .iter()
-                .any(|server| *server == ip_repr.src_addr())
+                .any(|server| *server == ip_repr.src_addr()))
+            || (udp_repr.src_port == MDNS_DNS_PORT)
     }
 
     pub(crate) fn process(
@@ -482,6 +511,20 @@ impl<'a> Socket<'a> {
 
         for q in self.queries.iter_mut().flatten() {
             if let State::Pending(pq) = &mut q.state {
+                // As per RFC 6762 any DNS query ending in .local. MUST be sent as mdns
+                // so we internally overwrite the servers for any of those queries
+                // in this function.
+                let servers = match pq.mdns {
+                    #[cfg(feature = "socket-mdns")]
+                    MulticastDns::Enabled => &[
+                        #[cfg(feature = "proto-ipv6")]
+                        MDNS_IPV6_ADDR,
+                        #[cfg(feature = "proto-ipv4")]
+                        MDNS_IPV4_ADDR,
+                    ],
+                    MulticastDns::Disabled => self.servers.as_slice(),
+                };
+
                 let timeout = if let Some(timeout) = pq.timeout_at {
                     timeout
                 } else {
@@ -500,16 +543,15 @@ impl<'a> Socket<'a> {
                     // Try next server. We check below whether we've tried all servers.
                     pq.server_idx += 1;
                 }
-
                 // Check if we've run out of servers to try.
-                if pq.server_idx >= self.servers.len() {
+                if pq.server_idx >= servers.len() {
                     net_trace!("already tried all servers.");
                     q.set_state(State::Failure);
                     continue;
                 }
 
                 // Check so the IP address is valid
-                if self.servers[pq.server_idx].is_unspecified() {
+                if servers[pq.server_idx].is_unspecified() {
                     net_trace!("invalid unspecified DNS server addr.");
                     q.set_state(State::Failure);
                     continue;
@@ -526,7 +568,7 @@ impl<'a> Socket<'a> {
                     opcode: Opcode::Query,
                     question: Question {
                         name: &pq.name,
-                        type_: Type::A,
+                        type_: pq.type_,
                     },
                 };
 
@@ -534,12 +576,18 @@ impl<'a> Socket<'a> {
                 let payload = &mut payload[..repr.buffer_len()];
                 repr.emit(&mut Packet::new_unchecked(payload));
 
+                let dst_port = match pq.mdns {
+                    #[cfg(feature = "socket-mdns")]
+                    MulticastDns::Enabled => MDNS_DNS_PORT,
+                    MulticastDns::Disabled => DNS_PORT,
+                };
+
                 let udp_repr = UdpRepr {
                     src_port: pq.port,
-                    dst_port: 53,
+                    dst_port,
                 };
 
-                let dst_addr = self.servers[pq.server_idx];
+                let dst_addr = servers[pq.server_idx];
                 let src_addr = cx.get_source_address(dst_addr).unwrap(); // TODO remove unwrap
                 let ip_repr = IpRepr::new(
                     src_addr,
@@ -550,7 +598,7 @@ impl<'a> Socket<'a> {
                 );
 
                 net_trace!(
-                    "sending {} octets to {:?}:{}",
+                    "sending {} octets to {} from port {}",
                     payload.len(),
                     ip_repr.dst_addr(),
                     udp_repr.src_port

+ 1 - 1
src/wire/dhcpv4.rs

@@ -364,7 +364,7 @@ impl<T: AsRef<[u8]>> Packet<T> {
         let mut buf = &self.buffer.as_ref()[field::OPTIONS];
         iter::from_fn(move || {
             loop {
-                match buf.get(0).copied() {
+                match buf.first().copied() {
                     // No more options, return.
                     None => return None,
                     Some(field::OPT_END) => return None,

+ 3 - 0
src/wire/mod.rs

@@ -248,6 +248,9 @@ pub use self::dhcpv4::{
     MAX_DNS_SERVER_COUNT as DHCP_MAX_DNS_SERVER_COUNT, SERVER_PORT as DHCP_SERVER_PORT,
 };
 
+#[cfg(feature = "proto-dns")]
+pub use self::dns::{Packet as DnsPacket, Repr as DnsRepr, Type as DnsQueryType};
+
 /// Parsing a packet failed.
 ///
 /// Either it is malformed, or it is not supported by smoltcp.