dns.rs 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701
  1. #[cfg(feature = "async")]
  2. use core::task::Waker;
  3. use heapless::Vec;
  4. use managed::ManagedSlice;
  5. use crate::config::{DNS_MAX_NAME_SIZE, DNS_MAX_RESULT_COUNT, DNS_MAX_SERVER_COUNT};
  6. use crate::socket::{Context, PollAt};
  7. use crate::time::{Duration, Instant};
  8. use crate::wire::dns::{Flags, Opcode, Packet, Question, Rcode, Record, RecordData, Repr, Type};
  9. use crate::wire::{self, IpAddress, IpProtocol, IpRepr, UdpRepr};
  10. #[cfg(feature = "async")]
  11. use super::WakerRegistration;
  12. const DNS_PORT: u16 = 53;
  13. const MDNS_DNS_PORT: u16 = 5353;
  14. const RETRANSMIT_DELAY: Duration = Duration::from_millis(1_000);
  15. const MAX_RETRANSMIT_DELAY: Duration = Duration::from_millis(10_000);
  16. const RETRANSMIT_TIMEOUT: Duration = Duration::from_millis(10_000); // Should generally be 2-10 secs
  17. #[cfg(feature = "proto-ipv6")]
  18. #[allow(unused)]
  19. const MDNS_IPV6_ADDR: IpAddress = IpAddress::Ipv6(crate::wire::Ipv6Address([
  20. 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfb,
  21. ]));
  22. #[cfg(feature = "proto-ipv4")]
  23. #[allow(unused)]
  24. const MDNS_IPV4_ADDR: IpAddress = IpAddress::Ipv4(crate::wire::Ipv4Address([224, 0, 0, 251]));
  25. /// Error returned by [`Socket::start_query`]
  26. #[derive(Debug, PartialEq, Eq, Clone, Copy)]
  27. #[cfg_attr(feature = "defmt", derive(defmt::Format))]
  28. pub enum StartQueryError {
  29. NoFreeSlot,
  30. InvalidName,
  31. NameTooLong,
  32. }
  33. impl core::fmt::Display for StartQueryError {
  34. fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
  35. match self {
  36. StartQueryError::NoFreeSlot => write!(f, "No free slot"),
  37. StartQueryError::InvalidName => write!(f, "Invalid name"),
  38. StartQueryError::NameTooLong => write!(f, "Name too long"),
  39. }
  40. }
  41. }
  42. #[cfg(feature = "std")]
  43. impl std::error::Error for StartQueryError {}
  44. /// Error returned by [`Socket::get_query_result`]
  45. #[derive(Debug, PartialEq, Eq, Clone, Copy)]
  46. #[cfg_attr(feature = "defmt", derive(defmt::Format))]
  47. pub enum GetQueryResultError {
  48. /// Query is not done yet.
  49. Pending,
  50. /// Query failed.
  51. Failed,
  52. }
  53. impl core::fmt::Display for GetQueryResultError {
  54. fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
  55. match self {
  56. GetQueryResultError::Pending => write!(f, "Query is not done yet"),
  57. GetQueryResultError::Failed => write!(f, "Query failed"),
  58. }
  59. }
  60. }
  61. #[cfg(feature = "std")]
  62. impl std::error::Error for GetQueryResultError {}
  63. /// State for an in-progress DNS query.
  64. ///
  65. /// The only reason this struct is public is to allow the socket state
  66. /// to be allocated externally.
  67. #[derive(Debug)]
  68. pub struct DnsQuery {
  69. state: State,
  70. #[cfg(feature = "async")]
  71. waker: WakerRegistration,
  72. }
  73. impl DnsQuery {
  74. fn set_state(&mut self, state: State) {
  75. self.state = state;
  76. #[cfg(feature = "async")]
  77. self.waker.wake();
  78. }
  79. }
  80. #[derive(Debug)]
  81. #[allow(clippy::large_enum_variant)]
  82. enum State {
  83. Pending(PendingQuery),
  84. Completed(CompletedQuery),
  85. Failure,
  86. }
  87. #[derive(Debug)]
  88. struct PendingQuery {
  89. name: Vec<u8, DNS_MAX_NAME_SIZE>,
  90. type_: Type,
  91. port: u16, // UDP port (src for request, dst for response)
  92. txid: u16, // transaction ID
  93. timeout_at: Option<Instant>,
  94. retransmit_at: Instant,
  95. delay: Duration,
  96. server_idx: usize,
  97. mdns: MulticastDns,
  98. }
  99. #[derive(Debug)]
  100. pub enum MulticastDns {
  101. Disabled,
  102. #[cfg(feature = "socket-mdns")]
  103. Enabled,
  104. }
  105. #[derive(Debug)]
  106. struct CompletedQuery {
  107. addresses: Vec<IpAddress, DNS_MAX_RESULT_COUNT>,
  108. }
  109. /// A handle to an in-progress DNS query.
  110. #[derive(Clone, Copy)]
  111. pub struct QueryHandle(usize);
  112. /// A Domain Name System socket.
  113. ///
  114. /// A UDP socket is bound to a specific endpoint, and owns transmit and receive
  115. /// packet buffers.
  116. #[derive(Debug)]
  117. pub struct Socket<'a> {
  118. servers: Vec<IpAddress, DNS_MAX_SERVER_COUNT>,
  119. queries: ManagedSlice<'a, Option<DnsQuery>>,
  120. /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
  121. hop_limit: Option<u8>,
  122. }
  123. impl<'a> Socket<'a> {
  124. /// Create a DNS socket.
  125. ///
  126. /// # Panics
  127. ///
  128. /// Panics if `servers.len() > MAX_SERVER_COUNT`
  129. pub fn new<Q>(servers: &[IpAddress], queries: Q) -> Socket<'a>
  130. where
  131. Q: Into<ManagedSlice<'a, Option<DnsQuery>>>,
  132. {
  133. Socket {
  134. servers: Vec::from_slice(servers).unwrap(),
  135. queries: queries.into(),
  136. hop_limit: None,
  137. }
  138. }
  139. /// Update the list of DNS servers, will replace all existing servers
  140. ///
  141. /// # Panics
  142. ///
  143. /// Panics if `servers.len() > MAX_SERVER_COUNT`
  144. pub fn update_servers(&mut self, servers: &[IpAddress]) {
  145. self.servers = Vec::from_slice(servers).unwrap();
  146. }
  147. /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
  148. ///
  149. /// See also the [set_hop_limit](#method.set_hop_limit) method
  150. pub fn hop_limit(&self) -> Option<u8> {
  151. self.hop_limit
  152. }
  153. /// Set the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
  154. ///
  155. /// A socket without an explicitly set hop limit value uses the default [IANA recommended]
  156. /// value (64).
  157. ///
  158. /// # Panics
  159. ///
  160. /// This function panics if a hop limit value of 0 is given. See [RFC 1122 § 3.2.1.7].
  161. ///
  162. /// [IANA recommended]: https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml
  163. /// [RFC 1122 § 3.2.1.7]: https://tools.ietf.org/html/rfc1122#section-3.2.1.7
  164. pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
  165. // A host MUST NOT send a datagram with a hop limit value of 0
  166. if let Some(0) = hop_limit {
  167. panic!("the time-to-live value of a packet must not be zero")
  168. }
  169. self.hop_limit = hop_limit
  170. }
  171. fn find_free_query(&mut self) -> Option<QueryHandle> {
  172. for (i, q) in self.queries.iter().enumerate() {
  173. if q.is_none() {
  174. return Some(QueryHandle(i));
  175. }
  176. }
  177. match &mut self.queries {
  178. ManagedSlice::Borrowed(_) => None,
  179. #[cfg(feature = "alloc")]
  180. ManagedSlice::Owned(queries) => {
  181. queries.push(None);
  182. let index = queries.len() - 1;
  183. Some(QueryHandle(index))
  184. }
  185. }
  186. }
  187. /// Start a query.
  188. ///
  189. /// `name` is specified in human-friendly format, such as `"rust-lang.org"`.
  190. /// It accepts names both with and without trailing dot, and they're treated
  191. /// the same (there's no support for DNS search path).
  192. pub fn start_query(
  193. &mut self,
  194. cx: &mut Context,
  195. name: &str,
  196. query_type: Type,
  197. ) -> Result<QueryHandle, StartQueryError> {
  198. let mut name = name.as_bytes();
  199. if name.is_empty() {
  200. net_trace!("invalid name: zero length");
  201. return Err(StartQueryError::InvalidName);
  202. }
  203. // Remove trailing dot, if any
  204. if name[name.len() - 1] == b'.' {
  205. name = &name[..name.len() - 1];
  206. }
  207. let mut raw_name: Vec<u8, DNS_MAX_NAME_SIZE> = Vec::new();
  208. let mut mdns = MulticastDns::Disabled;
  209. #[cfg(feature = "socket-mdns")]
  210. if name.split(|&c| c == b'.').last().unwrap() == b"local" {
  211. net_trace!("Starting a mDNS query");
  212. mdns = MulticastDns::Enabled;
  213. }
  214. for s in name.split(|&c| c == b'.') {
  215. if s.len() > 63 {
  216. net_trace!("invalid name: too long label");
  217. return Err(StartQueryError::InvalidName);
  218. }
  219. if s.is_empty() {
  220. net_trace!("invalid name: zero length label");
  221. return Err(StartQueryError::InvalidName);
  222. }
  223. // Push label
  224. raw_name
  225. .push(s.len() as u8)
  226. .map_err(|_| StartQueryError::NameTooLong)?;
  227. raw_name
  228. .extend_from_slice(s)
  229. .map_err(|_| StartQueryError::NameTooLong)?;
  230. }
  231. // Push terminator.
  232. raw_name
  233. .push(0x00)
  234. .map_err(|_| StartQueryError::NameTooLong)?;
  235. self.start_query_raw(cx, &raw_name, query_type, mdns)
  236. }
  237. /// Start a query with a raw (wire-format) DNS name.
  238. /// `b"\x09rust-lang\x03org\x00"`
  239. ///
  240. /// You probably want to use [`start_query`] instead.
  241. pub fn start_query_raw(
  242. &mut self,
  243. cx: &mut Context,
  244. raw_name: &[u8],
  245. query_type: Type,
  246. mdns: MulticastDns,
  247. ) -> Result<QueryHandle, StartQueryError> {
  248. let handle = self.find_free_query().ok_or(StartQueryError::NoFreeSlot)?;
  249. self.queries[handle.0] = Some(DnsQuery {
  250. state: State::Pending(PendingQuery {
  251. name: Vec::from_slice(raw_name).map_err(|_| StartQueryError::NameTooLong)?,
  252. type_: query_type,
  253. txid: cx.rand().rand_u16(),
  254. port: cx.rand().rand_source_port(),
  255. delay: RETRANSMIT_DELAY,
  256. timeout_at: None,
  257. retransmit_at: Instant::ZERO,
  258. server_idx: 0,
  259. mdns,
  260. }),
  261. #[cfg(feature = "async")]
  262. waker: WakerRegistration::new(),
  263. });
  264. Ok(handle)
  265. }
  266. /// Get the result of a query.
  267. ///
  268. /// If the query is completed, the query slot is automatically freed.
  269. ///
  270. /// # Panics
  271. /// Panics if the QueryHandle corresponds to a free slot.
  272. pub fn get_query_result(
  273. &mut self,
  274. handle: QueryHandle,
  275. ) -> Result<Vec<IpAddress, DNS_MAX_RESULT_COUNT>, GetQueryResultError> {
  276. let slot = &mut self.queries[handle.0];
  277. let q = slot.as_mut().unwrap();
  278. match &mut q.state {
  279. // Query is not done yet.
  280. State::Pending(_) => Err(GetQueryResultError::Pending),
  281. // Query is done
  282. State::Completed(q) => {
  283. let res = q.addresses.clone();
  284. *slot = None; // Free up the slot for recycling.
  285. Ok(res)
  286. }
  287. State::Failure => {
  288. *slot = None; // Free up the slot for recycling.
  289. Err(GetQueryResultError::Failed)
  290. }
  291. }
  292. }
  293. /// Cancels a query, freeing the slot.
  294. ///
  295. /// # Panics
  296. ///
  297. /// Panics if the QueryHandle corresponds to an already free slot.
  298. pub fn cancel_query(&mut self, handle: QueryHandle) {
  299. let slot = &mut self.queries[handle.0];
  300. if slot.is_none() {
  301. panic!("Canceling query in a free slot.")
  302. }
  303. *slot = None; // Free up the slot for recycling.
  304. }
  305. /// Assign a waker to a query slot
  306. ///
  307. /// The waker will be woken when the query completes, either successfully or failed.
  308. ///
  309. /// # Panics
  310. ///
  311. /// Panics if the QueryHandle corresponds to an already free slot.
  312. #[cfg(feature = "async")]
  313. pub fn register_query_waker(&mut self, handle: QueryHandle, waker: &Waker) {
  314. self.queries[handle.0]
  315. .as_mut()
  316. .unwrap()
  317. .waker
  318. .register(waker);
  319. }
  320. pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool {
  321. (udp_repr.src_port == DNS_PORT
  322. && self
  323. .servers
  324. .iter()
  325. .any(|server| *server == ip_repr.src_addr()))
  326. || (udp_repr.src_port == MDNS_DNS_PORT)
  327. }
  328. pub(crate) fn process(
  329. &mut self,
  330. _cx: &mut Context,
  331. ip_repr: &IpRepr,
  332. udp_repr: &UdpRepr,
  333. payload: &[u8],
  334. ) {
  335. debug_assert!(self.accepts(ip_repr, udp_repr));
  336. let size = payload.len();
  337. net_trace!(
  338. "receiving {} octets from {:?}:{}",
  339. size,
  340. ip_repr.src_addr(),
  341. udp_repr.dst_port
  342. );
  343. let p = match Packet::new_checked(payload) {
  344. Ok(x) => x,
  345. Err(_) => {
  346. net_trace!("dns packet malformed");
  347. return;
  348. }
  349. };
  350. if p.opcode() != Opcode::Query {
  351. net_trace!("unwanted opcode {:?}", p.opcode());
  352. return;
  353. }
  354. if !p.flags().contains(Flags::RESPONSE) {
  355. net_trace!("packet doesn't have response bit set");
  356. return;
  357. }
  358. if p.question_count() != 1 {
  359. net_trace!("bad question count {:?}", p.question_count());
  360. return;
  361. }
  362. // Find pending query
  363. for q in self.queries.iter_mut().flatten() {
  364. if let State::Pending(pq) = &mut q.state {
  365. if udp_repr.dst_port != pq.port || p.transaction_id() != pq.txid {
  366. continue;
  367. }
  368. if p.rcode() == Rcode::NXDomain {
  369. net_trace!("rcode NXDomain");
  370. q.set_state(State::Failure);
  371. continue;
  372. }
  373. let payload = p.payload();
  374. let (mut payload, question) = match Question::parse(payload) {
  375. Ok(x) => x,
  376. Err(_) => {
  377. net_trace!("question malformed");
  378. return;
  379. }
  380. };
  381. if question.type_ != pq.type_ {
  382. net_trace!("question type mismatch");
  383. return;
  384. }
  385. match eq_names(p.parse_name(question.name), p.parse_name(&pq.name)) {
  386. Ok(true) => {}
  387. Ok(false) => {
  388. net_trace!("question name mismatch");
  389. return;
  390. }
  391. Err(_) => {
  392. net_trace!("dns question name malformed");
  393. return;
  394. }
  395. }
  396. let mut addresses = Vec::new();
  397. for _ in 0..p.answer_record_count() {
  398. let (payload2, r) = match Record::parse(payload) {
  399. Ok(x) => x,
  400. Err(_) => {
  401. net_trace!("dns answer record malformed");
  402. return;
  403. }
  404. };
  405. payload = payload2;
  406. match eq_names(p.parse_name(r.name), p.parse_name(&pq.name)) {
  407. Ok(true) => {}
  408. Ok(false) => {
  409. net_trace!("answer name mismatch: {:?}", r);
  410. continue;
  411. }
  412. Err(_) => {
  413. net_trace!("dns answer record name malformed");
  414. return;
  415. }
  416. }
  417. match r.data {
  418. #[cfg(feature = "proto-ipv4")]
  419. RecordData::A(addr) => {
  420. net_trace!("A: {:?}", addr);
  421. if addresses.push(addr.into()).is_err() {
  422. net_trace!("too many addresses in response, ignoring {:?}", addr);
  423. }
  424. }
  425. #[cfg(feature = "proto-ipv6")]
  426. RecordData::Aaaa(addr) => {
  427. net_trace!("AAAA: {:?}", addr);
  428. if addresses.push(addr.into()).is_err() {
  429. net_trace!("too many addresses in response, ignoring {:?}", addr);
  430. }
  431. }
  432. RecordData::Cname(name) => {
  433. net_trace!("CNAME: {:?}", name);
  434. // When faced with a CNAME, recursive resolvers are supposed to
  435. // resolve the CNAME and append the results for it.
  436. //
  437. // We update the query with the new name, so that we pick up the A/AAAA
  438. // records for the CNAME when we parse them later.
  439. // I believe it's mandatory the CNAME results MUST come *after* in the
  440. // packet, so it's enough to do one linear pass over it.
  441. if copy_name(&mut pq.name, p.parse_name(name)).is_err() {
  442. net_trace!("dns answer cname malformed");
  443. return;
  444. }
  445. }
  446. RecordData::Other(type_, data) => {
  447. net_trace!("unknown: {:?} {:?}", type_, data)
  448. }
  449. }
  450. }
  451. q.set_state(if addresses.is_empty() {
  452. State::Failure
  453. } else {
  454. State::Completed(CompletedQuery { addresses })
  455. });
  456. // If we get here, packet matched the current query, stop processing.
  457. return;
  458. }
  459. }
  460. // If we get here, packet matched with no query.
  461. net_trace!("no query matched");
  462. }
  463. pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
  464. where
  465. F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), E>,
  466. {
  467. let hop_limit = self.hop_limit.unwrap_or(64);
  468. for q in self.queries.iter_mut().flatten() {
  469. if let State::Pending(pq) = &mut q.state {
  470. // As per RFC 6762 any DNS query ending in .local. MUST be sent as mdns
  471. // so we internally overwrite the servers for any of those queries
  472. // in this function.
  473. let servers = match pq.mdns {
  474. #[cfg(feature = "socket-mdns")]
  475. MulticastDns::Enabled => &[
  476. #[cfg(feature = "proto-ipv6")]
  477. MDNS_IPV6_ADDR,
  478. #[cfg(feature = "proto-ipv4")]
  479. MDNS_IPV4_ADDR,
  480. ],
  481. MulticastDns::Disabled => self.servers.as_slice(),
  482. };
  483. let timeout = if let Some(timeout) = pq.timeout_at {
  484. timeout
  485. } else {
  486. let v = cx.now() + RETRANSMIT_TIMEOUT;
  487. pq.timeout_at = Some(v);
  488. v
  489. };
  490. // Check timeout
  491. if timeout < cx.now() {
  492. // DNS timeout
  493. pq.timeout_at = Some(cx.now() + RETRANSMIT_TIMEOUT);
  494. pq.retransmit_at = Instant::ZERO;
  495. pq.delay = RETRANSMIT_DELAY;
  496. // Try next server. We check below whether we've tried all servers.
  497. pq.server_idx += 1;
  498. }
  499. // Check if we've run out of servers to try.
  500. if pq.server_idx >= servers.len() {
  501. net_trace!("already tried all servers.");
  502. q.set_state(State::Failure);
  503. continue;
  504. }
  505. // Check so the IP address is valid
  506. if servers[pq.server_idx].is_unspecified() {
  507. net_trace!("invalid unspecified DNS server addr.");
  508. q.set_state(State::Failure);
  509. continue;
  510. }
  511. if pq.retransmit_at > cx.now() {
  512. // query is waiting for retransmit
  513. continue;
  514. }
  515. let repr = Repr {
  516. transaction_id: pq.txid,
  517. flags: Flags::RECURSION_DESIRED,
  518. opcode: Opcode::Query,
  519. question: Question {
  520. name: &pq.name,
  521. type_: pq.type_,
  522. },
  523. };
  524. let mut payload = [0u8; 512];
  525. let payload = &mut payload[..repr.buffer_len()];
  526. repr.emit(&mut Packet::new_unchecked(payload));
  527. let dst_port = match pq.mdns {
  528. #[cfg(feature = "socket-mdns")]
  529. MulticastDns::Enabled => MDNS_DNS_PORT,
  530. MulticastDns::Disabled => DNS_PORT,
  531. };
  532. let udp_repr = UdpRepr {
  533. src_port: pq.port,
  534. dst_port,
  535. };
  536. let dst_addr = servers[pq.server_idx];
  537. let src_addr = cx.get_source_address(&dst_addr).unwrap(); // TODO remove unwrap
  538. let ip_repr = IpRepr::new(
  539. src_addr,
  540. dst_addr,
  541. IpProtocol::Udp,
  542. udp_repr.header_len() + payload.len(),
  543. hop_limit,
  544. );
  545. net_trace!(
  546. "sending {} octets to {} from port {}",
  547. payload.len(),
  548. ip_repr.dst_addr(),
  549. udp_repr.src_port
  550. );
  551. emit(cx, (ip_repr, udp_repr, payload))?;
  552. pq.retransmit_at = cx.now() + pq.delay;
  553. pq.delay = MAX_RETRANSMIT_DELAY.min(pq.delay * 2);
  554. return Ok(());
  555. }
  556. }
  557. // Nothing to dispatch
  558. Ok(())
  559. }
  560. pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt {
  561. self.queries
  562. .iter()
  563. .flatten()
  564. .filter_map(|q| match &q.state {
  565. State::Pending(pq) => Some(PollAt::Time(pq.retransmit_at)),
  566. State::Completed(_) => None,
  567. State::Failure => None,
  568. })
  569. .min()
  570. .unwrap_or(PollAt::Ingress)
  571. }
  572. }
  573. fn eq_names<'a>(
  574. mut a: impl Iterator<Item = wire::Result<&'a [u8]>>,
  575. mut b: impl Iterator<Item = wire::Result<&'a [u8]>>,
  576. ) -> wire::Result<bool> {
  577. loop {
  578. match (a.next(), b.next()) {
  579. // Handle errors
  580. (Some(Err(e)), _) => return Err(e),
  581. (_, Some(Err(e))) => return Err(e),
  582. // Both finished -> equal
  583. (None, None) => return Ok(true),
  584. // One finished before the other -> not equal
  585. (None, _) => return Ok(false),
  586. (_, None) => return Ok(false),
  587. // Got two labels, check if they're equal
  588. (Some(Ok(la)), Some(Ok(lb))) => {
  589. if la != lb {
  590. return Ok(false);
  591. }
  592. }
  593. }
  594. }
  595. }
  596. fn copy_name<'a, const N: usize>(
  597. dest: &mut Vec<u8, N>,
  598. name: impl Iterator<Item = wire::Result<&'a [u8]>>,
  599. ) -> Result<(), wire::Error> {
  600. dest.truncate(0);
  601. for label in name {
  602. let label = label?;
  603. dest.push(label.len() as u8).map_err(|_| wire::Error)?;
  604. dest.extend_from_slice(label).map_err(|_| wire::Error)?;
  605. }
  606. // Write terminator 0x00
  607. dest.push(0).map_err(|_| wire::Error)?;
  608. Ok(())
  609. }