|  | @@ -0,0 +1,411 @@
 | 
	
		
			
				|  |  | +#![allow(dead_code, unused)]
 | 
	
		
			
				|  |  | +use heapless::Vec;
 | 
	
		
			
				|  |  | +use managed::ManagedSlice;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +use crate::socket::{Context, PollAt, Socket};
 | 
	
		
			
				|  |  | +use crate::time::{Duration, Instant};
 | 
	
		
			
				|  |  | +use crate::wire::dns::{Flags, Opcode, Packet, Question, Record, RecordData, Repr, Type};
 | 
	
		
			
				|  |  | +use crate::wire::{IpAddress, IpEndpoint, IpProtocol, IpRepr, Ipv4Address, UdpRepr};
 | 
	
		
			
				|  |  | +use crate::{rand, Error, Result};
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +const DNS_PORT: u16 = 53;
 | 
	
		
			
				|  |  | +const MAX_NAME_LEN: usize = 255;
 | 
	
		
			
				|  |  | +const MAX_ADDRESS_COUNT: usize = 4;
 | 
	
		
			
				|  |  | +const RETRANSMIT_DELAY: Duration = Duration::from_millis(1000);
 | 
	
		
			
				|  |  | +const MAX_RETRANSMIT_DELAY: Duration = Duration::from_millis(10000);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/// State for an in-progress DNS query.
 | 
	
		
			
				|  |  | +///
 | 
	
		
			
				|  |  | +/// The only reason this struct is public is to allow the socket state
 | 
	
		
			
				|  |  | +/// to be allocated externally.
 | 
	
		
			
				|  |  | +#[derive(Debug)]
 | 
	
		
			
				|  |  | +pub struct DnsQuery {
 | 
	
		
			
				|  |  | +    state: State,
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +#[derive(Debug)]
 | 
	
		
			
				|  |  | +#[allow(clippy::large_enum_variant)]
 | 
	
		
			
				|  |  | +enum State {
 | 
	
		
			
				|  |  | +    Pending(PendingQuery),
 | 
	
		
			
				|  |  | +    Completed(CompletedQuery),
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +#[derive(Debug)]
 | 
	
		
			
				|  |  | +struct PendingQuery {
 | 
	
		
			
				|  |  | +    name: Vec<u8, MAX_NAME_LEN>,
 | 
	
		
			
				|  |  | +    type_: Type,
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    port: u16, // UDP port (src for request, dst for response)
 | 
	
		
			
				|  |  | +    txid: u16, // transaction ID
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    retransmit_at: Instant,
 | 
	
		
			
				|  |  | +    delay: Duration,
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +#[derive(Debug)]
 | 
	
		
			
				|  |  | +struct CompletedQuery {
 | 
	
		
			
				|  |  | +    addresses: Vec<IpAddress, MAX_ADDRESS_COUNT>,
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/// A handle to an in-progress DNS query.
 | 
	
		
			
				|  |  | +#[derive(Clone, Copy)]
 | 
	
		
			
				|  |  | +pub struct QueryHandle(usize);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/// A Domain Name System socket.
 | 
	
		
			
				|  |  | +///
 | 
	
		
			
				|  |  | +/// A UDP socket is bound to a specific endpoint, and owns transmit and receive
 | 
	
		
			
				|  |  | +/// packet buffers.
 | 
	
		
			
				|  |  | +#[derive(Debug)]
 | 
	
		
			
				|  |  | +pub struct DnsSocket<'a> {
 | 
	
		
			
				|  |  | +    servers: ManagedSlice<'a, IpAddress>,
 | 
	
		
			
				|  |  | +    queries: ManagedSlice<'a, Option<DnsQuery>>,
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
 | 
	
		
			
				|  |  | +    hop_limit: Option<u8>,
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +impl<'a> DnsSocket<'a> {
 | 
	
		
			
				|  |  | +    /// Create a DNS socket with the given buffers.
 | 
	
		
			
				|  |  | +    pub fn new<Q, S>(servers: S, queries: Q) -> DnsSocket<'a>
 | 
	
		
			
				|  |  | +    where
 | 
	
		
			
				|  |  | +        S: Into<ManagedSlice<'a, IpAddress>>,
 | 
	
		
			
				|  |  | +        Q: Into<ManagedSlice<'a, Option<DnsQuery>>>,
 | 
	
		
			
				|  |  | +    {
 | 
	
		
			
				|  |  | +        DnsSocket {
 | 
	
		
			
				|  |  | +            servers: servers.into(),
 | 
	
		
			
				|  |  | +            queries: queries.into(),
 | 
	
		
			
				|  |  | +            hop_limit: None,
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
 | 
	
		
			
				|  |  | +    ///
 | 
	
		
			
				|  |  | +    /// See also the [set_hop_limit](#method.set_hop_limit) method
 | 
	
		
			
				|  |  | +    pub fn hop_limit(&self) -> Option<u8> {
 | 
	
		
			
				|  |  | +        self.hop_limit
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /// Set the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
 | 
	
		
			
				|  |  | +    ///
 | 
	
		
			
				|  |  | +    /// A socket without an explicitly set hop limit value uses the default [IANA recommended]
 | 
	
		
			
				|  |  | +    /// value (64).
 | 
	
		
			
				|  |  | +    ///
 | 
	
		
			
				|  |  | +    /// # Panics
 | 
	
		
			
				|  |  | +    ///
 | 
	
		
			
				|  |  | +    /// This function panics if a hop limit value of 0 is given. See [RFC 1122 § 3.2.1.7].
 | 
	
		
			
				|  |  | +    ///
 | 
	
		
			
				|  |  | +    /// [IANA recommended]: https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml
 | 
	
		
			
				|  |  | +    /// [RFC 1122 § 3.2.1.7]: https://tools.ietf.org/html/rfc1122#section-3.2.1.7
 | 
	
		
			
				|  |  | +    pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
 | 
	
		
			
				|  |  | +        // A host MUST NOT send a datagram with a hop limit value of 0
 | 
	
		
			
				|  |  | +        if let Some(0) = hop_limit {
 | 
	
		
			
				|  |  | +            panic!("the time-to-live value of a packet must not be zero")
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        self.hop_limit = hop_limit
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    fn find_free_query(&mut self) -> Result<QueryHandle> {
 | 
	
		
			
				|  |  | +        for (i, q) in self.queries.iter().enumerate() {
 | 
	
		
			
				|  |  | +            if q.is_none() {
 | 
	
		
			
				|  |  | +                return Ok(QueryHandle(i));
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        match self.queries {
 | 
	
		
			
				|  |  | +            ManagedSlice::Borrowed(_) => Err(Error::Exhausted),
 | 
	
		
			
				|  |  | +            #[cfg(any(feature = "std", feature = "alloc"))]
 | 
	
		
			
				|  |  | +            ManagedSlice::Owned(ref mut queries) => {
 | 
	
		
			
				|  |  | +                queries.push(None);
 | 
	
		
			
				|  |  | +                let index = queries.len() - 1;
 | 
	
		
			
				|  |  | +                Ok(QueryHandle(index))
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    pub fn start_query(&mut self, cx: &mut Context, name: &[u8]) -> Result<QueryHandle> {
 | 
	
		
			
				|  |  | +        let handle = self.find_free_query()?;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        self.queries[handle.0] = Some(DnsQuery {
 | 
	
		
			
				|  |  | +            state: State::Pending(PendingQuery {
 | 
	
		
			
				|  |  | +                name: Vec::from_slice(name).map_err(|_| Error::Truncated)?,
 | 
	
		
			
				|  |  | +                type_: Type::A,
 | 
	
		
			
				|  |  | +                txid: cx.rand().rand_u16(),
 | 
	
		
			
				|  |  | +                port: cx.rand().rand_source_port(),
 | 
	
		
			
				|  |  | +                delay: RETRANSMIT_DELAY,
 | 
	
		
			
				|  |  | +                retransmit_at: Instant::ZERO,
 | 
	
		
			
				|  |  | +            }),
 | 
	
		
			
				|  |  | +        });
 | 
	
		
			
				|  |  | +        Ok(handle)
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    pub fn get_query_result(
 | 
	
		
			
				|  |  | +        &mut self,
 | 
	
		
			
				|  |  | +        handle: QueryHandle,
 | 
	
		
			
				|  |  | +    ) -> Result<Vec<IpAddress, MAX_ADDRESS_COUNT>> {
 | 
	
		
			
				|  |  | +        let slot = self.queries.get_mut(handle.0).ok_or(Error::Illegal)?;
 | 
	
		
			
				|  |  | +        let q = slot.as_mut().ok_or(Error::Illegal)?;
 | 
	
		
			
				|  |  | +        match &mut q.state {
 | 
	
		
			
				|  |  | +            // Query is not done yet.
 | 
	
		
			
				|  |  | +            State::Pending(_) => Err(Error::Exhausted),
 | 
	
		
			
				|  |  | +            // Query is done
 | 
	
		
			
				|  |  | +            State::Completed(q) => {
 | 
	
		
			
				|  |  | +                let res = q.addresses.clone();
 | 
	
		
			
				|  |  | +                *slot = None; // Free up the slot for recycling.
 | 
	
		
			
				|  |  | +                Ok(res)
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    pub fn cancel_query(&mut self, handle: QueryHandle) -> Result<()> {
 | 
	
		
			
				|  |  | +        let slot = self.queries.get_mut(handle.0).ok_or(Error::Illegal)?;
 | 
	
		
			
				|  |  | +        let q = slot.as_mut().ok_or(Error::Illegal)?;
 | 
	
		
			
				|  |  | +        *slot = None; // Free up the slot for recycling.
 | 
	
		
			
				|  |  | +        Ok(())
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool {
 | 
	
		
			
				|  |  | +        udp_repr.src_port == DNS_PORT
 | 
	
		
			
				|  |  | +            && self
 | 
	
		
			
				|  |  | +                .servers
 | 
	
		
			
				|  |  | +                .iter()
 | 
	
		
			
				|  |  | +                .any(|server| *server == ip_repr.src_addr())
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    pub(crate) fn process(
 | 
	
		
			
				|  |  | +        &mut self,
 | 
	
		
			
				|  |  | +        cx: &mut Context,
 | 
	
		
			
				|  |  | +        ip_repr: &IpRepr,
 | 
	
		
			
				|  |  | +        udp_repr: &UdpRepr,
 | 
	
		
			
				|  |  | +        payload: &[u8],
 | 
	
		
			
				|  |  | +    ) -> Result<()> {
 | 
	
		
			
				|  |  | +        debug_assert!(self.accepts(ip_repr, udp_repr));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        let size = payload.len();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        net_trace!(
 | 
	
		
			
				|  |  | +            "receiving {} octets from {:?}:{}",
 | 
	
		
			
				|  |  | +            size,
 | 
	
		
			
				|  |  | +            ip_repr.src_addr(),
 | 
	
		
			
				|  |  | +            udp_repr.dst_port
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        let p = Packet::new_checked(payload)?;
 | 
	
		
			
				|  |  | +        if p.opcode() != Opcode::Query {
 | 
	
		
			
				|  |  | +            net_trace!("unwanted opcode {:?}", p.opcode());
 | 
	
		
			
				|  |  | +            return Err(Error::Malformed);
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if !p.flags().contains(Flags::RESPONSE) {
 | 
	
		
			
				|  |  | +            net_trace!("packet doesn't have response bit set");
 | 
	
		
			
				|  |  | +            return Err(Error::Malformed);
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if p.question_count() != 1 {
 | 
	
		
			
				|  |  | +            net_trace!("bad question count {:?}", p.question_count());
 | 
	
		
			
				|  |  | +            return Err(Error::Malformed);
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Find pending query
 | 
	
		
			
				|  |  | +        for q in self.queries.iter_mut().flatten() {
 | 
	
		
			
				|  |  | +            if let State::Pending(pq) = &mut q.state {
 | 
	
		
			
				|  |  | +                if udp_repr.dst_port != pq.port || p.transaction_id() != pq.txid {
 | 
	
		
			
				|  |  | +                    continue;
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                let payload = p.payload();
 | 
	
		
			
				|  |  | +                let (mut payload, question) = Question::parse(payload)?;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                if question.type_ != pq.type_ {
 | 
	
		
			
				|  |  | +                    net_trace!("question type mismatch");
 | 
	
		
			
				|  |  | +                    return Err(Error::Malformed);
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +                if !eq_names(p.parse_name(question.name), p.parse_name(&pq.name))? {
 | 
	
		
			
				|  |  | +                    net_trace!("question name mismatch");
 | 
	
		
			
				|  |  | +                    return Err(Error::Malformed);
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                let mut addresses = Vec::new();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                for _ in 0..p.answer_record_count() {
 | 
	
		
			
				|  |  | +                    let (payload2, r) = Record::parse(payload)?;
 | 
	
		
			
				|  |  | +                    payload = payload2;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                    if !eq_names(p.parse_name(r.name), p.parse_name(&pq.name))? {
 | 
	
		
			
				|  |  | +                        net_trace!("answer name mismatch: {:?}", r);
 | 
	
		
			
				|  |  | +                        continue;
 | 
	
		
			
				|  |  | +                    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                    match r.data {
 | 
	
		
			
				|  |  | +                        RecordData::A(addr) => {
 | 
	
		
			
				|  |  | +                            net_trace!("A: {:?}", addr);
 | 
	
		
			
				|  |  | +                            if addresses.push(addr.into()).is_err() {
 | 
	
		
			
				|  |  | +                                net_trace!("too many addresses in response, ignoring {:?}", addr);
 | 
	
		
			
				|  |  | +                            }
 | 
	
		
			
				|  |  | +                        }
 | 
	
		
			
				|  |  | +                        RecordData::Aaaa(addr) => {
 | 
	
		
			
				|  |  | +                            net_trace!("AAAA: {:?}", addr);
 | 
	
		
			
				|  |  | +                            if addresses.push(addr.into()).is_err() {
 | 
	
		
			
				|  |  | +                                net_trace!("too many addresses in response, ignoring {:?}", addr);
 | 
	
		
			
				|  |  | +                            }
 | 
	
		
			
				|  |  | +                        }
 | 
	
		
			
				|  |  | +                        RecordData::Cname(name) => {
 | 
	
		
			
				|  |  | +                            net_trace!("CNAME: {:?}", name);
 | 
	
		
			
				|  |  | +                            copy_name(&mut pq.name, p.parse_name(name))?;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                            // Relaunch query with the new name.
 | 
	
		
			
				|  |  | +                            // If the server has bundled A records for the CNAME in the same packet,
 | 
	
		
			
				|  |  | +                            // we'll process them in next iterations, and cancel the query relaunch.
 | 
	
		
			
				|  |  | +                            pq.retransmit_at = Instant::ZERO;
 | 
	
		
			
				|  |  | +                            pq.delay = RETRANSMIT_DELAY;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                            pq.txid = cx.rand().rand_u16();
 | 
	
		
			
				|  |  | +                            pq.port = cx.rand().rand_source_port();
 | 
	
		
			
				|  |  | +                        }
 | 
	
		
			
				|  |  | +                        RecordData::Other(type_, data) => {
 | 
	
		
			
				|  |  | +                            net_trace!("unknown: {:?} {:?}", type_, data)
 | 
	
		
			
				|  |  | +                        }
 | 
	
		
			
				|  |  | +                    }
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                if !addresses.is_empty() {
 | 
	
		
			
				|  |  | +                    q.state = State::Completed(CompletedQuery { addresses })
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +                // If we get here, packet matched the current query, stop processing.
 | 
	
		
			
				|  |  | +                return Ok(());
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // If we get here, packet matched with no query.
 | 
	
		
			
				|  |  | +        net_trace!("no query matched");
 | 
	
		
			
				|  |  | +        Ok(())
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<()>
 | 
	
		
			
				|  |  | +    where
 | 
	
		
			
				|  |  | +        F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<()>,
 | 
	
		
			
				|  |  | +    {
 | 
	
		
			
				|  |  | +        let hop_limit = self.hop_limit.unwrap_or(64);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        for q in self.queries.iter_mut().flatten() {
 | 
	
		
			
				|  |  | +            if let State::Pending(pq) = &mut q.state {
 | 
	
		
			
				|  |  | +                if pq.retransmit_at > cx.now() {
 | 
	
		
			
				|  |  | +                    // query is waiting for retransmit
 | 
	
		
			
				|  |  | +                    continue;
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                let repr = Repr {
 | 
	
		
			
				|  |  | +                    transaction_id: pq.txid,
 | 
	
		
			
				|  |  | +                    flags: Flags::RECURSION_DESIRED,
 | 
	
		
			
				|  |  | +                    opcode: Opcode::Query,
 | 
	
		
			
				|  |  | +                    question: Question {
 | 
	
		
			
				|  |  | +                        name: &pq.name,
 | 
	
		
			
				|  |  | +                        type_: Type::A,
 | 
	
		
			
				|  |  | +                    },
 | 
	
		
			
				|  |  | +                };
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                let mut payload = [0u8; 512];
 | 
	
		
			
				|  |  | +                let payload = &mut payload[..repr.buffer_len()];
 | 
	
		
			
				|  |  | +                repr.emit(&mut Packet::new_unchecked(payload));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                let udp_repr = UdpRepr {
 | 
	
		
			
				|  |  | +                    src_port: pq.port,
 | 
	
		
			
				|  |  | +                    dst_port: 53,
 | 
	
		
			
				|  |  | +                };
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                let dst_addr = self.servers[0];
 | 
	
		
			
				|  |  | +                let src_addr = cx.get_source_address(dst_addr).unwrap(); // TODO remove unwrap
 | 
	
		
			
				|  |  | +                let ip_repr = IpRepr::new(
 | 
	
		
			
				|  |  | +                    src_addr,
 | 
	
		
			
				|  |  | +                    dst_addr,
 | 
	
		
			
				|  |  | +                    IpProtocol::Udp,
 | 
	
		
			
				|  |  | +                    udp_repr.header_len() + payload.len(),
 | 
	
		
			
				|  |  | +                    hop_limit,
 | 
	
		
			
				|  |  | +                );
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                net_trace!(
 | 
	
		
			
				|  |  | +                    "sending {} octets to {:?}:{}",
 | 
	
		
			
				|  |  | +                    payload.len(),
 | 
	
		
			
				|  |  | +                    ip_repr.dst_addr(),
 | 
	
		
			
				|  |  | +                    udp_repr.src_port
 | 
	
		
			
				|  |  | +                );
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                if let Err(e) = emit(cx, (ip_repr, udp_repr, payload)) {
 | 
	
		
			
				|  |  | +                    net_trace!("DNS emit error {:?}", e);
 | 
	
		
			
				|  |  | +                    return Ok(());
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                pq.retransmit_at = cx.now() + pq.delay;
 | 
	
		
			
				|  |  | +                pq.delay = MAX_RETRANSMIT_DELAY.min(pq.delay * 2);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                return Ok(());
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Nothing to dispatch
 | 
	
		
			
				|  |  | +        Err(Error::Exhausted)
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt {
 | 
	
		
			
				|  |  | +        self.queries
 | 
	
		
			
				|  |  | +            .iter()
 | 
	
		
			
				|  |  | +            .flatten()
 | 
	
		
			
				|  |  | +            .filter_map(|q| match &q.state {
 | 
	
		
			
				|  |  | +                State::Pending(pq) => Some(PollAt::Time(pq.retransmit_at)),
 | 
	
		
			
				|  |  | +                State::Completed(_) => None,
 | 
	
		
			
				|  |  | +            })
 | 
	
		
			
				|  |  | +            .min()
 | 
	
		
			
				|  |  | +            .unwrap_or(PollAt::Ingress)
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +impl<'a> From<DnsSocket<'a>> for Socket<'a> {
 | 
	
		
			
				|  |  | +    fn from(val: DnsSocket<'a>) -> Self {
 | 
	
		
			
				|  |  | +        Socket::Dns(val)
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +fn eq_names<'a>(
 | 
	
		
			
				|  |  | +    mut a: impl Iterator<Item = Result<&'a [u8]>>,
 | 
	
		
			
				|  |  | +    mut b: impl Iterator<Item = Result<&'a [u8]>>,
 | 
	
		
			
				|  |  | +) -> Result<bool> {
 | 
	
		
			
				|  |  | +    loop {
 | 
	
		
			
				|  |  | +        match (a.next(), b.next()) {
 | 
	
		
			
				|  |  | +            // Handle errors
 | 
	
		
			
				|  |  | +            (Some(Err(e)), _) => return Err(e),
 | 
	
		
			
				|  |  | +            (_, Some(Err(e))) => return Err(e),
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            // Both finished -> equal
 | 
	
		
			
				|  |  | +            (None, None) => return Ok(true),
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            // One finished before the other -> not equal
 | 
	
		
			
				|  |  | +            (None, _) => return Ok(false),
 | 
	
		
			
				|  |  | +            (_, None) => return Ok(false),
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            // Got two labels, check if they're equal
 | 
	
		
			
				|  |  | +            (Some(Ok(la)), Some(Ok(lb))) => {
 | 
	
		
			
				|  |  | +                if la != lb {
 | 
	
		
			
				|  |  | +                    return Ok(false);
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +fn copy_name<'a, const N: usize>(
 | 
	
		
			
				|  |  | +    dest: &mut Vec<u8, N>,
 | 
	
		
			
				|  |  | +    name: impl Iterator<Item = Result<&'a [u8]>>,
 | 
	
		
			
				|  |  | +) -> Result<()> {
 | 
	
		
			
				|  |  | +    dest.truncate(0);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    for label in name {
 | 
	
		
			
				|  |  | +        let label = label?;
 | 
	
		
			
				|  |  | +        dest.push(label.len() as u8).map_err(|_| Error::Truncated);
 | 
	
		
			
				|  |  | +        dest.extend_from_slice(label).map_err(|_| Error::Truncated);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    // Write terminator 0x00
 | 
	
		
			
				|  |  | +    dest.push(0).map_err(|_| Error::Truncated);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    Ok(())
 | 
	
		
			
				|  |  | +}
 |