123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480 |
- #![allow(dead_code, unused)]
- use heapless::Vec;
- use managed::ManagedSlice;
- use crate::socket::{Context, PollAt, Socket};
- use crate::time::{Duration, Instant};
- use crate::wire::dns::{Flags, Opcode, Packet, Question, Rcode, Record, RecordData, Repr, Type};
- use crate::wire::{IpAddress, IpEndpoint, IpProtocol, IpRepr, Ipv4Address, UdpRepr};
- use crate::{rand, Error, Result};
- const DNS_PORT: u16 = 53;
- const MAX_NAME_LEN: usize = 255;
- const MAX_ADDRESS_COUNT: usize = 4;
- const MAX_SERVER_COUNT: usize = 4;
- const RETRANSMIT_DELAY: Duration = Duration::from_millis(1_000);
- const MAX_RETRANSMIT_DELAY: Duration = Duration::from_millis(10_000);
- const RETRANSMIT_TIMEOUT: Duration = Duration::from_millis(10_000); // Should generally be 2-10 secs
- /// State for an in-progress DNS query.
- ///
- /// The only reason this struct is public is to allow the socket state
- /// to be allocated externally.
- #[derive(Debug)]
- pub struct DnsQuery {
- state: State,
- }
- #[derive(Debug)]
- #[allow(clippy::large_enum_variant)]
- enum State {
- Pending(PendingQuery),
- Completed(CompletedQuery),
- Failure,
- }
- #[derive(Debug)]
- struct PendingQuery {
- name: Vec<u8, MAX_NAME_LEN>,
- type_: Type,
- port: u16, // UDP port (src for request, dst for response)
- txid: u16, // transaction ID
- timeout_at: Option<Instant>,
- retransmit_at: Instant,
- delay: Duration,
- server_idx: usize,
- }
- #[derive(Debug)]
- struct CompletedQuery {
- addresses: Vec<IpAddress, MAX_ADDRESS_COUNT>,
- }
- /// A handle to an in-progress DNS query.
- #[derive(Clone, Copy)]
- pub struct QueryHandle(usize);
- /// A Domain Name System socket.
- ///
- /// A UDP socket is bound to a specific endpoint, and owns transmit and receive
- /// packet buffers.
- #[derive(Debug)]
- pub struct DnsSocket<'a> {
- servers: Vec<IpAddress, MAX_SERVER_COUNT>,
- queries: ManagedSlice<'a, Option<DnsQuery>>,
- /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
- hop_limit: Option<u8>,
- }
- impl<'a> DnsSocket<'a> {
- /// Create a DNS socket.
- ///
- /// # Panics
- ///
- /// Panics if `servers.len() > MAX_SERVER_COUNT`
- pub fn new<Q>(servers: &[IpAddress], queries: Q) -> DnsSocket<'a>
- where
- Q: Into<ManagedSlice<'a, Option<DnsQuery>>>,
- {
- DnsSocket {
- servers: Vec::from_slice(servers).unwrap(),
- queries: queries.into(),
- hop_limit: None,
- }
- }
- /// Update the list of DNS servers, will replace all existing servers
- ///
- /// # Panics
- ///
- /// Panics if `servers.len() > MAX_SERVER_COUNT`
- pub fn update_servers(&mut self, servers: &[IpAddress]) {
- self.servers = Vec::from_slice(servers).unwrap();
- }
- /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
- ///
- /// See also the [set_hop_limit](#method.set_hop_limit) method
- pub fn hop_limit(&self) -> Option<u8> {
- self.hop_limit
- }
- /// Set the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
- ///
- /// A socket without an explicitly set hop limit value uses the default [IANA recommended]
- /// value (64).
- ///
- /// # Panics
- ///
- /// This function panics if a hop limit value of 0 is given. See [RFC 1122 § 3.2.1.7].
- ///
- /// [IANA recommended]: https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml
- /// [RFC 1122 § 3.2.1.7]: https://tools.ietf.org/html/rfc1122#section-3.2.1.7
- pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
- // A host MUST NOT send a datagram with a hop limit value of 0
- if let Some(0) = hop_limit {
- panic!("the time-to-live value of a packet must not be zero")
- }
- self.hop_limit = hop_limit
- }
- fn find_free_query(&mut self) -> Result<QueryHandle> {
- for (i, q) in self.queries.iter().enumerate() {
- if q.is_none() {
- return Ok(QueryHandle(i));
- }
- }
- match self.queries {
- ManagedSlice::Borrowed(_) => Err(Error::Exhausted),
- #[cfg(any(feature = "std", feature = "alloc"))]
- ManagedSlice::Owned(ref mut queries) => {
- queries.push(None);
- let index = queries.len() - 1;
- Ok(QueryHandle(index))
- }
- }
- }
- pub fn start_query(&mut self, cx: &mut Context, name: &[u8]) -> Result<QueryHandle> {
- let handle = self.find_free_query()?;
- self.queries[handle.0] = Some(DnsQuery {
- state: State::Pending(PendingQuery {
- name: Vec::from_slice(name).map_err(|_| Error::Truncated)?,
- type_: Type::A,
- txid: cx.rand().rand_u16(),
- port: cx.rand().rand_source_port(),
- delay: RETRANSMIT_DELAY,
- timeout_at: None,
- retransmit_at: Instant::ZERO,
- server_idx: 0,
- }),
- });
- Ok(handle)
- }
- pub fn get_query_result(
- &mut self,
- handle: QueryHandle,
- ) -> Result<Vec<IpAddress, MAX_ADDRESS_COUNT>> {
- let slot = self.queries.get_mut(handle.0).ok_or(Error::Illegal)?;
- let q = slot.as_mut().ok_or(Error::Illegal)?;
- match &mut q.state {
- // Query is not done yet.
- State::Pending(_) => Err(Error::Exhausted),
- // Query is done
- State::Completed(q) => {
- let res = q.addresses.clone();
- *slot = None; // Free up the slot for recycling.
- Ok(res)
- }
- State::Failure => {
- *slot = None; // Free up the slot for recycling.
- Err(Error::Unaddressable)
- }
- }
- }
- pub fn cancel_query(&mut self, handle: QueryHandle) -> Result<()> {
- let slot = self.queries.get_mut(handle.0).ok_or(Error::Illegal)?;
- let q = slot.as_mut().ok_or(Error::Illegal)?;
- *slot = None; // Free up the slot for recycling.
- Ok(())
- }
- pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool {
- udp_repr.src_port == DNS_PORT
- && self
- .servers
- .iter()
- .any(|server| *server == ip_repr.src_addr())
- }
- pub(crate) fn process(
- &mut self,
- cx: &mut Context,
- ip_repr: &IpRepr,
- udp_repr: &UdpRepr,
- payload: &[u8],
- ) -> Result<()> {
- debug_assert!(self.accepts(ip_repr, udp_repr));
- let size = payload.len();
- net_trace!(
- "receiving {} octets from {:?}:{}",
- size,
- ip_repr.src_addr(),
- udp_repr.dst_port
- );
- let p = Packet::new_checked(payload)?;
- if p.opcode() != Opcode::Query {
- net_trace!("unwanted opcode {:?}", p.opcode());
- return Err(Error::Malformed);
- }
- if !p.flags().contains(Flags::RESPONSE) {
- net_trace!("packet doesn't have response bit set");
- return Err(Error::Malformed);
- }
- if p.question_count() != 1 {
- net_trace!("bad question count {:?}", p.question_count());
- return Err(Error::Malformed);
- }
- // Find pending query
- for q in self.queries.iter_mut().flatten() {
- if let State::Pending(pq) = &mut q.state {
- if udp_repr.dst_port != pq.port || p.transaction_id() != pq.txid {
- continue;
- }
- if p.rcode() == Rcode::NXDomain {
- net_trace!("rcode NXDomain");
- q.state = State::Failure;
- continue;
- }
- let payload = p.payload();
- let (mut payload, question) = Question::parse(payload)?;
- if question.type_ != pq.type_ {
- net_trace!("question type mismatch");
- return Err(Error::Malformed);
- }
- if !eq_names(p.parse_name(question.name), p.parse_name(&pq.name))? {
- net_trace!("question name mismatch");
- return Err(Error::Malformed);
- }
- let mut addresses = Vec::new();
- for _ in 0..p.answer_record_count() {
- let (payload2, r) = Record::parse(payload)?;
- payload = payload2;
- if !eq_names(p.parse_name(r.name), p.parse_name(&pq.name))? {
- net_trace!("answer name mismatch: {:?}", r);
- continue;
- }
- match r.data {
- RecordData::A(addr) => {
- net_trace!("A: {:?}", addr);
- if addresses.push(addr.into()).is_err() {
- net_trace!("too many addresses in response, ignoring {:?}", addr);
- }
- }
- RecordData::Aaaa(addr) => {
- net_trace!("AAAA: {:?}", addr);
- if addresses.push(addr.into()).is_err() {
- net_trace!("too many addresses in response, ignoring {:?}", addr);
- }
- }
- RecordData::Cname(name) => {
- net_trace!("CNAME: {:?}", name);
- copy_name(&mut pq.name, p.parse_name(name))?;
- // Relaunch query with the new name.
- // If the server has bundled A records for the CNAME in the same packet,
- // we'll process them in next iterations, and cancel the query relaunch.
- pq.retransmit_at = Instant::ZERO;
- pq.delay = RETRANSMIT_DELAY;
- pq.txid = cx.rand().rand_u16();
- pq.port = cx.rand().rand_source_port();
- }
- RecordData::Other(type_, data) => {
- net_trace!("unknown: {:?} {:?}", type_, data)
- }
- }
- }
- if !addresses.is_empty() {
- q.state = State::Completed(CompletedQuery { addresses })
- } else {
- q.state = State::Failure;
- }
- // If we get here, packet matched the current query, stop processing.
- return Ok(());
- }
- }
- // If we get here, packet matched with no query.
- net_trace!("no query matched");
- Ok(())
- }
- pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<()>
- where
- F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<()>,
- {
- let hop_limit = self.hop_limit.unwrap_or(64);
- for q in self.queries.iter_mut().flatten() {
- if let State::Pending(pq) = &mut q.state {
- let timeout = if let Some(timeout) = pq.timeout_at {
- timeout
- } else {
- let v = cx.now() + RETRANSMIT_TIMEOUT;
- pq.timeout_at = Some(v);
- v
- };
- // Check timeout
- if timeout < cx.now() {
- // DNS timeout
- pq.timeout_at = Some(cx.now() + RETRANSMIT_TIMEOUT);
- pq.retransmit_at = Instant::ZERO;
- pq.delay = RETRANSMIT_DELAY;
- // Try next server. We check below whether we've tried all servers.
- pq.server_idx += 1;
- }
- // Check if we've run out of servers to try.
- if pq.server_idx >= self.servers.len() {
- net_trace!("already tried all servers.");
- q.state = State::Failure;
- continue;
- }
- // Check so the IP address is valid
- if self.servers[pq.server_idx].is_unspecified() {
- net_trace!("invalid unspecified DNS server addr.");
- q.state = State::Failure;
- continue;
- }
- if pq.retransmit_at > cx.now() {
- // query is waiting for retransmit
- continue;
- }
- let repr = Repr {
- transaction_id: pq.txid,
- flags: Flags::RECURSION_DESIRED,
- opcode: Opcode::Query,
- question: Question {
- name: &pq.name,
- type_: Type::A,
- },
- };
- let mut payload = [0u8; 512];
- let payload = &mut payload[..repr.buffer_len()];
- repr.emit(&mut Packet::new_unchecked(payload));
- let udp_repr = UdpRepr {
- src_port: pq.port,
- dst_port: 53,
- };
- let dst_addr = self.servers[pq.server_idx];
- let src_addr = cx.get_source_address(dst_addr).unwrap(); // TODO remove unwrap
- let ip_repr = IpRepr::new(
- src_addr,
- dst_addr,
- IpProtocol::Udp,
- udp_repr.header_len() + payload.len(),
- hop_limit,
- );
- net_trace!(
- "sending {} octets to {:?}:{}",
- payload.len(),
- ip_repr.dst_addr(),
- udp_repr.src_port
- );
- if let Err(e) = emit(cx, (ip_repr, udp_repr, payload)) {
- net_trace!("DNS emit error {:?}", e);
- return Ok(());
- }
- pq.retransmit_at = cx.now() + pq.delay;
- pq.delay = MAX_RETRANSMIT_DELAY.min(pq.delay * 2);
- return Ok(());
- }
- }
- // Nothing to dispatch
- Err(Error::Exhausted)
- }
- pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt {
- self.queries
- .iter()
- .flatten()
- .filter_map(|q| match &q.state {
- State::Pending(pq) => Some(PollAt::Time(pq.retransmit_at)),
- State::Completed(_) => None,
- State::Failure => None,
- })
- .min()
- .unwrap_or(PollAt::Ingress)
- }
- }
- impl<'a> From<DnsSocket<'a>> for Socket<'a> {
- fn from(val: DnsSocket<'a>) -> Self {
- Socket::Dns(val)
- }
- }
- fn eq_names<'a>(
- mut a: impl Iterator<Item = Result<&'a [u8]>>,
- mut b: impl Iterator<Item = Result<&'a [u8]>>,
- ) -> Result<bool> {
- loop {
- match (a.next(), b.next()) {
- // Handle errors
- (Some(Err(e)), _) => return Err(e),
- (_, Some(Err(e))) => return Err(e),
- // Both finished -> equal
- (None, None) => return Ok(true),
- // One finished before the other -> not equal
- (None, _) => return Ok(false),
- (_, None) => return Ok(false),
- // Got two labels, check if they're equal
- (Some(Ok(la)), Some(Ok(lb))) => {
- if la != lb {
- return Ok(false);
- }
- }
- }
- }
- }
- fn copy_name<'a, const N: usize>(
- dest: &mut Vec<u8, N>,
- name: impl Iterator<Item = Result<&'a [u8]>>,
- ) -> Result<()> {
- dest.truncate(0);
- for label in name {
- let label = label?;
- dest.push(label.len() as u8).map_err(|_| Error::Truncated);
- dest.extend_from_slice(label).map_err(|_| Error::Truncated);
- }
- // Write terminator 0x00
- dest.push(0).map_err(|_| Error::Truncated);
- Ok(())
- }
|