dns.rs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. #![allow(dead_code, unused)]
  2. use heapless::Vec;
  3. use managed::ManagedSlice;
  4. use crate::socket::{Context, PollAt, Socket};
  5. use crate::time::{Duration, Instant};
  6. use crate::wire::dns::{Flags, Opcode, Packet, Question, Rcode, Record, RecordData, Repr, Type};
  7. use crate::wire::{IpAddress, IpEndpoint, IpProtocol, IpRepr, Ipv4Address, UdpRepr};
  8. use crate::{rand, Error, Result};
  9. const DNS_PORT: u16 = 53;
  10. const MAX_NAME_LEN: usize = 255;
  11. const MAX_ADDRESS_COUNT: usize = 4;
  12. const MAX_SERVER_COUNT: usize = 4;
  13. const RETRANSMIT_DELAY: Duration = Duration::from_millis(1_000);
  14. const MAX_RETRANSMIT_DELAY: Duration = Duration::from_millis(10_000);
  15. const RETRANSMIT_TIMEOUT: Duration = Duration::from_millis(10_000); // Should generally be 2-10 secs
  16. /// State for an in-progress DNS query.
  17. ///
  18. /// The only reason this struct is public is to allow the socket state
  19. /// to be allocated externally.
  20. #[derive(Debug)]
  21. pub struct DnsQuery {
  22. state: State,
  23. }
  24. #[derive(Debug)]
  25. #[allow(clippy::large_enum_variant)]
  26. enum State {
  27. Pending(PendingQuery),
  28. Completed(CompletedQuery),
  29. Failure,
  30. }
  31. #[derive(Debug)]
  32. struct PendingQuery {
  33. name: Vec<u8, MAX_NAME_LEN>,
  34. type_: Type,
  35. port: u16, // UDP port (src for request, dst for response)
  36. txid: u16, // transaction ID
  37. timeout_at: Option<Instant>,
  38. retransmit_at: Instant,
  39. delay: Duration,
  40. server_idx: usize,
  41. }
  42. #[derive(Debug)]
  43. struct CompletedQuery {
  44. addresses: Vec<IpAddress, MAX_ADDRESS_COUNT>,
  45. }
  46. /// A handle to an in-progress DNS query.
  47. #[derive(Clone, Copy)]
  48. pub struct QueryHandle(usize);
  49. /// A Domain Name System socket.
  50. ///
  51. /// A UDP socket is bound to a specific endpoint, and owns transmit and receive
  52. /// packet buffers.
  53. #[derive(Debug)]
  54. pub struct DnsSocket<'a> {
  55. servers: Vec<IpAddress, MAX_SERVER_COUNT>,
  56. queries: ManagedSlice<'a, Option<DnsQuery>>,
  57. /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
  58. hop_limit: Option<u8>,
  59. }
  60. impl<'a> DnsSocket<'a> {
  61. /// Create a DNS socket.
  62. ///
  63. /// # Panics
  64. ///
  65. /// Panics if `servers.len() > MAX_SERVER_COUNT`
  66. pub fn new<Q>(servers: &[IpAddress], queries: Q) -> DnsSocket<'a>
  67. where
  68. Q: Into<ManagedSlice<'a, Option<DnsQuery>>>,
  69. {
  70. DnsSocket {
  71. servers: Vec::from_slice(servers).unwrap(),
  72. queries: queries.into(),
  73. hop_limit: None,
  74. }
  75. }
  76. /// Update the list of DNS servers, will replace all existing servers
  77. ///
  78. /// # Panics
  79. ///
  80. /// Panics if `servers.len() > MAX_SERVER_COUNT`
  81. pub fn update_servers(&mut self, servers: &[IpAddress]) {
  82. self.servers = Vec::from_slice(servers).unwrap();
  83. }
  84. /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
  85. ///
  86. /// See also the [set_hop_limit](#method.set_hop_limit) method
  87. pub fn hop_limit(&self) -> Option<u8> {
  88. self.hop_limit
  89. }
  90. /// Set the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
  91. ///
  92. /// A socket without an explicitly set hop limit value uses the default [IANA recommended]
  93. /// value (64).
  94. ///
  95. /// # Panics
  96. ///
  97. /// This function panics if a hop limit value of 0 is given. See [RFC 1122 § 3.2.1.7].
  98. ///
  99. /// [IANA recommended]: https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml
  100. /// [RFC 1122 § 3.2.1.7]: https://tools.ietf.org/html/rfc1122#section-3.2.1.7
  101. pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
  102. // A host MUST NOT send a datagram with a hop limit value of 0
  103. if let Some(0) = hop_limit {
  104. panic!("the time-to-live value of a packet must not be zero")
  105. }
  106. self.hop_limit = hop_limit
  107. }
  108. fn find_free_query(&mut self) -> Result<QueryHandle> {
  109. for (i, q) in self.queries.iter().enumerate() {
  110. if q.is_none() {
  111. return Ok(QueryHandle(i));
  112. }
  113. }
  114. match self.queries {
  115. ManagedSlice::Borrowed(_) => Err(Error::Exhausted),
  116. #[cfg(any(feature = "std", feature = "alloc"))]
  117. ManagedSlice::Owned(ref mut queries) => {
  118. queries.push(None);
  119. let index = queries.len() - 1;
  120. Ok(QueryHandle(index))
  121. }
  122. }
  123. }
  124. pub fn start_query(&mut self, cx: &mut Context, name: &[u8]) -> Result<QueryHandle> {
  125. let handle = self.find_free_query()?;
  126. self.queries[handle.0] = Some(DnsQuery {
  127. state: State::Pending(PendingQuery {
  128. name: Vec::from_slice(name).map_err(|_| Error::Truncated)?,
  129. type_: Type::A,
  130. txid: cx.rand().rand_u16(),
  131. port: cx.rand().rand_source_port(),
  132. delay: RETRANSMIT_DELAY,
  133. timeout_at: None,
  134. retransmit_at: Instant::ZERO,
  135. server_idx: 0,
  136. }),
  137. });
  138. Ok(handle)
  139. }
  140. pub fn get_query_result(
  141. &mut self,
  142. handle: QueryHandle,
  143. ) -> Result<Vec<IpAddress, MAX_ADDRESS_COUNT>> {
  144. let slot = self.queries.get_mut(handle.0).ok_or(Error::Illegal)?;
  145. let q = slot.as_mut().ok_or(Error::Illegal)?;
  146. match &mut q.state {
  147. // Query is not done yet.
  148. State::Pending(_) => Err(Error::Exhausted),
  149. // Query is done
  150. State::Completed(q) => {
  151. let res = q.addresses.clone();
  152. *slot = None; // Free up the slot for recycling.
  153. Ok(res)
  154. }
  155. State::Failure => {
  156. *slot = None; // Free up the slot for recycling.
  157. Err(Error::Unaddressable)
  158. }
  159. }
  160. }
  161. pub fn cancel_query(&mut self, handle: QueryHandle) -> Result<()> {
  162. let slot = self.queries.get_mut(handle.0).ok_or(Error::Illegal)?;
  163. let q = slot.as_mut().ok_or(Error::Illegal)?;
  164. *slot = None; // Free up the slot for recycling.
  165. Ok(())
  166. }
  167. pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool {
  168. udp_repr.src_port == DNS_PORT
  169. && self
  170. .servers
  171. .iter()
  172. .any(|server| *server == ip_repr.src_addr())
  173. }
  174. pub(crate) fn process(
  175. &mut self,
  176. cx: &mut Context,
  177. ip_repr: &IpRepr,
  178. udp_repr: &UdpRepr,
  179. payload: &[u8],
  180. ) -> Result<()> {
  181. debug_assert!(self.accepts(ip_repr, udp_repr));
  182. let size = payload.len();
  183. net_trace!(
  184. "receiving {} octets from {:?}:{}",
  185. size,
  186. ip_repr.src_addr(),
  187. udp_repr.dst_port
  188. );
  189. let p = Packet::new_checked(payload)?;
  190. if p.opcode() != Opcode::Query {
  191. net_trace!("unwanted opcode {:?}", p.opcode());
  192. return Err(Error::Malformed);
  193. }
  194. if !p.flags().contains(Flags::RESPONSE) {
  195. net_trace!("packet doesn't have response bit set");
  196. return Err(Error::Malformed);
  197. }
  198. if p.question_count() != 1 {
  199. net_trace!("bad question count {:?}", p.question_count());
  200. return Err(Error::Malformed);
  201. }
  202. // Find pending query
  203. for q in self.queries.iter_mut().flatten() {
  204. if let State::Pending(pq) = &mut q.state {
  205. if udp_repr.dst_port != pq.port || p.transaction_id() != pq.txid {
  206. continue;
  207. }
  208. if p.rcode() == Rcode::NXDomain {
  209. net_trace!("rcode NXDomain");
  210. q.state = State::Failure;
  211. continue;
  212. }
  213. let payload = p.payload();
  214. let (mut payload, question) = Question::parse(payload)?;
  215. if question.type_ != pq.type_ {
  216. net_trace!("question type mismatch");
  217. return Err(Error::Malformed);
  218. }
  219. if !eq_names(p.parse_name(question.name), p.parse_name(&pq.name))? {
  220. net_trace!("question name mismatch");
  221. return Err(Error::Malformed);
  222. }
  223. let mut addresses = Vec::new();
  224. for _ in 0..p.answer_record_count() {
  225. let (payload2, r) = Record::parse(payload)?;
  226. payload = payload2;
  227. if !eq_names(p.parse_name(r.name), p.parse_name(&pq.name))? {
  228. net_trace!("answer name mismatch: {:?}", r);
  229. continue;
  230. }
  231. match r.data {
  232. RecordData::A(addr) => {
  233. net_trace!("A: {:?}", addr);
  234. if addresses.push(addr.into()).is_err() {
  235. net_trace!("too many addresses in response, ignoring {:?}", addr);
  236. }
  237. }
  238. RecordData::Aaaa(addr) => {
  239. net_trace!("AAAA: {:?}", addr);
  240. if addresses.push(addr.into()).is_err() {
  241. net_trace!("too many addresses in response, ignoring {:?}", addr);
  242. }
  243. }
  244. RecordData::Cname(name) => {
  245. net_trace!("CNAME: {:?}", name);
  246. copy_name(&mut pq.name, p.parse_name(name))?;
  247. // Relaunch query with the new name.
  248. // If the server has bundled A records for the CNAME in the same packet,
  249. // we'll process them in next iterations, and cancel the query relaunch.
  250. pq.retransmit_at = Instant::ZERO;
  251. pq.delay = RETRANSMIT_DELAY;
  252. pq.txid = cx.rand().rand_u16();
  253. pq.port = cx.rand().rand_source_port();
  254. }
  255. RecordData::Other(type_, data) => {
  256. net_trace!("unknown: {:?} {:?}", type_, data)
  257. }
  258. }
  259. }
  260. if !addresses.is_empty() {
  261. q.state = State::Completed(CompletedQuery { addresses })
  262. } else {
  263. q.state = State::Failure;
  264. }
  265. // If we get here, packet matched the current query, stop processing.
  266. return Ok(());
  267. }
  268. }
  269. // If we get here, packet matched with no query.
  270. net_trace!("no query matched");
  271. Ok(())
  272. }
  273. pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<()>
  274. where
  275. F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<()>,
  276. {
  277. let hop_limit = self.hop_limit.unwrap_or(64);
  278. for q in self.queries.iter_mut().flatten() {
  279. if let State::Pending(pq) = &mut q.state {
  280. let timeout = if let Some(timeout) = pq.timeout_at {
  281. timeout
  282. } else {
  283. let v = cx.now() + RETRANSMIT_TIMEOUT;
  284. pq.timeout_at = Some(v);
  285. v
  286. };
  287. // Check timeout
  288. if timeout < cx.now() {
  289. // DNS timeout
  290. pq.timeout_at = Some(cx.now() + RETRANSMIT_TIMEOUT);
  291. pq.retransmit_at = Instant::ZERO;
  292. pq.delay = RETRANSMIT_DELAY;
  293. // Try next server. We check below whether we've tried all servers.
  294. pq.server_idx += 1;
  295. }
  296. // Check if we've run out of servers to try.
  297. if pq.server_idx >= self.servers.len() {
  298. net_trace!("already tried all servers.");
  299. q.state = State::Failure;
  300. continue;
  301. }
  302. // Check so the IP address is valid
  303. if self.servers[pq.server_idx].is_unspecified() {
  304. net_trace!("invalid unspecified DNS server addr.");
  305. q.state = State::Failure;
  306. continue;
  307. }
  308. if pq.retransmit_at > cx.now() {
  309. // query is waiting for retransmit
  310. continue;
  311. }
  312. let repr = Repr {
  313. transaction_id: pq.txid,
  314. flags: Flags::RECURSION_DESIRED,
  315. opcode: Opcode::Query,
  316. question: Question {
  317. name: &pq.name,
  318. type_: Type::A,
  319. },
  320. };
  321. let mut payload = [0u8; 512];
  322. let payload = &mut payload[..repr.buffer_len()];
  323. repr.emit(&mut Packet::new_unchecked(payload));
  324. let udp_repr = UdpRepr {
  325. src_port: pq.port,
  326. dst_port: 53,
  327. };
  328. let dst_addr = self.servers[pq.server_idx];
  329. let src_addr = cx.get_source_address(dst_addr).unwrap(); // TODO remove unwrap
  330. let ip_repr = IpRepr::new(
  331. src_addr,
  332. dst_addr,
  333. IpProtocol::Udp,
  334. udp_repr.header_len() + payload.len(),
  335. hop_limit,
  336. );
  337. net_trace!(
  338. "sending {} octets to {:?}:{}",
  339. payload.len(),
  340. ip_repr.dst_addr(),
  341. udp_repr.src_port
  342. );
  343. if let Err(e) = emit(cx, (ip_repr, udp_repr, payload)) {
  344. net_trace!("DNS emit error {:?}", e);
  345. return Ok(());
  346. }
  347. pq.retransmit_at = cx.now() + pq.delay;
  348. pq.delay = MAX_RETRANSMIT_DELAY.min(pq.delay * 2);
  349. return Ok(());
  350. }
  351. }
  352. // Nothing to dispatch
  353. Err(Error::Exhausted)
  354. }
  355. pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt {
  356. self.queries
  357. .iter()
  358. .flatten()
  359. .filter_map(|q| match &q.state {
  360. State::Pending(pq) => Some(PollAt::Time(pq.retransmit_at)),
  361. State::Completed(_) => None,
  362. State::Failure => None,
  363. })
  364. .min()
  365. .unwrap_or(PollAt::Ingress)
  366. }
  367. }
  368. impl<'a> From<DnsSocket<'a>> for Socket<'a> {
  369. fn from(val: DnsSocket<'a>) -> Self {
  370. Socket::Dns(val)
  371. }
  372. }
  373. fn eq_names<'a>(
  374. mut a: impl Iterator<Item = Result<&'a [u8]>>,
  375. mut b: impl Iterator<Item = Result<&'a [u8]>>,
  376. ) -> Result<bool> {
  377. loop {
  378. match (a.next(), b.next()) {
  379. // Handle errors
  380. (Some(Err(e)), _) => return Err(e),
  381. (_, Some(Err(e))) => return Err(e),
  382. // Both finished -> equal
  383. (None, None) => return Ok(true),
  384. // One finished before the other -> not equal
  385. (None, _) => return Ok(false),
  386. (_, None) => return Ok(false),
  387. // Got two labels, check if they're equal
  388. (Some(Ok(la)), Some(Ok(lb))) => {
  389. if la != lb {
  390. return Ok(false);
  391. }
  392. }
  393. }
  394. }
  395. }
  396. fn copy_name<'a, const N: usize>(
  397. dest: &mut Vec<u8, N>,
  398. name: impl Iterator<Item = Result<&'a [u8]>>,
  399. ) -> Result<()> {
  400. dest.truncate(0);
  401. for label in name {
  402. let label = label?;
  403. dest.push(label.len() as u8).map_err(|_| Error::Truncated);
  404. dest.extend_from_slice(label).map_err(|_| Error::Truncated);
  405. }
  406. // Write terminator 0x00
  407. dest.push(0).map_err(|_| Error::Truncated);
  408. Ok(())
  409. }