dns.rs 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699
  1. #[cfg(feature = "async")]
  2. use core::task::Waker;
  3. use heapless::Vec;
  4. use managed::ManagedSlice;
  5. use crate::config::{DNS_MAX_NAME_SIZE, DNS_MAX_RESULT_COUNT, DNS_MAX_SERVER_COUNT};
  6. use crate::socket::{Context, PollAt};
  7. use crate::time::{Duration, Instant};
  8. use crate::wire::dns::{Flags, Opcode, Packet, Question, Rcode, Record, RecordData, Repr, Type};
  9. use crate::wire::{self, IpAddress, IpProtocol, IpRepr, UdpRepr};
  10. #[cfg(feature = "async")]
  11. use super::WakerRegistration;
  12. const DNS_PORT: u16 = 53;
  13. const MDNS_DNS_PORT: u16 = 5353;
  14. const RETRANSMIT_DELAY: Duration = Duration::from_millis(1_000);
  15. const MAX_RETRANSMIT_DELAY: Duration = Duration::from_millis(10_000);
  16. const RETRANSMIT_TIMEOUT: Duration = Duration::from_millis(10_000); // Should generally be 2-10 secs
  17. #[cfg(feature = "proto-ipv6")]
  18. const MDNS_IPV6_ADDR: IpAddress = IpAddress::Ipv6(crate::wire::Ipv6Address([
  19. 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfb,
  20. ]));
  21. #[cfg(feature = "proto-ipv4")]
  22. const MDNS_IPV4_ADDR: IpAddress = IpAddress::Ipv4(crate::wire::Ipv4Address([224, 0, 0, 251]));
  23. /// Error returned by [`Socket::start_query`]
  24. #[derive(Debug, PartialEq, Eq, Clone, Copy)]
  25. #[cfg_attr(feature = "defmt", derive(defmt::Format))]
  26. pub enum StartQueryError {
  27. NoFreeSlot,
  28. InvalidName,
  29. NameTooLong,
  30. }
  31. impl core::fmt::Display for StartQueryError {
  32. fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
  33. match self {
  34. StartQueryError::NoFreeSlot => write!(f, "No free slot"),
  35. StartQueryError::InvalidName => write!(f, "Invalid name"),
  36. StartQueryError::NameTooLong => write!(f, "Name too long"),
  37. }
  38. }
  39. }
  40. #[cfg(feature = "std")]
  41. impl std::error::Error for StartQueryError {}
  42. /// Error returned by [`Socket::get_query_result`]
  43. #[derive(Debug, PartialEq, Eq, Clone, Copy)]
  44. #[cfg_attr(feature = "defmt", derive(defmt::Format))]
  45. pub enum GetQueryResultError {
  46. /// Query is not done yet.
  47. Pending,
  48. /// Query failed.
  49. Failed,
  50. }
  51. impl core::fmt::Display for GetQueryResultError {
  52. fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
  53. match self {
  54. GetQueryResultError::Pending => write!(f, "Query is not done yet"),
  55. GetQueryResultError::Failed => write!(f, "Query failed"),
  56. }
  57. }
  58. }
  59. #[cfg(feature = "std")]
  60. impl std::error::Error for GetQueryResultError {}
  61. /// State for an in-progress DNS query.
  62. ///
  63. /// The only reason this struct is public is to allow the socket state
  64. /// to be allocated externally.
  65. #[derive(Debug)]
  66. pub struct DnsQuery {
  67. state: State,
  68. #[cfg(feature = "async")]
  69. waker: WakerRegistration,
  70. }
  71. impl DnsQuery {
  72. fn set_state(&mut self, state: State) {
  73. self.state = state;
  74. #[cfg(feature = "async")]
  75. self.waker.wake();
  76. }
  77. }
  78. #[derive(Debug)]
  79. #[allow(clippy::large_enum_variant)]
  80. enum State {
  81. Pending(PendingQuery),
  82. Completed(CompletedQuery),
  83. Failure,
  84. }
  85. #[derive(Debug)]
  86. struct PendingQuery {
  87. name: Vec<u8, DNS_MAX_NAME_SIZE>,
  88. type_: Type,
  89. port: u16, // UDP port (src for request, dst for response)
  90. txid: u16, // transaction ID
  91. timeout_at: Option<Instant>,
  92. retransmit_at: Instant,
  93. delay: Duration,
  94. server_idx: usize,
  95. mdns: MulticastDns,
  96. }
  97. #[derive(Debug)]
  98. pub enum MulticastDns {
  99. Disabled,
  100. #[cfg(feature = "socket-mdns")]
  101. Enabled,
  102. }
  103. #[derive(Debug)]
  104. struct CompletedQuery {
  105. addresses: Vec<IpAddress, DNS_MAX_RESULT_COUNT>,
  106. }
  107. /// A handle to an in-progress DNS query.
  108. #[derive(Clone, Copy)]
  109. pub struct QueryHandle(usize);
  110. /// A Domain Name System socket.
  111. ///
  112. /// A UDP socket is bound to a specific endpoint, and owns transmit and receive
  113. /// packet buffers.
  114. #[derive(Debug)]
  115. pub struct Socket<'a> {
  116. servers: Vec<IpAddress, DNS_MAX_SERVER_COUNT>,
  117. queries: ManagedSlice<'a, Option<DnsQuery>>,
  118. /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
  119. hop_limit: Option<u8>,
  120. }
  121. impl<'a> Socket<'a> {
  122. /// Create a DNS socket.
  123. ///
  124. /// # Panics
  125. ///
  126. /// Panics if `servers.len() > MAX_SERVER_COUNT`
  127. pub fn new<Q>(servers: &[IpAddress], queries: Q) -> Socket<'a>
  128. where
  129. Q: Into<ManagedSlice<'a, Option<DnsQuery>>>,
  130. {
  131. Socket {
  132. servers: Vec::from_slice(servers).unwrap(),
  133. queries: queries.into(),
  134. hop_limit: None,
  135. }
  136. }
  137. /// Update the list of DNS servers, will replace all existing servers
  138. ///
  139. /// # Panics
  140. ///
  141. /// Panics if `servers.len() > MAX_SERVER_COUNT`
  142. pub fn update_servers(&mut self, servers: &[IpAddress]) {
  143. self.servers = Vec::from_slice(servers).unwrap();
  144. }
  145. /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
  146. ///
  147. /// See also the [set_hop_limit](#method.set_hop_limit) method
  148. pub fn hop_limit(&self) -> Option<u8> {
  149. self.hop_limit
  150. }
  151. /// Set the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
  152. ///
  153. /// A socket without an explicitly set hop limit value uses the default [IANA recommended]
  154. /// value (64).
  155. ///
  156. /// # Panics
  157. ///
  158. /// This function panics if a hop limit value of 0 is given. See [RFC 1122 § 3.2.1.7].
  159. ///
  160. /// [IANA recommended]: https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml
  161. /// [RFC 1122 § 3.2.1.7]: https://tools.ietf.org/html/rfc1122#section-3.2.1.7
  162. pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
  163. // A host MUST NOT send a datagram with a hop limit value of 0
  164. if let Some(0) = hop_limit {
  165. panic!("the time-to-live value of a packet must not be zero")
  166. }
  167. self.hop_limit = hop_limit
  168. }
  169. fn find_free_query(&mut self) -> Option<QueryHandle> {
  170. for (i, q) in self.queries.iter().enumerate() {
  171. if q.is_none() {
  172. return Some(QueryHandle(i));
  173. }
  174. }
  175. match &mut self.queries {
  176. ManagedSlice::Borrowed(_) => None,
  177. #[cfg(feature = "alloc")]
  178. ManagedSlice::Owned(queries) => {
  179. queries.push(None);
  180. let index = queries.len() - 1;
  181. Some(QueryHandle(index))
  182. }
  183. }
  184. }
  185. /// Start a query.
  186. ///
  187. /// `name` is specified in human-friendly format, such as `"rust-lang.org"`.
  188. /// It accepts names both with and without trailing dot, and they're treated
  189. /// the same (there's no support for DNS search path).
  190. pub fn start_query(
  191. &mut self,
  192. cx: &mut Context,
  193. name: &str,
  194. query_type: Type,
  195. ) -> Result<QueryHandle, StartQueryError> {
  196. let mut name = name.as_bytes();
  197. if name.is_empty() {
  198. net_trace!("invalid name: zero length");
  199. return Err(StartQueryError::InvalidName);
  200. }
  201. // Remove trailing dot, if any
  202. if name[name.len() - 1] == b'.' {
  203. name = &name[..name.len() - 1];
  204. }
  205. let mut raw_name: Vec<u8, DNS_MAX_NAME_SIZE> = Vec::new();
  206. let mut mdns = MulticastDns::Disabled;
  207. #[cfg(feature = "socket-mdns")]
  208. if name.split(|&c| c == b'.').last().unwrap() == b"local" {
  209. net_trace!("Starting a mDNS query");
  210. mdns = MulticastDns::Enabled;
  211. }
  212. for s in name.split(|&c| c == b'.') {
  213. if s.len() > 63 {
  214. net_trace!("invalid name: too long label");
  215. return Err(StartQueryError::InvalidName);
  216. }
  217. if s.is_empty() {
  218. net_trace!("invalid name: zero length label");
  219. return Err(StartQueryError::InvalidName);
  220. }
  221. // Push label
  222. raw_name
  223. .push(s.len() as u8)
  224. .map_err(|_| StartQueryError::NameTooLong)?;
  225. raw_name
  226. .extend_from_slice(s)
  227. .map_err(|_| StartQueryError::NameTooLong)?;
  228. }
  229. // Push terminator.
  230. raw_name
  231. .push(0x00)
  232. .map_err(|_| StartQueryError::NameTooLong)?;
  233. self.start_query_raw(cx, &raw_name, query_type, mdns)
  234. }
  235. /// Start a query with a raw (wire-format) DNS name.
  236. /// `b"\x09rust-lang\x03org\x00"`
  237. ///
  238. /// You probably want to use [`start_query`] instead.
  239. pub fn start_query_raw(
  240. &mut self,
  241. cx: &mut Context,
  242. raw_name: &[u8],
  243. query_type: Type,
  244. mdns: MulticastDns,
  245. ) -> Result<QueryHandle, StartQueryError> {
  246. let handle = self.find_free_query().ok_or(StartQueryError::NoFreeSlot)?;
  247. self.queries[handle.0] = Some(DnsQuery {
  248. state: State::Pending(PendingQuery {
  249. name: Vec::from_slice(raw_name).map_err(|_| StartQueryError::NameTooLong)?,
  250. type_: query_type,
  251. txid: cx.rand().rand_u16(),
  252. port: cx.rand().rand_source_port(),
  253. delay: RETRANSMIT_DELAY,
  254. timeout_at: None,
  255. retransmit_at: Instant::ZERO,
  256. server_idx: 0,
  257. mdns,
  258. }),
  259. #[cfg(feature = "async")]
  260. waker: WakerRegistration::new(),
  261. });
  262. Ok(handle)
  263. }
  264. /// Get the result of a query.
  265. ///
  266. /// If the query is completed, the query slot is automatically freed.
  267. ///
  268. /// # Panics
  269. /// Panics if the QueryHandle corresponds to a free slot.
  270. pub fn get_query_result(
  271. &mut self,
  272. handle: QueryHandle,
  273. ) -> Result<Vec<IpAddress, DNS_MAX_RESULT_COUNT>, GetQueryResultError> {
  274. let slot = &mut self.queries[handle.0];
  275. let q = slot.as_mut().unwrap();
  276. match &mut q.state {
  277. // Query is not done yet.
  278. State::Pending(_) => Err(GetQueryResultError::Pending),
  279. // Query is done
  280. State::Completed(q) => {
  281. let res = q.addresses.clone();
  282. *slot = None; // Free up the slot for recycling.
  283. Ok(res)
  284. }
  285. State::Failure => {
  286. *slot = None; // Free up the slot for recycling.
  287. Err(GetQueryResultError::Failed)
  288. }
  289. }
  290. }
  291. /// Cancels a query, freeing the slot.
  292. ///
  293. /// # Panics
  294. ///
  295. /// Panics if the QueryHandle corresponds to an already free slot.
  296. pub fn cancel_query(&mut self, handle: QueryHandle) {
  297. let slot = &mut self.queries[handle.0];
  298. if slot.is_none() {
  299. panic!("Canceling query in a free slot.")
  300. }
  301. *slot = None; // Free up the slot for recycling.
  302. }
  303. /// Assign a waker to a query slot
  304. ///
  305. /// The waker will be woken when the query completes, either successfully or failed.
  306. ///
  307. /// # Panics
  308. ///
  309. /// Panics if the QueryHandle corresponds to an already free slot.
  310. #[cfg(feature = "async")]
  311. pub fn register_query_waker(&mut self, handle: QueryHandle, waker: &Waker) {
  312. self.queries[handle.0]
  313. .as_mut()
  314. .unwrap()
  315. .waker
  316. .register(waker);
  317. }
  318. pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool {
  319. (udp_repr.src_port == DNS_PORT
  320. && self
  321. .servers
  322. .iter()
  323. .any(|server| *server == ip_repr.src_addr()))
  324. || (udp_repr.src_port == MDNS_DNS_PORT)
  325. }
  326. pub(crate) fn process(
  327. &mut self,
  328. _cx: &mut Context,
  329. ip_repr: &IpRepr,
  330. udp_repr: &UdpRepr,
  331. payload: &[u8],
  332. ) {
  333. debug_assert!(self.accepts(ip_repr, udp_repr));
  334. let size = payload.len();
  335. net_trace!(
  336. "receiving {} octets from {:?}:{}",
  337. size,
  338. ip_repr.src_addr(),
  339. udp_repr.dst_port
  340. );
  341. let p = match Packet::new_checked(payload) {
  342. Ok(x) => x,
  343. Err(_) => {
  344. net_trace!("dns packet malformed");
  345. return;
  346. }
  347. };
  348. if p.opcode() != Opcode::Query {
  349. net_trace!("unwanted opcode {:?}", p.opcode());
  350. return;
  351. }
  352. if !p.flags().contains(Flags::RESPONSE) {
  353. net_trace!("packet doesn't have response bit set");
  354. return;
  355. }
  356. if p.question_count() != 1 {
  357. net_trace!("bad question count {:?}", p.question_count());
  358. return;
  359. }
  360. // Find pending query
  361. for q in self.queries.iter_mut().flatten() {
  362. if let State::Pending(pq) = &mut q.state {
  363. if udp_repr.dst_port != pq.port || p.transaction_id() != pq.txid {
  364. continue;
  365. }
  366. if p.rcode() == Rcode::NXDomain {
  367. net_trace!("rcode NXDomain");
  368. q.set_state(State::Failure);
  369. continue;
  370. }
  371. let payload = p.payload();
  372. let (mut payload, question) = match Question::parse(payload) {
  373. Ok(x) => x,
  374. Err(_) => {
  375. net_trace!("question malformed");
  376. return;
  377. }
  378. };
  379. if question.type_ != pq.type_ {
  380. net_trace!("question type mismatch");
  381. return;
  382. }
  383. match eq_names(p.parse_name(question.name), p.parse_name(&pq.name)) {
  384. Ok(true) => {}
  385. Ok(false) => {
  386. net_trace!("question name mismatch");
  387. return;
  388. }
  389. Err(_) => {
  390. net_trace!("dns question name malformed");
  391. return;
  392. }
  393. }
  394. let mut addresses = Vec::new();
  395. for _ in 0..p.answer_record_count() {
  396. let (payload2, r) = match Record::parse(payload) {
  397. Ok(x) => x,
  398. Err(_) => {
  399. net_trace!("dns answer record malformed");
  400. return;
  401. }
  402. };
  403. payload = payload2;
  404. match eq_names(p.parse_name(r.name), p.parse_name(&pq.name)) {
  405. Ok(true) => {}
  406. Ok(false) => {
  407. net_trace!("answer name mismatch: {:?}", r);
  408. continue;
  409. }
  410. Err(_) => {
  411. net_trace!("dns answer record name malformed");
  412. return;
  413. }
  414. }
  415. match r.data {
  416. #[cfg(feature = "proto-ipv4")]
  417. RecordData::A(addr) => {
  418. net_trace!("A: {:?}", addr);
  419. if addresses.push(addr.into()).is_err() {
  420. net_trace!("too many addresses in response, ignoring {:?}", addr);
  421. }
  422. }
  423. #[cfg(feature = "proto-ipv6")]
  424. RecordData::Aaaa(addr) => {
  425. net_trace!("AAAA: {:?}", addr);
  426. if addresses.push(addr.into()).is_err() {
  427. net_trace!("too many addresses in response, ignoring {:?}", addr);
  428. }
  429. }
  430. RecordData::Cname(name) => {
  431. net_trace!("CNAME: {:?}", name);
  432. // When faced with a CNAME, recursive resolvers are supposed to
  433. // resolve the CNAME and append the results for it.
  434. //
  435. // We update the query with the new name, so that we pick up the A/AAAA
  436. // records for the CNAME when we parse them later.
  437. // I believe it's mandatory the CNAME results MUST come *after* in the
  438. // packet, so it's enough to do one linear pass over it.
  439. if copy_name(&mut pq.name, p.parse_name(name)).is_err() {
  440. net_trace!("dns answer cname malformed");
  441. return;
  442. }
  443. }
  444. RecordData::Other(type_, data) => {
  445. net_trace!("unknown: {:?} {:?}", type_, data)
  446. }
  447. }
  448. }
  449. q.set_state(if addresses.is_empty() {
  450. State::Failure
  451. } else {
  452. State::Completed(CompletedQuery { addresses })
  453. });
  454. // If we get here, packet matched the current query, stop processing.
  455. return;
  456. }
  457. }
  458. // If we get here, packet matched with no query.
  459. net_trace!("no query matched");
  460. }
  461. pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
  462. where
  463. F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), E>,
  464. {
  465. let hop_limit = self.hop_limit.unwrap_or(64);
  466. for q in self.queries.iter_mut().flatten() {
  467. if let State::Pending(pq) = &mut q.state {
  468. // As per RFC 6762 any DNS query ending in .local. MUST be sent as mdns
  469. // so we internally overwrite the servers for any of those queries
  470. // in this function.
  471. let servers = match pq.mdns {
  472. #[cfg(feature = "socket-mdns")]
  473. MulticastDns::Enabled => &[
  474. #[cfg(feature = "proto-ipv6")]
  475. MDNS_IPV6_ADDR,
  476. #[cfg(feature = "proto-ipv4")]
  477. MDNS_IPV4_ADDR,
  478. ],
  479. MulticastDns::Disabled => self.servers.as_slice(),
  480. };
  481. let timeout = if let Some(timeout) = pq.timeout_at {
  482. timeout
  483. } else {
  484. let v = cx.now() + RETRANSMIT_TIMEOUT;
  485. pq.timeout_at = Some(v);
  486. v
  487. };
  488. // Check timeout
  489. if timeout < cx.now() {
  490. // DNS timeout
  491. pq.timeout_at = Some(cx.now() + RETRANSMIT_TIMEOUT);
  492. pq.retransmit_at = Instant::ZERO;
  493. pq.delay = RETRANSMIT_DELAY;
  494. // Try next server. We check below whether we've tried all servers.
  495. pq.server_idx += 1;
  496. }
  497. // Check if we've run out of servers to try.
  498. if pq.server_idx >= servers.len() {
  499. net_trace!("already tried all servers.");
  500. q.set_state(State::Failure);
  501. continue;
  502. }
  503. // Check so the IP address is valid
  504. if servers[pq.server_idx].is_unspecified() {
  505. net_trace!("invalid unspecified DNS server addr.");
  506. q.set_state(State::Failure);
  507. continue;
  508. }
  509. if pq.retransmit_at > cx.now() {
  510. // query is waiting for retransmit
  511. continue;
  512. }
  513. let repr = Repr {
  514. transaction_id: pq.txid,
  515. flags: Flags::RECURSION_DESIRED,
  516. opcode: Opcode::Query,
  517. question: Question {
  518. name: &pq.name,
  519. type_: pq.type_,
  520. },
  521. };
  522. let mut payload = [0u8; 512];
  523. let payload = &mut payload[..repr.buffer_len()];
  524. repr.emit(&mut Packet::new_unchecked(payload));
  525. let dst_port = match pq.mdns {
  526. #[cfg(feature = "socket-mdns")]
  527. MulticastDns::Enabled => MDNS_DNS_PORT,
  528. MulticastDns::Disabled => DNS_PORT,
  529. };
  530. let udp_repr = UdpRepr {
  531. src_port: pq.port,
  532. dst_port,
  533. };
  534. let dst_addr = servers[pq.server_idx];
  535. let src_addr = cx.get_source_address(dst_addr).unwrap(); // TODO remove unwrap
  536. let ip_repr = IpRepr::new(
  537. src_addr,
  538. dst_addr,
  539. IpProtocol::Udp,
  540. udp_repr.header_len() + payload.len(),
  541. hop_limit,
  542. );
  543. net_trace!(
  544. "sending {} octets to {} from port {}",
  545. payload.len(),
  546. ip_repr.dst_addr(),
  547. udp_repr.src_port
  548. );
  549. emit(cx, (ip_repr, udp_repr, payload))?;
  550. pq.retransmit_at = cx.now() + pq.delay;
  551. pq.delay = MAX_RETRANSMIT_DELAY.min(pq.delay * 2);
  552. return Ok(());
  553. }
  554. }
  555. // Nothing to dispatch
  556. Ok(())
  557. }
  558. pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt {
  559. self.queries
  560. .iter()
  561. .flatten()
  562. .filter_map(|q| match &q.state {
  563. State::Pending(pq) => Some(PollAt::Time(pq.retransmit_at)),
  564. State::Completed(_) => None,
  565. State::Failure => None,
  566. })
  567. .min()
  568. .unwrap_or(PollAt::Ingress)
  569. }
  570. }
  571. fn eq_names<'a>(
  572. mut a: impl Iterator<Item = wire::Result<&'a [u8]>>,
  573. mut b: impl Iterator<Item = wire::Result<&'a [u8]>>,
  574. ) -> wire::Result<bool> {
  575. loop {
  576. match (a.next(), b.next()) {
  577. // Handle errors
  578. (Some(Err(e)), _) => return Err(e),
  579. (_, Some(Err(e))) => return Err(e),
  580. // Both finished -> equal
  581. (None, None) => return Ok(true),
  582. // One finished before the other -> not equal
  583. (None, _) => return Ok(false),
  584. (_, None) => return Ok(false),
  585. // Got two labels, check if they're equal
  586. (Some(Ok(la)), Some(Ok(lb))) => {
  587. if la != lb {
  588. return Ok(false);
  589. }
  590. }
  591. }
  592. }
  593. }
  594. fn copy_name<'a, const N: usize>(
  595. dest: &mut Vec<u8, N>,
  596. name: impl Iterator<Item = wire::Result<&'a [u8]>>,
  597. ) -> Result<(), wire::Error> {
  598. dest.truncate(0);
  599. for label in name {
  600. let label = label?;
  601. dest.push(label.len() as u8).map_err(|_| wire::Error)?;
  602. dest.extend_from_slice(label).map_err(|_| wire::Error)?;
  603. }
  604. // Write terminator 0x00
  605. dest.push(0).map_err(|_| wire::Error)?;
  606. Ok(())
  607. }