123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699 |
- #[cfg(feature = "async")]
- use core::task::Waker;
- use heapless::Vec;
- use managed::ManagedSlice;
- use crate::config::{DNS_MAX_NAME_SIZE, DNS_MAX_RESULT_COUNT, DNS_MAX_SERVER_COUNT};
- use crate::socket::{Context, PollAt};
- use crate::time::{Duration, Instant};
- use crate::wire::dns::{Flags, Opcode, Packet, Question, Rcode, Record, RecordData, Repr, Type};
- use crate::wire::{self, IpAddress, IpProtocol, IpRepr, UdpRepr};
- #[cfg(feature = "async")]
- use super::WakerRegistration;
- const DNS_PORT: u16 = 53;
- const MDNS_DNS_PORT: u16 = 5353;
- const RETRANSMIT_DELAY: Duration = Duration::from_millis(1_000);
- const MAX_RETRANSMIT_DELAY: Duration = Duration::from_millis(10_000);
- const RETRANSMIT_TIMEOUT: Duration = Duration::from_millis(10_000); // Should generally be 2-10 secs
- #[cfg(feature = "proto-ipv6")]
- const MDNS_IPV6_ADDR: IpAddress = IpAddress::Ipv6(crate::wire::Ipv6Address([
- 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfb,
- ]));
- #[cfg(feature = "proto-ipv4")]
- const MDNS_IPV4_ADDR: IpAddress = IpAddress::Ipv4(crate::wire::Ipv4Address([224, 0, 0, 251]));
- /// Error returned by [`Socket::start_query`]
- #[derive(Debug, PartialEq, Eq, Clone, Copy)]
- #[cfg_attr(feature = "defmt", derive(defmt::Format))]
- pub enum StartQueryError {
- NoFreeSlot,
- InvalidName,
- NameTooLong,
- }
- impl core::fmt::Display for StartQueryError {
- fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
- match self {
- StartQueryError::NoFreeSlot => write!(f, "No free slot"),
- StartQueryError::InvalidName => write!(f, "Invalid name"),
- StartQueryError::NameTooLong => write!(f, "Name too long"),
- }
- }
- }
- #[cfg(feature = "std")]
- impl std::error::Error for StartQueryError {}
- /// Error returned by [`Socket::get_query_result`]
- #[derive(Debug, PartialEq, Eq, Clone, Copy)]
- #[cfg_attr(feature = "defmt", derive(defmt::Format))]
- pub enum GetQueryResultError {
- /// Query is not done yet.
- Pending,
- /// Query failed.
- Failed,
- }
- impl core::fmt::Display for GetQueryResultError {
- fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
- match self {
- GetQueryResultError::Pending => write!(f, "Query is not done yet"),
- GetQueryResultError::Failed => write!(f, "Query failed"),
- }
- }
- }
- #[cfg(feature = "std")]
- impl std::error::Error for GetQueryResultError {}
- /// State for an in-progress DNS query.
- ///
- /// The only reason this struct is public is to allow the socket state
- /// to be allocated externally.
- #[derive(Debug)]
- pub struct DnsQuery {
- state: State,
- #[cfg(feature = "async")]
- waker: WakerRegistration,
- }
- impl DnsQuery {
- fn set_state(&mut self, state: State) {
- self.state = state;
- #[cfg(feature = "async")]
- self.waker.wake();
- }
- }
- #[derive(Debug)]
- #[allow(clippy::large_enum_variant)]
- enum State {
- Pending(PendingQuery),
- Completed(CompletedQuery),
- Failure,
- }
- #[derive(Debug)]
- struct PendingQuery {
- name: Vec<u8, DNS_MAX_NAME_SIZE>,
- type_: Type,
- port: u16, // UDP port (src for request, dst for response)
- txid: u16, // transaction ID
- timeout_at: Option<Instant>,
- retransmit_at: Instant,
- delay: Duration,
- server_idx: usize,
- mdns: MulticastDns,
- }
- #[derive(Debug)]
- pub enum MulticastDns {
- Disabled,
- #[cfg(feature = "socket-mdns")]
- Enabled,
- }
- #[derive(Debug)]
- struct CompletedQuery {
- addresses: Vec<IpAddress, DNS_MAX_RESULT_COUNT>,
- }
- /// A handle to an in-progress DNS query.
- #[derive(Clone, Copy)]
- pub struct QueryHandle(usize);
- /// A Domain Name System socket.
- ///
- /// A UDP socket is bound to a specific endpoint, and owns transmit and receive
- /// packet buffers.
- #[derive(Debug)]
- pub struct Socket<'a> {
- servers: Vec<IpAddress, DNS_MAX_SERVER_COUNT>,
- queries: ManagedSlice<'a, Option<DnsQuery>>,
- /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
- hop_limit: Option<u8>,
- }
- impl<'a> Socket<'a> {
- /// Create a DNS socket.
- ///
- /// # Panics
- ///
- /// Panics if `servers.len() > MAX_SERVER_COUNT`
- pub fn new<Q>(servers: &[IpAddress], queries: Q) -> Socket<'a>
- where
- Q: Into<ManagedSlice<'a, Option<DnsQuery>>>,
- {
- Socket {
- servers: Vec::from_slice(servers).unwrap(),
- queries: queries.into(),
- hop_limit: None,
- }
- }
- /// Update the list of DNS servers, will replace all existing servers
- ///
- /// # Panics
- ///
- /// Panics if `servers.len() > MAX_SERVER_COUNT`
- pub fn update_servers(&mut self, servers: &[IpAddress]) {
- self.servers = Vec::from_slice(servers).unwrap();
- }
- /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
- ///
- /// See also the [set_hop_limit](#method.set_hop_limit) method
- pub fn hop_limit(&self) -> Option<u8> {
- self.hop_limit
- }
- /// Set the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
- ///
- /// A socket without an explicitly set hop limit value uses the default [IANA recommended]
- /// value (64).
- ///
- /// # Panics
- ///
- /// This function panics if a hop limit value of 0 is given. See [RFC 1122 § 3.2.1.7].
- ///
- /// [IANA recommended]: https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml
- /// [RFC 1122 § 3.2.1.7]: https://tools.ietf.org/html/rfc1122#section-3.2.1.7
- pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
- // A host MUST NOT send a datagram with a hop limit value of 0
- if let Some(0) = hop_limit {
- panic!("the time-to-live value of a packet must not be zero")
- }
- self.hop_limit = hop_limit
- }
- fn find_free_query(&mut self) -> Option<QueryHandle> {
- for (i, q) in self.queries.iter().enumerate() {
- if q.is_none() {
- return Some(QueryHandle(i));
- }
- }
- match &mut self.queries {
- ManagedSlice::Borrowed(_) => None,
- #[cfg(feature = "alloc")]
- ManagedSlice::Owned(queries) => {
- queries.push(None);
- let index = queries.len() - 1;
- Some(QueryHandle(index))
- }
- }
- }
- /// Start a query.
- ///
- /// `name` is specified in human-friendly format, such as `"rust-lang.org"`.
- /// It accepts names both with and without trailing dot, and they're treated
- /// the same (there's no support for DNS search path).
- pub fn start_query(
- &mut self,
- cx: &mut Context,
- name: &str,
- query_type: Type,
- ) -> Result<QueryHandle, StartQueryError> {
- let mut name = name.as_bytes();
- if name.is_empty() {
- net_trace!("invalid name: zero length");
- return Err(StartQueryError::InvalidName);
- }
- // Remove trailing dot, if any
- if name[name.len() - 1] == b'.' {
- name = &name[..name.len() - 1];
- }
- let mut raw_name: Vec<u8, DNS_MAX_NAME_SIZE> = Vec::new();
- let mut mdns = MulticastDns::Disabled;
- #[cfg(feature = "socket-mdns")]
- if name.split(|&c| c == b'.').last().unwrap() == b"local" {
- net_trace!("Starting a mDNS query");
- mdns = MulticastDns::Enabled;
- }
- for s in name.split(|&c| c == b'.') {
- if s.len() > 63 {
- net_trace!("invalid name: too long label");
- return Err(StartQueryError::InvalidName);
- }
- if s.is_empty() {
- net_trace!("invalid name: zero length label");
- return Err(StartQueryError::InvalidName);
- }
- // Push label
- raw_name
- .push(s.len() as u8)
- .map_err(|_| StartQueryError::NameTooLong)?;
- raw_name
- .extend_from_slice(s)
- .map_err(|_| StartQueryError::NameTooLong)?;
- }
- // Push terminator.
- raw_name
- .push(0x00)
- .map_err(|_| StartQueryError::NameTooLong)?;
- self.start_query_raw(cx, &raw_name, query_type, mdns)
- }
- /// Start a query with a raw (wire-format) DNS name.
- /// `b"\x09rust-lang\x03org\x00"`
- ///
- /// You probably want to use [`start_query`] instead.
- pub fn start_query_raw(
- &mut self,
- cx: &mut Context,
- raw_name: &[u8],
- query_type: Type,
- mdns: MulticastDns,
- ) -> Result<QueryHandle, StartQueryError> {
- let handle = self.find_free_query().ok_or(StartQueryError::NoFreeSlot)?;
- self.queries[handle.0] = Some(DnsQuery {
- state: State::Pending(PendingQuery {
- name: Vec::from_slice(raw_name).map_err(|_| StartQueryError::NameTooLong)?,
- type_: query_type,
- txid: cx.rand().rand_u16(),
- port: cx.rand().rand_source_port(),
- delay: RETRANSMIT_DELAY,
- timeout_at: None,
- retransmit_at: Instant::ZERO,
- server_idx: 0,
- mdns,
- }),
- #[cfg(feature = "async")]
- waker: WakerRegistration::new(),
- });
- Ok(handle)
- }
- /// Get the result of a query.
- ///
- /// If the query is completed, the query slot is automatically freed.
- ///
- /// # Panics
- /// Panics if the QueryHandle corresponds to a free slot.
- pub fn get_query_result(
- &mut self,
- handle: QueryHandle,
- ) -> Result<Vec<IpAddress, DNS_MAX_RESULT_COUNT>, GetQueryResultError> {
- let slot = &mut self.queries[handle.0];
- let q = slot.as_mut().unwrap();
- match &mut q.state {
- // Query is not done yet.
- State::Pending(_) => Err(GetQueryResultError::Pending),
- // Query is done
- State::Completed(q) => {
- let res = q.addresses.clone();
- *slot = None; // Free up the slot for recycling.
- Ok(res)
- }
- State::Failure => {
- *slot = None; // Free up the slot for recycling.
- Err(GetQueryResultError::Failed)
- }
- }
- }
- /// Cancels a query, freeing the slot.
- ///
- /// # Panics
- ///
- /// Panics if the QueryHandle corresponds to an already free slot.
- pub fn cancel_query(&mut self, handle: QueryHandle) {
- let slot = &mut self.queries[handle.0];
- if slot.is_none() {
- panic!("Canceling query in a free slot.")
- }
- *slot = None; // Free up the slot for recycling.
- }
- /// Assign a waker to a query slot
- ///
- /// The waker will be woken when the query completes, either successfully or failed.
- ///
- /// # Panics
- ///
- /// Panics if the QueryHandle corresponds to an already free slot.
- #[cfg(feature = "async")]
- pub fn register_query_waker(&mut self, handle: QueryHandle, waker: &Waker) {
- self.queries[handle.0]
- .as_mut()
- .unwrap()
- .waker
- .register(waker);
- }
- pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool {
- (udp_repr.src_port == DNS_PORT
- && self
- .servers
- .iter()
- .any(|server| *server == ip_repr.src_addr()))
- || (udp_repr.src_port == MDNS_DNS_PORT)
- }
- pub(crate) fn process(
- &mut self,
- _cx: &mut Context,
- ip_repr: &IpRepr,
- udp_repr: &UdpRepr,
- payload: &[u8],
- ) {
- debug_assert!(self.accepts(ip_repr, udp_repr));
- let size = payload.len();
- net_trace!(
- "receiving {} octets from {:?}:{}",
- size,
- ip_repr.src_addr(),
- udp_repr.dst_port
- );
- let p = match Packet::new_checked(payload) {
- Ok(x) => x,
- Err(_) => {
- net_trace!("dns packet malformed");
- return;
- }
- };
- if p.opcode() != Opcode::Query {
- net_trace!("unwanted opcode {:?}", p.opcode());
- return;
- }
- if !p.flags().contains(Flags::RESPONSE) {
- net_trace!("packet doesn't have response bit set");
- return;
- }
- if p.question_count() != 1 {
- net_trace!("bad question count {:?}", p.question_count());
- return;
- }
- // Find pending query
- for q in self.queries.iter_mut().flatten() {
- if let State::Pending(pq) = &mut q.state {
- if udp_repr.dst_port != pq.port || p.transaction_id() != pq.txid {
- continue;
- }
- if p.rcode() == Rcode::NXDomain {
- net_trace!("rcode NXDomain");
- q.set_state(State::Failure);
- continue;
- }
- let payload = p.payload();
- let (mut payload, question) = match Question::parse(payload) {
- Ok(x) => x,
- Err(_) => {
- net_trace!("question malformed");
- return;
- }
- };
- if question.type_ != pq.type_ {
- net_trace!("question type mismatch");
- return;
- }
- match eq_names(p.parse_name(question.name), p.parse_name(&pq.name)) {
- Ok(true) => {}
- Ok(false) => {
- net_trace!("question name mismatch");
- return;
- }
- Err(_) => {
- net_trace!("dns question name malformed");
- return;
- }
- }
- let mut addresses = Vec::new();
- for _ in 0..p.answer_record_count() {
- let (payload2, r) = match Record::parse(payload) {
- Ok(x) => x,
- Err(_) => {
- net_trace!("dns answer record malformed");
- return;
- }
- };
- payload = payload2;
- match eq_names(p.parse_name(r.name), p.parse_name(&pq.name)) {
- Ok(true) => {}
- Ok(false) => {
- net_trace!("answer name mismatch: {:?}", r);
- continue;
- }
- Err(_) => {
- net_trace!("dns answer record name malformed");
- return;
- }
- }
- match r.data {
- #[cfg(feature = "proto-ipv4")]
- RecordData::A(addr) => {
- net_trace!("A: {:?}", addr);
- if addresses.push(addr.into()).is_err() {
- net_trace!("too many addresses in response, ignoring {:?}", addr);
- }
- }
- #[cfg(feature = "proto-ipv6")]
- RecordData::Aaaa(addr) => {
- net_trace!("AAAA: {:?}", addr);
- if addresses.push(addr.into()).is_err() {
- net_trace!("too many addresses in response, ignoring {:?}", addr);
- }
- }
- RecordData::Cname(name) => {
- net_trace!("CNAME: {:?}", name);
- // When faced with a CNAME, recursive resolvers are supposed to
- // resolve the CNAME and append the results for it.
- //
- // We update the query with the new name, so that we pick up the A/AAAA
- // records for the CNAME when we parse them later.
- // I believe it's mandatory the CNAME results MUST come *after* in the
- // packet, so it's enough to do one linear pass over it.
- if copy_name(&mut pq.name, p.parse_name(name)).is_err() {
- net_trace!("dns answer cname malformed");
- return;
- }
- }
- RecordData::Other(type_, data) => {
- net_trace!("unknown: {:?} {:?}", type_, data)
- }
- }
- }
- q.set_state(if addresses.is_empty() {
- State::Failure
- } else {
- State::Completed(CompletedQuery { addresses })
- });
- // If we get here, packet matched the current query, stop processing.
- return;
- }
- }
- // If we get here, packet matched with no query.
- net_trace!("no query matched");
- }
- pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
- where
- F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), E>,
- {
- let hop_limit = self.hop_limit.unwrap_or(64);
- for q in self.queries.iter_mut().flatten() {
- if let State::Pending(pq) = &mut q.state {
- // As per RFC 6762 any DNS query ending in .local. MUST be sent as mdns
- // so we internally overwrite the servers for any of those queries
- // in this function.
- let servers = match pq.mdns {
- #[cfg(feature = "socket-mdns")]
- MulticastDns::Enabled => &[
- #[cfg(feature = "proto-ipv6")]
- MDNS_IPV6_ADDR,
- #[cfg(feature = "proto-ipv4")]
- MDNS_IPV4_ADDR,
- ],
- MulticastDns::Disabled => self.servers.as_slice(),
- };
- let timeout = if let Some(timeout) = pq.timeout_at {
- timeout
- } else {
- let v = cx.now() + RETRANSMIT_TIMEOUT;
- pq.timeout_at = Some(v);
- v
- };
- // Check timeout
- if timeout < cx.now() {
- // DNS timeout
- pq.timeout_at = Some(cx.now() + RETRANSMIT_TIMEOUT);
- pq.retransmit_at = Instant::ZERO;
- pq.delay = RETRANSMIT_DELAY;
- // Try next server. We check below whether we've tried all servers.
- pq.server_idx += 1;
- }
- // Check if we've run out of servers to try.
- if pq.server_idx >= servers.len() {
- net_trace!("already tried all servers.");
- q.set_state(State::Failure);
- continue;
- }
- // Check so the IP address is valid
- if servers[pq.server_idx].is_unspecified() {
- net_trace!("invalid unspecified DNS server addr.");
- q.set_state(State::Failure);
- continue;
- }
- if pq.retransmit_at > cx.now() {
- // query is waiting for retransmit
- continue;
- }
- let repr = Repr {
- transaction_id: pq.txid,
- flags: Flags::RECURSION_DESIRED,
- opcode: Opcode::Query,
- question: Question {
- name: &pq.name,
- type_: pq.type_,
- },
- };
- let mut payload = [0u8; 512];
- let payload = &mut payload[..repr.buffer_len()];
- repr.emit(&mut Packet::new_unchecked(payload));
- let dst_port = match pq.mdns {
- #[cfg(feature = "socket-mdns")]
- MulticastDns::Enabled => MDNS_DNS_PORT,
- MulticastDns::Disabled => DNS_PORT,
- };
- let udp_repr = UdpRepr {
- src_port: pq.port,
- dst_port,
- };
- let dst_addr = servers[pq.server_idx];
- let src_addr = cx.get_source_address(dst_addr).unwrap(); // TODO remove unwrap
- let ip_repr = IpRepr::new(
- src_addr,
- dst_addr,
- IpProtocol::Udp,
- udp_repr.header_len() + payload.len(),
- hop_limit,
- );
- net_trace!(
- "sending {} octets to {} from port {}",
- payload.len(),
- ip_repr.dst_addr(),
- udp_repr.src_port
- );
- emit(cx, (ip_repr, udp_repr, payload))?;
- pq.retransmit_at = cx.now() + pq.delay;
- pq.delay = MAX_RETRANSMIT_DELAY.min(pq.delay * 2);
- return Ok(());
- }
- }
- // Nothing to dispatch
- Ok(())
- }
- pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt {
- self.queries
- .iter()
- .flatten()
- .filter_map(|q| match &q.state {
- State::Pending(pq) => Some(PollAt::Time(pq.retransmit_at)),
- State::Completed(_) => None,
- State::Failure => None,
- })
- .min()
- .unwrap_or(PollAt::Ingress)
- }
- }
- fn eq_names<'a>(
- mut a: impl Iterator<Item = wire::Result<&'a [u8]>>,
- mut b: impl Iterator<Item = wire::Result<&'a [u8]>>,
- ) -> wire::Result<bool> {
- loop {
- match (a.next(), b.next()) {
- // Handle errors
- (Some(Err(e)), _) => return Err(e),
- (_, Some(Err(e))) => return Err(e),
- // Both finished -> equal
- (None, None) => return Ok(true),
- // One finished before the other -> not equal
- (None, _) => return Ok(false),
- (_, None) => return Ok(false),
- // Got two labels, check if they're equal
- (Some(Ok(la)), Some(Ok(lb))) => {
- if la != lb {
- return Ok(false);
- }
- }
- }
- }
- }
- fn copy_name<'a, const N: usize>(
- dest: &mut Vec<u8, N>,
- name: impl Iterator<Item = wire::Result<&'a [u8]>>,
- ) -> Result<(), wire::Error> {
- dest.truncate(0);
- for label in name {
- let label = label?;
- dest.push(label.len() as u8).map_err(|_| wire::Error)?;
- dest.extend_from_slice(label).map_err(|_| wire::Error)?;
- }
- // Write terminator 0x00
- dest.push(0).map_err(|_| wire::Error)?;
- Ok(())
- }
|