|
@@ -16,11 +16,20 @@ pub const MAX_ADDRESS_COUNT: usize = 4;
|
|
|
pub const MAX_SERVER_COUNT: usize = 4;
|
|
|
|
|
|
const DNS_PORT: u16 = 53;
|
|
|
+const MDNS_DNS_PORT: u16 = 5353;
|
|
|
const MAX_NAME_LEN: usize = 255;
|
|
|
const RETRANSMIT_DELAY: Duration = Duration::from_millis(1_000);
|
|
|
const MAX_RETRANSMIT_DELAY: Duration = Duration::from_millis(10_000);
|
|
|
const RETRANSMIT_TIMEOUT: Duration = Duration::from_millis(10_000); // Should generally be 2-10 secs
|
|
|
|
|
|
+#[cfg(feature = "proto-ipv6")]
|
|
|
+const MDNS_IPV6_ADDR: IpAddress = IpAddress::Ipv6(crate::wire::Ipv6Address([
|
|
|
+ 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfb,
|
|
|
+]));
|
|
|
+
|
|
|
+#[cfg(feature = "proto-ipv4")]
|
|
|
+const MDNS_IPV4_ADDR: IpAddress = IpAddress::Ipv4(crate::wire::Ipv4Address([224, 0, 0, 251]));
|
|
|
+
|
|
|
/// Error returned by [`Socket::start_query`]
|
|
|
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
|
|
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
|
|
@@ -81,6 +90,14 @@ struct PendingQuery {
|
|
|
delay: Duration,
|
|
|
|
|
|
server_idx: usize,
|
|
|
+ mdns: MulticastDns,
|
|
|
+}
|
|
|
+
|
|
|
+#[derive(Debug)]
|
|
|
+pub enum MulticastDns {
|
|
|
+ Disabled,
|
|
|
+ #[cfg(feature = "socket-mdns")]
|
|
|
+ Enabled,
|
|
|
}
|
|
|
|
|
|
#[derive(Debug)]
|
|
@@ -185,6 +202,7 @@ impl<'a> Socket<'a> {
|
|
|
&mut self,
|
|
|
cx: &mut Context,
|
|
|
name: &str,
|
|
|
+ query_type: Type,
|
|
|
) -> Result<QueryHandle, StartQueryError> {
|
|
|
let mut name = name.as_bytes();
|
|
|
|
|
@@ -200,6 +218,13 @@ impl<'a> Socket<'a> {
|
|
|
|
|
|
let mut raw_name: Vec<u8, MAX_NAME_LEN> = Vec::new();
|
|
|
|
|
|
+ let mut mdns = MulticastDns::Disabled;
|
|
|
+ #[cfg(feature = "socket-mdns")]
|
|
|
+ if name.split(|&c| c == b'.').last().unwrap() == b"local" {
|
|
|
+ net_trace!("Starting a mDNS query");
|
|
|
+ mdns = MulticastDns::Enabled;
|
|
|
+ }
|
|
|
+
|
|
|
for s in name.split(|&c| c == b'.') {
|
|
|
if s.len() > 63 {
|
|
|
net_trace!("invalid name: too long label");
|
|
@@ -224,7 +249,7 @@ impl<'a> Socket<'a> {
|
|
|
.push(0x00)
|
|
|
.map_err(|_| StartQueryError::NameTooLong)?;
|
|
|
|
|
|
- self.start_query_raw(cx, &raw_name)
|
|
|
+ self.start_query_raw(cx, &raw_name, query_type, mdns)
|
|
|
}
|
|
|
|
|
|
/// Start a query with a raw (wire-format) DNS name.
|
|
@@ -235,19 +260,22 @@ impl<'a> Socket<'a> {
|
|
|
&mut self,
|
|
|
cx: &mut Context,
|
|
|
raw_name: &[u8],
|
|
|
+ query_type: Type,
|
|
|
+ mdns: MulticastDns,
|
|
|
) -> Result<QueryHandle, StartQueryError> {
|
|
|
let handle = self.find_free_query().ok_or(StartQueryError::NoFreeSlot)?;
|
|
|
|
|
|
self.queries[handle.0] = Some(DnsQuery {
|
|
|
state: State::Pending(PendingQuery {
|
|
|
name: Vec::from_slice(raw_name).map_err(|_| StartQueryError::NameTooLong)?,
|
|
|
- type_: Type::A,
|
|
|
+ type_: query_type,
|
|
|
txid: cx.rand().rand_u16(),
|
|
|
port: cx.rand().rand_source_port(),
|
|
|
delay: RETRANSMIT_DELAY,
|
|
|
timeout_at: None,
|
|
|
retransmit_at: Instant::ZERO,
|
|
|
server_idx: 0,
|
|
|
+ mdns,
|
|
|
}),
|
|
|
#[cfg(feature = "async")]
|
|
|
waker: WakerRegistration::new(),
|
|
@@ -313,11 +341,12 @@ impl<'a> Socket<'a> {
|
|
|
}
|
|
|
|
|
|
pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool {
|
|
|
- udp_repr.src_port == DNS_PORT
|
|
|
+ (udp_repr.src_port == DNS_PORT
|
|
|
&& self
|
|
|
.servers
|
|
|
.iter()
|
|
|
- .any(|server| *server == ip_repr.src_addr())
|
|
|
+ .any(|server| *server == ip_repr.src_addr()))
|
|
|
+ || (udp_repr.src_port == MDNS_DNS_PORT)
|
|
|
}
|
|
|
|
|
|
pub(crate) fn process(
|
|
@@ -482,6 +511,20 @@ impl<'a> Socket<'a> {
|
|
|
|
|
|
for q in self.queries.iter_mut().flatten() {
|
|
|
if let State::Pending(pq) = &mut q.state {
|
|
|
+ // As per RFC 6762 any DNS query ending in .local. MUST be sent as mdns
|
|
|
+ // so we internally overwrite the servers for any of those queries
|
|
|
+ // in this function.
|
|
|
+ let servers = match pq.mdns {
|
|
|
+ #[cfg(feature = "socket-mdns")]
|
|
|
+ MulticastDns::Enabled => &[
|
|
|
+ #[cfg(feature = "proto-ipv6")]
|
|
|
+ MDNS_IPV6_ADDR,
|
|
|
+ #[cfg(feature = "proto-ipv4")]
|
|
|
+ MDNS_IPV4_ADDR,
|
|
|
+ ],
|
|
|
+ MulticastDns::Disabled => self.servers.as_slice(),
|
|
|
+ };
|
|
|
+
|
|
|
let timeout = if let Some(timeout) = pq.timeout_at {
|
|
|
timeout
|
|
|
} else {
|
|
@@ -500,16 +543,15 @@ impl<'a> Socket<'a> {
|
|
|
// Try next server. We check below whether we've tried all servers.
|
|
|
pq.server_idx += 1;
|
|
|
}
|
|
|
-
|
|
|
// Check if we've run out of servers to try.
|
|
|
- if pq.server_idx >= self.servers.len() {
|
|
|
+ if pq.server_idx >= servers.len() {
|
|
|
net_trace!("already tried all servers.");
|
|
|
q.set_state(State::Failure);
|
|
|
continue;
|
|
|
}
|
|
|
|
|
|
// Check so the IP address is valid
|
|
|
- if self.servers[pq.server_idx].is_unspecified() {
|
|
|
+ if servers[pq.server_idx].is_unspecified() {
|
|
|
net_trace!("invalid unspecified DNS server addr.");
|
|
|
q.set_state(State::Failure);
|
|
|
continue;
|
|
@@ -526,7 +568,7 @@ impl<'a> Socket<'a> {
|
|
|
opcode: Opcode::Query,
|
|
|
question: Question {
|
|
|
name: &pq.name,
|
|
|
- type_: Type::A,
|
|
|
+ type_: pq.type_,
|
|
|
},
|
|
|
};
|
|
|
|
|
@@ -534,12 +576,18 @@ impl<'a> Socket<'a> {
|
|
|
let payload = &mut payload[..repr.buffer_len()];
|
|
|
repr.emit(&mut Packet::new_unchecked(payload));
|
|
|
|
|
|
+ let dst_port = match pq.mdns {
|
|
|
+ #[cfg(feature = "socket-mdns")]
|
|
|
+ MulticastDns::Enabled => MDNS_DNS_PORT,
|
|
|
+ MulticastDns::Disabled => DNS_PORT,
|
|
|
+ };
|
|
|
+
|
|
|
let udp_repr = UdpRepr {
|
|
|
src_port: pq.port,
|
|
|
- dst_port: 53,
|
|
|
+ dst_port,
|
|
|
};
|
|
|
|
|
|
- let dst_addr = self.servers[pq.server_idx];
|
|
|
+ let dst_addr = servers[pq.server_idx];
|
|
|
let src_addr = cx.get_source_address(dst_addr).unwrap(); // TODO remove unwrap
|
|
|
let ip_repr = IpRepr::new(
|
|
|
src_addr,
|
|
@@ -550,7 +598,7 @@ impl<'a> Socket<'a> {
|
|
|
);
|
|
|
|
|
|
net_trace!(
|
|
|
- "sending {} octets to {:?}:{}",
|
|
|
+ "sending {} octets to {} from port {}",
|
|
|
payload.len(),
|
|
|
ip_repr.dst_addr(),
|
|
|
udp_repr.src_port
|