| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411 | #![allow(dead_code, unused)]use heapless::Vec;use managed::ManagedSlice;use crate::socket::{Context, PollAt, Socket};use crate::time::{Duration, Instant};use crate::wire::dns::{Flags, Opcode, Packet, Question, Record, RecordData, Repr, Type};use crate::wire::{IpAddress, IpEndpoint, IpProtocol, IpRepr, Ipv4Address, UdpRepr};use crate::{rand, Error, Result};const DNS_PORT: u16 = 53;const MAX_NAME_LEN: usize = 255;const MAX_ADDRESS_COUNT: usize = 4;const RETRANSMIT_DELAY: Duration = Duration::from_millis(1000);const MAX_RETRANSMIT_DELAY: Duration = Duration::from_millis(10000);/// State for an in-progress DNS query.////// The only reason this struct is public is to allow the socket state/// to be allocated externally.#[derive(Debug)]pub struct DnsQuery {    state: State,}#[derive(Debug)]#[allow(clippy::large_enum_variant)]enum State {    Pending(PendingQuery),    Completed(CompletedQuery),}#[derive(Debug)]struct PendingQuery {    name: Vec<u8, MAX_NAME_LEN>,    type_: Type,    port: u16, // UDP port (src for request, dst for response)    txid: u16, // transaction ID    retransmit_at: Instant,    delay: Duration,}#[derive(Debug)]struct CompletedQuery {    addresses: Vec<IpAddress, MAX_ADDRESS_COUNT>,}/// A handle to an in-progress DNS query.#[derive(Clone, Copy)]pub struct QueryHandle(usize);/// A Domain Name System socket.////// A UDP socket is bound to a specific endpoint, and owns transmit and receive/// packet buffers.#[derive(Debug)]pub struct DnsSocket<'a> {    servers: ManagedSlice<'a, IpAddress>,    queries: ManagedSlice<'a, Option<DnsQuery>>,    /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.    hop_limit: Option<u8>,}impl<'a> DnsSocket<'a> {    /// Create a DNS socket with the given buffers.    pub fn new<Q, S>(servers: S, queries: Q) -> DnsSocket<'a>    where        S: Into<ManagedSlice<'a, IpAddress>>,        Q: Into<ManagedSlice<'a, Option<DnsQuery>>>,    {        DnsSocket {            servers: servers.into(),            queries: queries.into(),            hop_limit: None,        }    }    /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.    ///    /// See also the [set_hop_limit](#method.set_hop_limit) method    pub fn hop_limit(&self) -> Option<u8> {        self.hop_limit    }    /// Set the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.    ///    /// A socket without an explicitly set hop limit value uses the default [IANA recommended]    /// value (64).    ///    /// # Panics    ///    /// This function panics if a hop limit value of 0 is given. See [RFC 1122 § 3.2.1.7].    ///    /// [IANA recommended]: https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml    /// [RFC 1122 § 3.2.1.7]: https://tools.ietf.org/html/rfc1122#section-3.2.1.7    pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {        // A host MUST NOT send a datagram with a hop limit value of 0        if let Some(0) = hop_limit {            panic!("the time-to-live value of a packet must not be zero")        }        self.hop_limit = hop_limit    }    fn find_free_query(&mut self) -> Result<QueryHandle> {        for (i, q) in self.queries.iter().enumerate() {            if q.is_none() {                return Ok(QueryHandle(i));            }        }        match self.queries {            ManagedSlice::Borrowed(_) => Err(Error::Exhausted),            #[cfg(any(feature = "std", feature = "alloc"))]            ManagedSlice::Owned(ref mut queries) => {                queries.push(None);                let index = queries.len() - 1;                Ok(QueryHandle(index))            }        }    }    pub fn start_query(&mut self, cx: &mut Context, name: &[u8]) -> Result<QueryHandle> {        let handle = self.find_free_query()?;        self.queries[handle.0] = Some(DnsQuery {            state: State::Pending(PendingQuery {                name: Vec::from_slice(name).map_err(|_| Error::Truncated)?,                type_: Type::A,                txid: cx.rand().rand_u16(),                port: cx.rand().rand_source_port(),                delay: RETRANSMIT_DELAY,                retransmit_at: Instant::ZERO,            }),        });        Ok(handle)    }    pub fn get_query_result(        &mut self,        handle: QueryHandle,    ) -> Result<Vec<IpAddress, MAX_ADDRESS_COUNT>> {        let slot = self.queries.get_mut(handle.0).ok_or(Error::Illegal)?;        let q = slot.as_mut().ok_or(Error::Illegal)?;        match &mut q.state {            // Query is not done yet.            State::Pending(_) => Err(Error::Exhausted),            // Query is done            State::Completed(q) => {                let res = q.addresses.clone();                *slot = None; // Free up the slot for recycling.                Ok(res)            }        }    }    pub fn cancel_query(&mut self, handle: QueryHandle) -> Result<()> {        let slot = self.queries.get_mut(handle.0).ok_or(Error::Illegal)?;        let q = slot.as_mut().ok_or(Error::Illegal)?;        *slot = None; // Free up the slot for recycling.        Ok(())    }    pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool {        udp_repr.src_port == DNS_PORT            && self                .servers                .iter()                .any(|server| *server == ip_repr.src_addr())    }    pub(crate) fn process(        &mut self,        cx: &mut Context,        ip_repr: &IpRepr,        udp_repr: &UdpRepr,        payload: &[u8],    ) -> Result<()> {        debug_assert!(self.accepts(ip_repr, udp_repr));        let size = payload.len();        net_trace!(            "receiving {} octets from {:?}:{}",            size,            ip_repr.src_addr(),            udp_repr.dst_port        );        let p = Packet::new_checked(payload)?;        if p.opcode() != Opcode::Query {            net_trace!("unwanted opcode {:?}", p.opcode());            return Err(Error::Malformed);        }        if !p.flags().contains(Flags::RESPONSE) {            net_trace!("packet doesn't have response bit set");            return Err(Error::Malformed);        }        if p.question_count() != 1 {            net_trace!("bad question count {:?}", p.question_count());            return Err(Error::Malformed);        }        // Find pending query        for q in self.queries.iter_mut().flatten() {            if let State::Pending(pq) = &mut q.state {                if udp_repr.dst_port != pq.port || p.transaction_id() != pq.txid {                    continue;                }                let payload = p.payload();                let (mut payload, question) = Question::parse(payload)?;                if question.type_ != pq.type_ {                    net_trace!("question type mismatch");                    return Err(Error::Malformed);                }                if !eq_names(p.parse_name(question.name), p.parse_name(&pq.name))? {                    net_trace!("question name mismatch");                    return Err(Error::Malformed);                }                let mut addresses = Vec::new();                for _ in 0..p.answer_record_count() {                    let (payload2, r) = Record::parse(payload)?;                    payload = payload2;                    if !eq_names(p.parse_name(r.name), p.parse_name(&pq.name))? {                        net_trace!("answer name mismatch: {:?}", r);                        continue;                    }                    match r.data {                        RecordData::A(addr) => {                            net_trace!("A: {:?}", addr);                            if addresses.push(addr.into()).is_err() {                                net_trace!("too many addresses in response, ignoring {:?}", addr);                            }                        }                        RecordData::Aaaa(addr) => {                            net_trace!("AAAA: {:?}", addr);                            if addresses.push(addr.into()).is_err() {                                net_trace!("too many addresses in response, ignoring {:?}", addr);                            }                        }                        RecordData::Cname(name) => {                            net_trace!("CNAME: {:?}", name);                            copy_name(&mut pq.name, p.parse_name(name))?;                            // Relaunch query with the new name.                            // If the server has bundled A records for the CNAME in the same packet,                            // we'll process them in next iterations, and cancel the query relaunch.                            pq.retransmit_at = Instant::ZERO;                            pq.delay = RETRANSMIT_DELAY;                            pq.txid = cx.rand().rand_u16();                            pq.port = cx.rand().rand_source_port();                        }                        RecordData::Other(type_, data) => {                            net_trace!("unknown: {:?} {:?}", type_, data)                        }                    }                }                if !addresses.is_empty() {                    q.state = State::Completed(CompletedQuery { addresses })                }                // If we get here, packet matched the current query, stop processing.                return Ok(());            }        }        // If we get here, packet matched with no query.        net_trace!("no query matched");        Ok(())    }    pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<()>    where        F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<()>,    {        let hop_limit = self.hop_limit.unwrap_or(64);        for q in self.queries.iter_mut().flatten() {            if let State::Pending(pq) = &mut q.state {                if pq.retransmit_at > cx.now() {                    // query is waiting for retransmit                    continue;                }                let repr = Repr {                    transaction_id: pq.txid,                    flags: Flags::RECURSION_DESIRED,                    opcode: Opcode::Query,                    question: Question {                        name: &pq.name,                        type_: Type::A,                    },                };                let mut payload = [0u8; 512];                let payload = &mut payload[..repr.buffer_len()];                repr.emit(&mut Packet::new_unchecked(payload));                let udp_repr = UdpRepr {                    src_port: pq.port,                    dst_port: 53,                };                let dst_addr = self.servers[0];                let src_addr = cx.get_source_address(dst_addr).unwrap(); // TODO remove unwrap                let ip_repr = IpRepr::new(                    src_addr,                    dst_addr,                    IpProtocol::Udp,                    udp_repr.header_len() + payload.len(),                    hop_limit,                );                net_trace!(                    "sending {} octets to {:?}:{}",                    payload.len(),                    ip_repr.dst_addr(),                    udp_repr.src_port                );                if let Err(e) = emit(cx, (ip_repr, udp_repr, payload)) {                    net_trace!("DNS emit error {:?}", e);                    return Ok(());                }                pq.retransmit_at = cx.now() + pq.delay;                pq.delay = MAX_RETRANSMIT_DELAY.min(pq.delay * 2);                return Ok(());            }        }        // Nothing to dispatch        Err(Error::Exhausted)    }    pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt {        self.queries            .iter()            .flatten()            .filter_map(|q| match &q.state {                State::Pending(pq) => Some(PollAt::Time(pq.retransmit_at)),                State::Completed(_) => None,            })            .min()            .unwrap_or(PollAt::Ingress)    }}impl<'a> From<DnsSocket<'a>> for Socket<'a> {    fn from(val: DnsSocket<'a>) -> Self {        Socket::Dns(val)    }}fn eq_names<'a>(    mut a: impl Iterator<Item = Result<&'a [u8]>>,    mut b: impl Iterator<Item = Result<&'a [u8]>>,) -> Result<bool> {    loop {        match (a.next(), b.next()) {            // Handle errors            (Some(Err(e)), _) => return Err(e),            (_, Some(Err(e))) => return Err(e),            // Both finished -> equal            (None, None) => return Ok(true),            // One finished before the other -> not equal            (None, _) => return Ok(false),            (_, None) => return Ok(false),            // Got two labels, check if they're equal            (Some(Ok(la)), Some(Ok(lb))) => {                if la != lb {                    return Ok(false);                }            }        }    }}fn copy_name<'a, const N: usize>(    dest: &mut Vec<u8, N>,    name: impl Iterator<Item = Result<&'a [u8]>>,) -> Result<()> {    dest.truncate(0);    for label in name {        let label = label?;        dest.push(label.len() as u8).map_err(|_| Error::Truncated);        dest.extend_from_slice(label).map_err(|_| Error::Truncated);    }    // Write terminator 0x00    dest.push(0).map_err(|_| Error::Truncated);    Ok(())}
 |