Dario Nieuwenhuis 2 жил өмнө
parent
commit
74bc0b49d7
1 өөрчлөгдсөн 33 нэмэгдсэн , 8 устгасан
  1. 33 8
      src/socket/dns.rs

+ 33 - 8
src/socket/dns.rs

@@ -1,3 +1,6 @@
+#[cfg(feature = "async")]
+use core::task::Waker;
+
 use heapless::Vec;
 use managed::ManagedSlice;
 
@@ -7,6 +10,9 @@ use crate::wire::dns::{Flags, Opcode, Packet, Question, Rcode, Record, RecordDat
 use crate::wire::{IpAddress, IpProtocol, IpRepr, UdpRepr};
 use crate::{Error, Result};
 
+#[cfg(feature = "async")]
+use super::WakerRegistration;
+
 const DNS_PORT: u16 = 53;
 const MAX_NAME_LEN: usize = 255;
 const MAX_ADDRESS_COUNT: usize = 4;
@@ -22,6 +28,17 @@ const RETRANSMIT_TIMEOUT: Duration = Duration::from_millis(10_000); // Should ge
 #[derive(Debug)]
 pub struct DnsQuery {
     state: State,
+
+    #[cfg(feature = "async")]
+    waker: WakerRegistration,
+}
+
+impl DnsQuery {
+    fn set_state(&mut self, state: State) {
+        self.state = state;
+        #[cfg(feature = "async")]
+        self.waker.wake();
+    }
 }
 
 #[derive(Debug)]
@@ -201,6 +218,8 @@ impl<'a> DnsSocket<'a> {
                 retransmit_at: Instant::ZERO,
                 server_idx: 0,
             }),
+            #[cfg(feature = "async")]
+            waker: WakerRegistration::new(),
         });
         Ok(handle)
     }
@@ -236,6 +255,13 @@ impl<'a> DnsSocket<'a> {
         Ok(())
     }
 
+    #[cfg(feature = "async")]
+    pub fn register_query_waker(&mut self, handle: QueryHandle, waker: &Waker) -> Result<()> {
+        let slot = self.queries.get_mut(handle.0).ok_or(Error::Illegal)?;
+        slot.as_mut().ok_or(Error::Illegal)?.waker.register(waker);
+        Ok(())
+    }
+
     pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool {
         udp_repr.src_port == DNS_PORT
             && self
@@ -287,8 +313,7 @@ impl<'a> DnsSocket<'a> {
 
                 if p.rcode() == Rcode::NXDomain {
                     net_trace!("rcode NXDomain");
-
-                    q.state = State::Failure;
+                    q.set_state(State::Failure);
                     continue;
                 }
 
@@ -349,11 +374,11 @@ impl<'a> DnsSocket<'a> {
                     }
                 }
 
-                if addresses.is_empty() {
-                    q.state = State::Failure;
+                q.set_state(if addresses.is_empty() {
+                    State::Failure
                 } else {
-                    q.state = State::Completed(CompletedQuery { addresses })
-                }
+                    State::Completed(CompletedQuery { addresses })
+                });
 
                 // If we get here, packet matched the current query, stop processing.
                 return Ok(());
@@ -395,14 +420,14 @@ impl<'a> DnsSocket<'a> {
                 // Check if we've run out of servers to try.
                 if pq.server_idx >= self.servers.len() {
                     net_trace!("already tried all servers.");
-                    q.state = State::Failure;
+                    q.set_state(State::Failure);
                     continue;
                 }
 
                 // Check so the IP address is valid
                 if self.servers[pq.server_idx].is_unspecified() {
                     net_trace!("invalid unspecified DNS server addr.");
-                    q.state = State::Failure;
+                    q.set_state(State::Failure);
                     continue;
                 }