dns.rs 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. #[cfg(feature = "async")]
  2. use core::task::Waker;
  3. use heapless::Vec;
  4. use managed::ManagedSlice;
  5. use crate::socket::{Context, PollAt, Socket};
  6. use crate::time::{Duration, Instant};
  7. use crate::wire::dns::{Flags, Opcode, Packet, Question, Rcode, Record, RecordData, Repr, Type};
  8. use crate::wire::{IpAddress, IpProtocol, IpRepr, UdpRepr};
  9. use crate::{Error, Result};
  10. #[cfg(feature = "async")]
  11. use super::WakerRegistration;
  12. const DNS_PORT: u16 = 53;
  13. const MAX_NAME_LEN: usize = 255;
  14. const MAX_ADDRESS_COUNT: usize = 4;
  15. const MAX_SERVER_COUNT: usize = 4;
  16. const RETRANSMIT_DELAY: Duration = Duration::from_millis(1_000);
  17. const MAX_RETRANSMIT_DELAY: Duration = Duration::from_millis(10_000);
  18. const RETRANSMIT_TIMEOUT: Duration = Duration::from_millis(10_000); // Should generally be 2-10 secs
  19. /// State for an in-progress DNS query.
  20. ///
  21. /// The only reason this struct is public is to allow the socket state
  22. /// to be allocated externally.
  23. #[derive(Debug)]
  24. pub struct DnsQuery {
  25. state: State,
  26. #[cfg(feature = "async")]
  27. waker: WakerRegistration,
  28. }
  29. impl DnsQuery {
  30. fn set_state(&mut self, state: State) {
  31. self.state = state;
  32. #[cfg(feature = "async")]
  33. self.waker.wake();
  34. }
  35. }
  36. #[derive(Debug)]
  37. #[allow(clippy::large_enum_variant)]
  38. enum State {
  39. Pending(PendingQuery),
  40. Completed(CompletedQuery),
  41. Failure,
  42. }
  43. #[derive(Debug)]
  44. struct PendingQuery {
  45. name: Vec<u8, MAX_NAME_LEN>,
  46. type_: Type,
  47. port: u16, // UDP port (src for request, dst for response)
  48. txid: u16, // transaction ID
  49. timeout_at: Option<Instant>,
  50. retransmit_at: Instant,
  51. delay: Duration,
  52. server_idx: usize,
  53. }
  54. #[derive(Debug)]
  55. struct CompletedQuery {
  56. addresses: Vec<IpAddress, MAX_ADDRESS_COUNT>,
  57. }
  58. /// A handle to an in-progress DNS query.
  59. #[derive(Clone, Copy)]
  60. pub struct QueryHandle(usize);
  61. /// A Domain Name System socket.
  62. ///
  63. /// A UDP socket is bound to a specific endpoint, and owns transmit and receive
  64. /// packet buffers.
  65. #[derive(Debug)]
  66. pub struct DnsSocket<'a> {
  67. servers: Vec<IpAddress, MAX_SERVER_COUNT>,
  68. queries: ManagedSlice<'a, Option<DnsQuery>>,
  69. /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
  70. hop_limit: Option<u8>,
  71. }
  72. impl<'a> DnsSocket<'a> {
  73. /// Create a DNS socket.
  74. ///
  75. /// # Panics
  76. ///
  77. /// Panics if `servers.len() > MAX_SERVER_COUNT`
  78. pub fn new<Q>(servers: &[IpAddress], queries: Q) -> DnsSocket<'a>
  79. where
  80. Q: Into<ManagedSlice<'a, Option<DnsQuery>>>,
  81. {
  82. DnsSocket {
  83. servers: Vec::from_slice(servers).unwrap(),
  84. queries: queries.into(),
  85. hop_limit: None,
  86. }
  87. }
  88. /// Update the list of DNS servers, will replace all existing servers
  89. ///
  90. /// # Panics
  91. ///
  92. /// Panics if `servers.len() > MAX_SERVER_COUNT`
  93. pub fn update_servers(&mut self, servers: &[IpAddress]) {
  94. self.servers = Vec::from_slice(servers).unwrap();
  95. }
  96. /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
  97. ///
  98. /// See also the [set_hop_limit](#method.set_hop_limit) method
  99. pub fn hop_limit(&self) -> Option<u8> {
  100. self.hop_limit
  101. }
  102. /// Set the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
  103. ///
  104. /// A socket without an explicitly set hop limit value uses the default [IANA recommended]
  105. /// value (64).
  106. ///
  107. /// # Panics
  108. ///
  109. /// This function panics if a hop limit value of 0 is given. See [RFC 1122 § 3.2.1.7].
  110. ///
  111. /// [IANA recommended]: https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml
  112. /// [RFC 1122 § 3.2.1.7]: https://tools.ietf.org/html/rfc1122#section-3.2.1.7
  113. pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
  114. // A host MUST NOT send a datagram with a hop limit value of 0
  115. if let Some(0) = hop_limit {
  116. panic!("the time-to-live value of a packet must not be zero")
  117. }
  118. self.hop_limit = hop_limit
  119. }
  120. fn find_free_query(&mut self) -> Result<QueryHandle> {
  121. for (i, q) in self.queries.iter().enumerate() {
  122. if q.is_none() {
  123. return Ok(QueryHandle(i));
  124. }
  125. }
  126. match self.queries {
  127. ManagedSlice::Borrowed(_) => Err(Error::Exhausted),
  128. #[cfg(any(feature = "std", feature = "alloc"))]
  129. ManagedSlice::Owned(ref mut queries) => {
  130. queries.push(None);
  131. let index = queries.len() - 1;
  132. Ok(QueryHandle(index))
  133. }
  134. }
  135. }
  136. /// Start a query.
  137. ///
  138. /// `name` is specified in human-friendly format, such as `"rust-lang.org"`.
  139. /// It accepts names both with and without trailing dot, and they're treated
  140. /// the same (there's no support for DNS search path).
  141. pub fn start_query(&mut self, cx: &mut Context, name: &str) -> Result<QueryHandle> {
  142. let mut name = name.as_bytes();
  143. if name.is_empty() {
  144. net_trace!("invalid name: zero length");
  145. return Err(Error::Illegal);
  146. }
  147. // Remove trailing dot, if any
  148. if name[name.len() - 1] == b'.' {
  149. name = &name[..name.len() - 1];
  150. }
  151. let mut raw_name: Vec<u8, MAX_NAME_LEN> = Vec::new();
  152. for s in name.split(|&c| c == b'.') {
  153. if s.len() > 255 {
  154. net_trace!("invalid name: too long label");
  155. return Err(Error::Illegal);
  156. }
  157. if s.is_empty() {
  158. net_trace!("invalid name: zero length label");
  159. return Err(Error::Illegal);
  160. }
  161. // Push label
  162. raw_name.push(s.len() as u8).map_err(|_| Error::Exhausted)?;
  163. raw_name
  164. .extend_from_slice(s)
  165. .map_err(|_| Error::Exhausted)?;
  166. }
  167. // Push terminator.
  168. raw_name.push(0x00).map_err(|_| Error::Exhausted)?;
  169. self.start_query_raw(cx, &raw_name)
  170. }
  171. /// Start a query with a raw (wire-format) DNS name.
  172. /// `b"\x09rust-lang\x03org\x00"`
  173. ///
  174. /// You probably want to use [`start_query`] instead.
  175. pub fn start_query_raw(&mut self, cx: &mut Context, raw_name: &[u8]) -> Result<QueryHandle> {
  176. let handle = self.find_free_query()?;
  177. self.queries[handle.0] = Some(DnsQuery {
  178. state: State::Pending(PendingQuery {
  179. name: Vec::from_slice(raw_name).map_err(|_| Error::Exhausted)?,
  180. type_: Type::A,
  181. txid: cx.rand().rand_u16(),
  182. port: cx.rand().rand_source_port(),
  183. delay: RETRANSMIT_DELAY,
  184. timeout_at: None,
  185. retransmit_at: Instant::ZERO,
  186. server_idx: 0,
  187. }),
  188. #[cfg(feature = "async")]
  189. waker: WakerRegistration::new(),
  190. });
  191. Ok(handle)
  192. }
  193. pub fn get_query_result(
  194. &mut self,
  195. handle: QueryHandle,
  196. ) -> Result<Vec<IpAddress, MAX_ADDRESS_COUNT>> {
  197. let slot = self.queries.get_mut(handle.0).ok_or(Error::Illegal)?;
  198. let q = slot.as_mut().ok_or(Error::Illegal)?;
  199. match &mut q.state {
  200. // Query is not done yet.
  201. State::Pending(_) => Err(Error::Exhausted),
  202. // Query is done
  203. State::Completed(q) => {
  204. let res = q.addresses.clone();
  205. *slot = None; // Free up the slot for recycling.
  206. Ok(res)
  207. }
  208. State::Failure => {
  209. *slot = None; // Free up the slot for recycling.
  210. Err(Error::Unaddressable)
  211. }
  212. }
  213. }
  214. pub fn cancel_query(&mut self, handle: QueryHandle) -> Result<()> {
  215. let slot = self.queries.get_mut(handle.0).ok_or(Error::Illegal)?;
  216. if slot.is_none() {
  217. return Err(Error::Illegal);
  218. }
  219. *slot = None; // Free up the slot for recycling.
  220. Ok(())
  221. }
  222. #[cfg(feature = "async")]
  223. pub fn register_query_waker(&mut self, handle: QueryHandle, waker: &Waker) -> Result<()> {
  224. let slot = self.queries.get_mut(handle.0).ok_or(Error::Illegal)?;
  225. slot.as_mut().ok_or(Error::Illegal)?.waker.register(waker);
  226. Ok(())
  227. }
  228. pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool {
  229. udp_repr.src_port == DNS_PORT
  230. && self
  231. .servers
  232. .iter()
  233. .any(|server| *server == ip_repr.src_addr())
  234. }
  235. pub(crate) fn process(
  236. &mut self,
  237. _cx: &mut Context,
  238. ip_repr: &IpRepr,
  239. udp_repr: &UdpRepr,
  240. payload: &[u8],
  241. ) -> Result<()> {
  242. debug_assert!(self.accepts(ip_repr, udp_repr));
  243. let size = payload.len();
  244. net_trace!(
  245. "receiving {} octets from {:?}:{}",
  246. size,
  247. ip_repr.src_addr(),
  248. udp_repr.dst_port
  249. );
  250. let p = Packet::new_checked(payload)?;
  251. if p.opcode() != Opcode::Query {
  252. net_trace!("unwanted opcode {:?}", p.opcode());
  253. return Err(Error::Malformed);
  254. }
  255. if !p.flags().contains(Flags::RESPONSE) {
  256. net_trace!("packet doesn't have response bit set");
  257. return Err(Error::Malformed);
  258. }
  259. if p.question_count() != 1 {
  260. net_trace!("bad question count {:?}", p.question_count());
  261. return Err(Error::Malformed);
  262. }
  263. // Find pending query
  264. for q in self.queries.iter_mut().flatten() {
  265. if let State::Pending(pq) = &mut q.state {
  266. if udp_repr.dst_port != pq.port || p.transaction_id() != pq.txid {
  267. continue;
  268. }
  269. if p.rcode() == Rcode::NXDomain {
  270. net_trace!("rcode NXDomain");
  271. q.set_state(State::Failure);
  272. continue;
  273. }
  274. let payload = p.payload();
  275. let (mut payload, question) = Question::parse(payload)?;
  276. if question.type_ != pq.type_ {
  277. net_trace!("question type mismatch");
  278. return Err(Error::Malformed);
  279. }
  280. if !eq_names(p.parse_name(question.name), p.parse_name(&pq.name))? {
  281. net_trace!("question name mismatch");
  282. return Err(Error::Malformed);
  283. }
  284. let mut addresses = Vec::new();
  285. for _ in 0..p.answer_record_count() {
  286. let (payload2, r) = Record::parse(payload)?;
  287. payload = payload2;
  288. if !eq_names(p.parse_name(r.name), p.parse_name(&pq.name))? {
  289. net_trace!("answer name mismatch: {:?}", r);
  290. continue;
  291. }
  292. match r.data {
  293. #[cfg(feature = "proto-ipv4")]
  294. RecordData::A(addr) => {
  295. net_trace!("A: {:?}", addr);
  296. if addresses.push(addr.into()).is_err() {
  297. net_trace!("too many addresses in response, ignoring {:?}", addr);
  298. }
  299. }
  300. #[cfg(feature = "proto-ipv6")]
  301. RecordData::Aaaa(addr) => {
  302. net_trace!("AAAA: {:?}", addr);
  303. if addresses.push(addr.into()).is_err() {
  304. net_trace!("too many addresses in response, ignoring {:?}", addr);
  305. }
  306. }
  307. RecordData::Cname(name) => {
  308. net_trace!("CNAME: {:?}", name);
  309. // When faced with a CNAME, recursive resolvers are supposed to
  310. // resolve the CNAME and append the results for it.
  311. //
  312. // We update the query with the new name, so that we pick up the A/AAAA
  313. // records for the CNAME when we parse them later.
  314. // I believe it's mandatory the CNAME results MUST come *after* in the
  315. // packet, so it's enough to do one linear pass over it.
  316. copy_name(&mut pq.name, p.parse_name(name))?;
  317. }
  318. RecordData::Other(type_, data) => {
  319. net_trace!("unknown: {:?} {:?}", type_, data)
  320. }
  321. }
  322. }
  323. q.set_state(if addresses.is_empty() {
  324. State::Failure
  325. } else {
  326. State::Completed(CompletedQuery { addresses })
  327. });
  328. // If we get here, packet matched the current query, stop processing.
  329. return Ok(());
  330. }
  331. }
  332. // If we get here, packet matched with no query.
  333. net_trace!("no query matched");
  334. Ok(())
  335. }
  336. pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<()>
  337. where
  338. F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<()>,
  339. {
  340. let hop_limit = self.hop_limit.unwrap_or(64);
  341. for q in self.queries.iter_mut().flatten() {
  342. if let State::Pending(pq) = &mut q.state {
  343. let timeout = if let Some(timeout) = pq.timeout_at {
  344. timeout
  345. } else {
  346. let v = cx.now() + RETRANSMIT_TIMEOUT;
  347. pq.timeout_at = Some(v);
  348. v
  349. };
  350. // Check timeout
  351. if timeout < cx.now() {
  352. // DNS timeout
  353. pq.timeout_at = Some(cx.now() + RETRANSMIT_TIMEOUT);
  354. pq.retransmit_at = Instant::ZERO;
  355. pq.delay = RETRANSMIT_DELAY;
  356. // Try next server. We check below whether we've tried all servers.
  357. pq.server_idx += 1;
  358. }
  359. // Check if we've run out of servers to try.
  360. if pq.server_idx >= self.servers.len() {
  361. net_trace!("already tried all servers.");
  362. q.set_state(State::Failure);
  363. continue;
  364. }
  365. // Check so the IP address is valid
  366. if self.servers[pq.server_idx].is_unspecified() {
  367. net_trace!("invalid unspecified DNS server addr.");
  368. q.set_state(State::Failure);
  369. continue;
  370. }
  371. if pq.retransmit_at > cx.now() {
  372. // query is waiting for retransmit
  373. continue;
  374. }
  375. let repr = Repr {
  376. transaction_id: pq.txid,
  377. flags: Flags::RECURSION_DESIRED,
  378. opcode: Opcode::Query,
  379. question: Question {
  380. name: &pq.name,
  381. type_: Type::A,
  382. },
  383. };
  384. let mut payload = [0u8; 512];
  385. let payload = &mut payload[..repr.buffer_len()];
  386. repr.emit(&mut Packet::new_unchecked(payload));
  387. let udp_repr = UdpRepr {
  388. src_port: pq.port,
  389. dst_port: 53,
  390. };
  391. let dst_addr = self.servers[pq.server_idx];
  392. let src_addr = cx.get_source_address(dst_addr).unwrap(); // TODO remove unwrap
  393. let ip_repr = IpRepr::new(
  394. src_addr,
  395. dst_addr,
  396. IpProtocol::Udp,
  397. udp_repr.header_len() + payload.len(),
  398. hop_limit,
  399. );
  400. net_trace!(
  401. "sending {} octets to {:?}:{}",
  402. payload.len(),
  403. ip_repr.dst_addr(),
  404. udp_repr.src_port
  405. );
  406. if let Err(e) = emit(cx, (ip_repr, udp_repr, payload)) {
  407. net_trace!("DNS emit error {:?}", e);
  408. return Ok(());
  409. }
  410. pq.retransmit_at = cx.now() + pq.delay;
  411. pq.delay = MAX_RETRANSMIT_DELAY.min(pq.delay * 2);
  412. return Ok(());
  413. }
  414. }
  415. // Nothing to dispatch
  416. Err(Error::Exhausted)
  417. }
  418. pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt {
  419. self.queries
  420. .iter()
  421. .flatten()
  422. .filter_map(|q| match &q.state {
  423. State::Pending(pq) => Some(PollAt::Time(pq.retransmit_at)),
  424. State::Completed(_) => None,
  425. State::Failure => None,
  426. })
  427. .min()
  428. .unwrap_or(PollAt::Ingress)
  429. }
  430. }
  431. impl<'a> From<DnsSocket<'a>> for Socket<'a> {
  432. fn from(val: DnsSocket<'a>) -> Self {
  433. Socket::Dns(val)
  434. }
  435. }
  436. fn eq_names<'a>(
  437. mut a: impl Iterator<Item = Result<&'a [u8]>>,
  438. mut b: impl Iterator<Item = Result<&'a [u8]>>,
  439. ) -> Result<bool> {
  440. loop {
  441. match (a.next(), b.next()) {
  442. // Handle errors
  443. (Some(Err(e)), _) => return Err(e),
  444. (_, Some(Err(e))) => return Err(e),
  445. // Both finished -> equal
  446. (None, None) => return Ok(true),
  447. // One finished before the other -> not equal
  448. (None, _) => return Ok(false),
  449. (_, None) => return Ok(false),
  450. // Got two labels, check if they're equal
  451. (Some(Ok(la)), Some(Ok(lb))) => {
  452. if la != lb {
  453. return Ok(false);
  454. }
  455. }
  456. }
  457. }
  458. }
  459. fn copy_name<'a, const N: usize>(
  460. dest: &mut Vec<u8, N>,
  461. name: impl Iterator<Item = Result<&'a [u8]>>,
  462. ) -> Result<()> {
  463. dest.truncate(0);
  464. for label in name {
  465. let label = label?;
  466. dest.push(label.len() as u8).map_err(|_| Error::Truncated)?;
  467. dest.extend_from_slice(label)
  468. .map_err(|_| Error::Truncated)?;
  469. }
  470. // Write terminator 0x00
  471. dest.push(0).map_err(|_| Error::Truncated)?;
  472. Ok(())
  473. }