#[cfg(feature = "async")] use core::task::Waker; use heapless::Vec; use managed::ManagedSlice; use crate::config::{DNS_MAX_NAME_SIZE, DNS_MAX_RESULT_COUNT, DNS_MAX_SERVER_COUNT}; use crate::socket::{Context, PollAt}; use crate::time::{Duration, Instant}; use crate::wire::dns::{Flags, Opcode, Packet, Question, Rcode, Record, RecordData, Repr, Type}; use crate::wire::{self, IpAddress, IpProtocol, IpRepr, UdpRepr}; #[cfg(feature = "async")] use super::WakerRegistration; const DNS_PORT: u16 = 53; const MDNS_DNS_PORT: u16 = 5353; const RETRANSMIT_DELAY: Duration = Duration::from_millis(1_000); const MAX_RETRANSMIT_DELAY: Duration = Duration::from_millis(10_000); const RETRANSMIT_TIMEOUT: Duration = Duration::from_millis(10_000); // Should generally be 2-10 secs #[cfg(feature = "proto-ipv6")] const MDNS_IPV6_ADDR: IpAddress = IpAddress::Ipv6(crate::wire::Ipv6Address([ 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfb, ])); #[cfg(feature = "proto-ipv4")] const MDNS_IPV4_ADDR: IpAddress = IpAddress::Ipv4(crate::wire::Ipv4Address([224, 0, 0, 251])); /// Error returned by [`Socket::start_query`] #[derive(Debug, PartialEq, Eq, Clone, Copy)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum StartQueryError { NoFreeSlot, InvalidName, NameTooLong, } impl core::fmt::Display for StartQueryError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { StartQueryError::NoFreeSlot => write!(f, "No free slot"), StartQueryError::InvalidName => write!(f, "Invalid name"), StartQueryError::NameTooLong => write!(f, "Name too long"), } } } #[cfg(feature = "std")] impl std::error::Error for StartQueryError {} /// Error returned by [`Socket::get_query_result`] #[derive(Debug, PartialEq, Eq, Clone, Copy)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum GetQueryResultError { /// Query is not done yet. Pending, /// Query failed. Failed, } impl core::fmt::Display for GetQueryResultError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { GetQueryResultError::Pending => write!(f, "Query is not done yet"), GetQueryResultError::Failed => write!(f, "Query failed"), } } } #[cfg(feature = "std")] impl std::error::Error for GetQueryResultError {} /// State for an in-progress DNS query. /// /// The only reason this struct is public is to allow the socket state /// to be allocated externally. #[derive(Debug)] pub struct DnsQuery { state: State, #[cfg(feature = "async")] waker: WakerRegistration, } impl DnsQuery { fn set_state(&mut self, state: State) { self.state = state; #[cfg(feature = "async")] self.waker.wake(); } } #[derive(Debug)] #[allow(clippy::large_enum_variant)] enum State { Pending(PendingQuery), Completed(CompletedQuery), Failure, } #[derive(Debug)] struct PendingQuery { name: Vec, type_: Type, port: u16, // UDP port (src for request, dst for response) txid: u16, // transaction ID timeout_at: Option, retransmit_at: Instant, delay: Duration, server_idx: usize, mdns: MulticastDns, } #[derive(Debug)] pub enum MulticastDns { Disabled, #[cfg(feature = "socket-mdns")] Enabled, } #[derive(Debug)] struct CompletedQuery { addresses: Vec, } /// A handle to an in-progress DNS query. #[derive(Clone, Copy)] pub struct QueryHandle(usize); /// A Domain Name System socket. /// /// A UDP socket is bound to a specific endpoint, and owns transmit and receive /// packet buffers. #[derive(Debug)] pub struct Socket<'a> { servers: Vec, queries: ManagedSlice<'a, Option>, /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. hop_limit: Option, } impl<'a> Socket<'a> { /// Create a DNS socket. /// /// # Panics /// /// Panics if `servers.len() > MAX_SERVER_COUNT` pub fn new(servers: &[IpAddress], queries: Q) -> Socket<'a> where Q: Into>>, { Socket { servers: Vec::from_slice(servers).unwrap(), queries: queries.into(), hop_limit: None, } } /// Update the list of DNS servers, will replace all existing servers /// /// # Panics /// /// Panics if `servers.len() > MAX_SERVER_COUNT` pub fn update_servers(&mut self, servers: &[IpAddress]) { self.servers = Vec::from_slice(servers).unwrap(); } /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. /// /// See also the [set_hop_limit](#method.set_hop_limit) method pub fn hop_limit(&self) -> Option { self.hop_limit } /// Set the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. /// /// A socket without an explicitly set hop limit value uses the default [IANA recommended] /// value (64). /// /// # Panics /// /// This function panics if a hop limit value of 0 is given. See [RFC 1122 § 3.2.1.7]. /// /// [IANA recommended]: https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml /// [RFC 1122 § 3.2.1.7]: https://tools.ietf.org/html/rfc1122#section-3.2.1.7 pub fn set_hop_limit(&mut self, hop_limit: Option) { // A host MUST NOT send a datagram with a hop limit value of 0 if let Some(0) = hop_limit { panic!("the time-to-live value of a packet must not be zero") } self.hop_limit = hop_limit } fn find_free_query(&mut self) -> Option { for (i, q) in self.queries.iter().enumerate() { if q.is_none() { return Some(QueryHandle(i)); } } match &mut self.queries { ManagedSlice::Borrowed(_) => None, #[cfg(feature = "alloc")] ManagedSlice::Owned(queries) => { queries.push(None); let index = queries.len() - 1; Some(QueryHandle(index)) } } } /// Start a query. /// /// `name` is specified in human-friendly format, such as `"rust-lang.org"`. /// It accepts names both with and without trailing dot, and they're treated /// the same (there's no support for DNS search path). pub fn start_query( &mut self, cx: &mut Context, name: &str, query_type: Type, ) -> Result { let mut name = name.as_bytes(); if name.is_empty() { net_trace!("invalid name: zero length"); return Err(StartQueryError::InvalidName); } // Remove trailing dot, if any if name[name.len() - 1] == b'.' { name = &name[..name.len() - 1]; } let mut raw_name: Vec = Vec::new(); let mut mdns = MulticastDns::Disabled; #[cfg(feature = "socket-mdns")] if name.split(|&c| c == b'.').last().unwrap() == b"local" { net_trace!("Starting a mDNS query"); mdns = MulticastDns::Enabled; } for s in name.split(|&c| c == b'.') { if s.len() > 63 { net_trace!("invalid name: too long label"); return Err(StartQueryError::InvalidName); } if s.is_empty() { net_trace!("invalid name: zero length label"); return Err(StartQueryError::InvalidName); } // Push label raw_name .push(s.len() as u8) .map_err(|_| StartQueryError::NameTooLong)?; raw_name .extend_from_slice(s) .map_err(|_| StartQueryError::NameTooLong)?; } // Push terminator. raw_name .push(0x00) .map_err(|_| StartQueryError::NameTooLong)?; self.start_query_raw(cx, &raw_name, query_type, mdns) } /// Start a query with a raw (wire-format) DNS name. /// `b"\x09rust-lang\x03org\x00"` /// /// You probably want to use [`start_query`] instead. pub fn start_query_raw( &mut self, cx: &mut Context, raw_name: &[u8], query_type: Type, mdns: MulticastDns, ) -> Result { let handle = self.find_free_query().ok_or(StartQueryError::NoFreeSlot)?; self.queries[handle.0] = Some(DnsQuery { state: State::Pending(PendingQuery { name: Vec::from_slice(raw_name).map_err(|_| StartQueryError::NameTooLong)?, type_: query_type, txid: cx.rand().rand_u16(), port: cx.rand().rand_source_port(), delay: RETRANSMIT_DELAY, timeout_at: None, retransmit_at: Instant::ZERO, server_idx: 0, mdns, }), #[cfg(feature = "async")] waker: WakerRegistration::new(), }); Ok(handle) } /// Get the result of a query. /// /// If the query is completed, the query slot is automatically freed. /// /// # Panics /// Panics if the QueryHandle corresponds to a free slot. pub fn get_query_result( &mut self, handle: QueryHandle, ) -> Result, GetQueryResultError> { let slot = &mut self.queries[handle.0]; let q = slot.as_mut().unwrap(); match &mut q.state { // Query is not done yet. State::Pending(_) => Err(GetQueryResultError::Pending), // Query is done State::Completed(q) => { let res = q.addresses.clone(); *slot = None; // Free up the slot for recycling. Ok(res) } State::Failure => { *slot = None; // Free up the slot for recycling. Err(GetQueryResultError::Failed) } } } /// Cancels a query, freeing the slot. /// /// # Panics /// /// Panics if the QueryHandle corresponds to an already free slot. pub fn cancel_query(&mut self, handle: QueryHandle) { let slot = &mut self.queries[handle.0]; if slot.is_none() { panic!("Canceling query in a free slot.") } *slot = None; // Free up the slot for recycling. } /// Assign a waker to a query slot /// /// The waker will be woken when the query completes, either successfully or failed. /// /// # Panics /// /// Panics if the QueryHandle corresponds to an already free slot. #[cfg(feature = "async")] pub fn register_query_waker(&mut self, handle: QueryHandle, waker: &Waker) { self.queries[handle.0] .as_mut() .unwrap() .waker .register(waker); } pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool { (udp_repr.src_port == DNS_PORT && self .servers .iter() .any(|server| *server == ip_repr.src_addr())) || (udp_repr.src_port == MDNS_DNS_PORT) } pub(crate) fn process( &mut self, _cx: &mut Context, ip_repr: &IpRepr, udp_repr: &UdpRepr, payload: &[u8], ) { debug_assert!(self.accepts(ip_repr, udp_repr)); let size = payload.len(); net_trace!( "receiving {} octets from {:?}:{}", size, ip_repr.src_addr(), udp_repr.dst_port ); let p = match Packet::new_checked(payload) { Ok(x) => x, Err(_) => { net_trace!("dns packet malformed"); return; } }; if p.opcode() != Opcode::Query { net_trace!("unwanted opcode {:?}", p.opcode()); return; } if !p.flags().contains(Flags::RESPONSE) { net_trace!("packet doesn't have response bit set"); return; } if p.question_count() != 1 { net_trace!("bad question count {:?}", p.question_count()); return; } // Find pending query for q in self.queries.iter_mut().flatten() { if let State::Pending(pq) = &mut q.state { if udp_repr.dst_port != pq.port || p.transaction_id() != pq.txid { continue; } if p.rcode() == Rcode::NXDomain { net_trace!("rcode NXDomain"); q.set_state(State::Failure); continue; } let payload = p.payload(); let (mut payload, question) = match Question::parse(payload) { Ok(x) => x, Err(_) => { net_trace!("question malformed"); return; } }; if question.type_ != pq.type_ { net_trace!("question type mismatch"); return; } match eq_names(p.parse_name(question.name), p.parse_name(&pq.name)) { Ok(true) => {} Ok(false) => { net_trace!("question name mismatch"); return; } Err(_) => { net_trace!("dns question name malformed"); return; } } let mut addresses = Vec::new(); for _ in 0..p.answer_record_count() { let (payload2, r) = match Record::parse(payload) { Ok(x) => x, Err(_) => { net_trace!("dns answer record malformed"); return; } }; payload = payload2; match eq_names(p.parse_name(r.name), p.parse_name(&pq.name)) { Ok(true) => {} Ok(false) => { net_trace!("answer name mismatch: {:?}", r); continue; } Err(_) => { net_trace!("dns answer record name malformed"); return; } } match r.data { #[cfg(feature = "proto-ipv4")] RecordData::A(addr) => { net_trace!("A: {:?}", addr); if addresses.push(addr.into()).is_err() { net_trace!("too many addresses in response, ignoring {:?}", addr); } } #[cfg(feature = "proto-ipv6")] RecordData::Aaaa(addr) => { net_trace!("AAAA: {:?}", addr); if addresses.push(addr.into()).is_err() { net_trace!("too many addresses in response, ignoring {:?}", addr); } } RecordData::Cname(name) => { net_trace!("CNAME: {:?}", name); // When faced with a CNAME, recursive resolvers are supposed to // resolve the CNAME and append the results for it. // // We update the query with the new name, so that we pick up the A/AAAA // records for the CNAME when we parse them later. // I believe it's mandatory the CNAME results MUST come *after* in the // packet, so it's enough to do one linear pass over it. if copy_name(&mut pq.name, p.parse_name(name)).is_err() { net_trace!("dns answer cname malformed"); return; } } RecordData::Other(type_, data) => { net_trace!("unknown: {:?} {:?}", type_, data) } } } q.set_state(if addresses.is_empty() { State::Failure } else { State::Completed(CompletedQuery { addresses }) }); // If we get here, packet matched the current query, stop processing. return; } } // If we get here, packet matched with no query. net_trace!("no query matched"); } pub(crate) fn dispatch(&mut self, cx: &mut Context, emit: F) -> Result<(), E> where F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), E>, { let hop_limit = self.hop_limit.unwrap_or(64); for q in self.queries.iter_mut().flatten() { if let State::Pending(pq) = &mut q.state { // As per RFC 6762 any DNS query ending in .local. MUST be sent as mdns // so we internally overwrite the servers for any of those queries // in this function. let servers = match pq.mdns { #[cfg(feature = "socket-mdns")] MulticastDns::Enabled => &[ #[cfg(feature = "proto-ipv6")] MDNS_IPV6_ADDR, #[cfg(feature = "proto-ipv4")] MDNS_IPV4_ADDR, ], MulticastDns::Disabled => self.servers.as_slice(), }; let timeout = if let Some(timeout) = pq.timeout_at { timeout } else { let v = cx.now() + RETRANSMIT_TIMEOUT; pq.timeout_at = Some(v); v }; // Check timeout if timeout < cx.now() { // DNS timeout pq.timeout_at = Some(cx.now() + RETRANSMIT_TIMEOUT); pq.retransmit_at = Instant::ZERO; pq.delay = RETRANSMIT_DELAY; // Try next server. We check below whether we've tried all servers. pq.server_idx += 1; } // Check if we've run out of servers to try. if pq.server_idx >= servers.len() { net_trace!("already tried all servers."); q.set_state(State::Failure); continue; } // Check so the IP address is valid if servers[pq.server_idx].is_unspecified() { net_trace!("invalid unspecified DNS server addr."); q.set_state(State::Failure); continue; } if pq.retransmit_at > cx.now() { // query is waiting for retransmit continue; } let repr = Repr { transaction_id: pq.txid, flags: Flags::RECURSION_DESIRED, opcode: Opcode::Query, question: Question { name: &pq.name, type_: pq.type_, }, }; let mut payload = [0u8; 512]; let payload = &mut payload[..repr.buffer_len()]; repr.emit(&mut Packet::new_unchecked(payload)); let dst_port = match pq.mdns { #[cfg(feature = "socket-mdns")] MulticastDns::Enabled => MDNS_DNS_PORT, MulticastDns::Disabled => DNS_PORT, }; let udp_repr = UdpRepr { src_port: pq.port, dst_port, }; let dst_addr = servers[pq.server_idx]; let src_addr = cx.get_source_address(dst_addr).unwrap(); // TODO remove unwrap let ip_repr = IpRepr::new( src_addr, dst_addr, IpProtocol::Udp, udp_repr.header_len() + payload.len(), hop_limit, ); net_trace!( "sending {} octets to {} from port {}", payload.len(), ip_repr.dst_addr(), udp_repr.src_port ); emit(cx, (ip_repr, udp_repr, payload))?; pq.retransmit_at = cx.now() + pq.delay; pq.delay = MAX_RETRANSMIT_DELAY.min(pq.delay * 2); return Ok(()); } } // Nothing to dispatch Ok(()) } pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt { self.queries .iter() .flatten() .filter_map(|q| match &q.state { State::Pending(pq) => Some(PollAt::Time(pq.retransmit_at)), State::Completed(_) => None, State::Failure => None, }) .min() .unwrap_or(PollAt::Ingress) } } fn eq_names<'a>( mut a: impl Iterator>, mut b: impl Iterator>, ) -> wire::Result { loop { match (a.next(), b.next()) { // Handle errors (Some(Err(e)), _) => return Err(e), (_, Some(Err(e))) => return Err(e), // Both finished -> equal (None, None) => return Ok(true), // One finished before the other -> not equal (None, _) => return Ok(false), (_, None) => return Ok(false), // Got two labels, check if they're equal (Some(Ok(la)), Some(Ok(lb))) => { if la != lb { return Ok(false); } } } } } fn copy_name<'a, const N: usize>( dest: &mut Vec, name: impl Iterator>, ) -> Result<(), wire::Error> { dest.truncate(0); for label in name { let label = label?; dest.push(label.len() as u8).map_err(|_| wire::Error)?; dest.extend_from_slice(label).map_err(|_| wire::Error)?; } // Write terminator 0x00 dest.push(0).map_err(|_| wire::Error)?; Ok(()) }