dns.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. #![allow(dead_code, unused)]
  2. use heapless::Vec;
  3. use managed::ManagedSlice;
  4. use crate::socket::{Context, PollAt, Socket};
  5. use crate::time::{Duration, Instant};
  6. use crate::wire::dns::{Flags, Opcode, Packet, Question, Record, RecordData, Repr, Type};
  7. use crate::wire::{IpAddress, IpEndpoint, IpProtocol, IpRepr, Ipv4Address, UdpRepr};
  8. use crate::{rand, Error, Result};
  9. const DNS_PORT: u16 = 53;
  10. const MAX_NAME_LEN: usize = 255;
  11. const MAX_ADDRESS_COUNT: usize = 4;
  12. const RETRANSMIT_DELAY: Duration = Duration::from_millis(1000);
  13. const MAX_RETRANSMIT_DELAY: Duration = Duration::from_millis(10000);
  14. /// State for an in-progress DNS query.
  15. ///
  16. /// The only reason this struct is public is to allow the socket state
  17. /// to be allocated externally.
  18. #[derive(Debug)]
  19. pub struct DnsQuery {
  20. state: State,
  21. }
  22. #[derive(Debug)]
  23. #[allow(clippy::large_enum_variant)]
  24. enum State {
  25. Pending(PendingQuery),
  26. Completed(CompletedQuery),
  27. }
  28. #[derive(Debug)]
  29. struct PendingQuery {
  30. name: Vec<u8, MAX_NAME_LEN>,
  31. type_: Type,
  32. port: u16, // UDP port (src for request, dst for response)
  33. txid: u16, // transaction ID
  34. retransmit_at: Instant,
  35. delay: Duration,
  36. }
  37. #[derive(Debug)]
  38. struct CompletedQuery {
  39. addresses: Vec<IpAddress, MAX_ADDRESS_COUNT>,
  40. }
  41. /// A handle to an in-progress DNS query.
  42. #[derive(Clone, Copy)]
  43. pub struct QueryHandle(usize);
  44. /// A Domain Name System socket.
  45. ///
  46. /// A UDP socket is bound to a specific endpoint, and owns transmit and receive
  47. /// packet buffers.
  48. #[derive(Debug)]
  49. pub struct DnsSocket<'a> {
  50. servers: ManagedSlice<'a, IpAddress>,
  51. queries: ManagedSlice<'a, Option<DnsQuery>>,
  52. /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
  53. hop_limit: Option<u8>,
  54. }
  55. impl<'a> DnsSocket<'a> {
  56. /// Create a DNS socket with the given buffers.
  57. pub fn new<Q, S>(servers: S, queries: Q) -> DnsSocket<'a>
  58. where
  59. S: Into<ManagedSlice<'a, IpAddress>>,
  60. Q: Into<ManagedSlice<'a, Option<DnsQuery>>>,
  61. {
  62. DnsSocket {
  63. servers: servers.into(),
  64. queries: queries.into(),
  65. hop_limit: None,
  66. }
  67. }
  68. /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
  69. ///
  70. /// See also the [set_hop_limit](#method.set_hop_limit) method
  71. pub fn hop_limit(&self) -> Option<u8> {
  72. self.hop_limit
  73. }
  74. /// Set the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
  75. ///
  76. /// A socket without an explicitly set hop limit value uses the default [IANA recommended]
  77. /// value (64).
  78. ///
  79. /// # Panics
  80. ///
  81. /// This function panics if a hop limit value of 0 is given. See [RFC 1122 § 3.2.1.7].
  82. ///
  83. /// [IANA recommended]: https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml
  84. /// [RFC 1122 § 3.2.1.7]: https://tools.ietf.org/html/rfc1122#section-3.2.1.7
  85. pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
  86. // A host MUST NOT send a datagram with a hop limit value of 0
  87. if let Some(0) = hop_limit {
  88. panic!("the time-to-live value of a packet must not be zero")
  89. }
  90. self.hop_limit = hop_limit
  91. }
  92. fn find_free_query(&mut self) -> Result<QueryHandle> {
  93. for (i, q) in self.queries.iter().enumerate() {
  94. if q.is_none() {
  95. return Ok(QueryHandle(i));
  96. }
  97. }
  98. match self.queries {
  99. ManagedSlice::Borrowed(_) => Err(Error::Exhausted),
  100. #[cfg(any(feature = "std", feature = "alloc"))]
  101. ManagedSlice::Owned(ref mut queries) => {
  102. queries.push(None);
  103. let index = queries.len() - 1;
  104. Ok(QueryHandle(index))
  105. }
  106. }
  107. }
  108. pub fn start_query(&mut self, cx: &mut Context, name: &[u8]) -> Result<QueryHandle> {
  109. let handle = self.find_free_query()?;
  110. self.queries[handle.0] = Some(DnsQuery {
  111. state: State::Pending(PendingQuery {
  112. name: Vec::from_slice(name).map_err(|_| Error::Truncated)?,
  113. type_: Type::A,
  114. txid: cx.rand().rand_u16(),
  115. port: cx.rand().rand_source_port(),
  116. delay: RETRANSMIT_DELAY,
  117. retransmit_at: Instant::ZERO,
  118. }),
  119. });
  120. Ok(handle)
  121. }
  122. pub fn get_query_result(
  123. &mut self,
  124. handle: QueryHandle,
  125. ) -> Result<Vec<IpAddress, MAX_ADDRESS_COUNT>> {
  126. let slot = self.queries.get_mut(handle.0).ok_or(Error::Illegal)?;
  127. let q = slot.as_mut().ok_or(Error::Illegal)?;
  128. match &mut q.state {
  129. // Query is not done yet.
  130. State::Pending(_) => Err(Error::Exhausted),
  131. // Query is done
  132. State::Completed(q) => {
  133. let res = q.addresses.clone();
  134. *slot = None; // Free up the slot for recycling.
  135. Ok(res)
  136. }
  137. }
  138. }
  139. pub fn cancel_query(&mut self, handle: QueryHandle) -> Result<()> {
  140. let slot = self.queries.get_mut(handle.0).ok_or(Error::Illegal)?;
  141. let q = slot.as_mut().ok_or(Error::Illegal)?;
  142. *slot = None; // Free up the slot for recycling.
  143. Ok(())
  144. }
  145. pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool {
  146. udp_repr.src_port == DNS_PORT
  147. && self
  148. .servers
  149. .iter()
  150. .any(|server| *server == ip_repr.src_addr())
  151. }
  152. pub(crate) fn process(
  153. &mut self,
  154. cx: &mut Context,
  155. ip_repr: &IpRepr,
  156. udp_repr: &UdpRepr,
  157. payload: &[u8],
  158. ) -> Result<()> {
  159. debug_assert!(self.accepts(ip_repr, udp_repr));
  160. let size = payload.len();
  161. net_trace!(
  162. "receiving {} octets from {:?}:{}",
  163. size,
  164. ip_repr.src_addr(),
  165. udp_repr.dst_port
  166. );
  167. let p = Packet::new_checked(payload)?;
  168. if p.opcode() != Opcode::Query {
  169. net_trace!("unwanted opcode {:?}", p.opcode());
  170. return Err(Error::Malformed);
  171. }
  172. if !p.flags().contains(Flags::RESPONSE) {
  173. net_trace!("packet doesn't have response bit set");
  174. return Err(Error::Malformed);
  175. }
  176. if p.question_count() != 1 {
  177. net_trace!("bad question count {:?}", p.question_count());
  178. return Err(Error::Malformed);
  179. }
  180. // Find pending query
  181. for q in self.queries.iter_mut().flatten() {
  182. if let State::Pending(pq) = &mut q.state {
  183. if udp_repr.dst_port != pq.port || p.transaction_id() != pq.txid {
  184. continue;
  185. }
  186. let payload = p.payload();
  187. let (mut payload, question) = Question::parse(payload)?;
  188. if question.type_ != pq.type_ {
  189. net_trace!("question type mismatch");
  190. return Err(Error::Malformed);
  191. }
  192. if !eq_names(p.parse_name(question.name), p.parse_name(&pq.name))? {
  193. net_trace!("question name mismatch");
  194. return Err(Error::Malformed);
  195. }
  196. let mut addresses = Vec::new();
  197. for _ in 0..p.answer_record_count() {
  198. let (payload2, r) = Record::parse(payload)?;
  199. payload = payload2;
  200. if !eq_names(p.parse_name(r.name), p.parse_name(&pq.name))? {
  201. net_trace!("answer name mismatch: {:?}", r);
  202. continue;
  203. }
  204. match r.data {
  205. RecordData::A(addr) => {
  206. net_trace!("A: {:?}", addr);
  207. if addresses.push(addr.into()).is_err() {
  208. net_trace!("too many addresses in response, ignoring {:?}", addr);
  209. }
  210. }
  211. RecordData::Aaaa(addr) => {
  212. net_trace!("AAAA: {:?}", addr);
  213. if addresses.push(addr.into()).is_err() {
  214. net_trace!("too many addresses in response, ignoring {:?}", addr);
  215. }
  216. }
  217. RecordData::Cname(name) => {
  218. net_trace!("CNAME: {:?}", name);
  219. copy_name(&mut pq.name, p.parse_name(name))?;
  220. // Relaunch query with the new name.
  221. // If the server has bundled A records for the CNAME in the same packet,
  222. // we'll process them in next iterations, and cancel the query relaunch.
  223. pq.retransmit_at = Instant::ZERO;
  224. pq.delay = RETRANSMIT_DELAY;
  225. pq.txid = cx.rand().rand_u16();
  226. pq.port = cx.rand().rand_source_port();
  227. }
  228. RecordData::Other(type_, data) => {
  229. net_trace!("unknown: {:?} {:?}", type_, data)
  230. }
  231. }
  232. }
  233. if !addresses.is_empty() {
  234. q.state = State::Completed(CompletedQuery { addresses })
  235. }
  236. // If we get here, packet matched the current query, stop processing.
  237. return Ok(());
  238. }
  239. }
  240. // If we get here, packet matched with no query.
  241. net_trace!("no query matched");
  242. Ok(())
  243. }
  244. pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<()>
  245. where
  246. F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<()>,
  247. {
  248. let hop_limit = self.hop_limit.unwrap_or(64);
  249. for q in self.queries.iter_mut().flatten() {
  250. if let State::Pending(pq) = &mut q.state {
  251. if pq.retransmit_at > cx.now() {
  252. // query is waiting for retransmit
  253. continue;
  254. }
  255. let repr = Repr {
  256. transaction_id: pq.txid,
  257. flags: Flags::RECURSION_DESIRED,
  258. opcode: Opcode::Query,
  259. question: Question {
  260. name: &pq.name,
  261. type_: Type::A,
  262. },
  263. };
  264. let mut payload = [0u8; 512];
  265. let payload = &mut payload[..repr.buffer_len()];
  266. repr.emit(&mut Packet::new_unchecked(payload));
  267. let udp_repr = UdpRepr {
  268. src_port: pq.port,
  269. dst_port: 53,
  270. };
  271. let dst_addr = self.servers[0];
  272. let src_addr = cx.get_source_address(dst_addr).unwrap(); // TODO remove unwrap
  273. let ip_repr = IpRepr::new(
  274. src_addr,
  275. dst_addr,
  276. IpProtocol::Udp,
  277. udp_repr.header_len() + payload.len(),
  278. hop_limit,
  279. );
  280. net_trace!(
  281. "sending {} octets to {:?}:{}",
  282. payload.len(),
  283. ip_repr.dst_addr(),
  284. udp_repr.src_port
  285. );
  286. if let Err(e) = emit(cx, (ip_repr, udp_repr, payload)) {
  287. net_trace!("DNS emit error {:?}", e);
  288. return Ok(());
  289. }
  290. pq.retransmit_at = cx.now() + pq.delay;
  291. pq.delay = MAX_RETRANSMIT_DELAY.min(pq.delay * 2);
  292. return Ok(());
  293. }
  294. }
  295. // Nothing to dispatch
  296. Err(Error::Exhausted)
  297. }
  298. pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt {
  299. self.queries
  300. .iter()
  301. .flatten()
  302. .filter_map(|q| match &q.state {
  303. State::Pending(pq) => Some(PollAt::Time(pq.retransmit_at)),
  304. State::Completed(_) => None,
  305. })
  306. .min()
  307. .unwrap_or(PollAt::Ingress)
  308. }
  309. }
  310. impl<'a> From<DnsSocket<'a>> for Socket<'a> {
  311. fn from(val: DnsSocket<'a>) -> Self {
  312. Socket::Dns(val)
  313. }
  314. }
  315. fn eq_names<'a>(
  316. mut a: impl Iterator<Item = Result<&'a [u8]>>,
  317. mut b: impl Iterator<Item = Result<&'a [u8]>>,
  318. ) -> Result<bool> {
  319. loop {
  320. match (a.next(), b.next()) {
  321. // Handle errors
  322. (Some(Err(e)), _) => return Err(e),
  323. (_, Some(Err(e))) => return Err(e),
  324. // Both finished -> equal
  325. (None, None) => return Ok(true),
  326. // One finished before the other -> not equal
  327. (None, _) => return Ok(false),
  328. (_, None) => return Ok(false),
  329. // Got two labels, check if they're equal
  330. (Some(Ok(la)), Some(Ok(lb))) => {
  331. if la != lb {
  332. return Ok(false);
  333. }
  334. }
  335. }
  336. }
  337. }
  338. fn copy_name<'a, const N: usize>(
  339. dest: &mut Vec<u8, N>,
  340. name: impl Iterator<Item = Result<&'a [u8]>>,
  341. ) -> Result<()> {
  342. dest.truncate(0);
  343. for label in name {
  344. let label = label?;
  345. dest.push(label.len() as u8).map_err(|_| Error::Truncated);
  346. dest.extend_from_slice(label).map_err(|_| Error::Truncated);
  347. }
  348. // Write terminator 0x00
  349. dest.push(0).map_err(|_| Error::Truncated);
  350. Ok(())
  351. }