|
@@ -1,3 +1,6 @@
|
|
|
+#[cfg(feature = "async")]
|
|
|
+use core::task::Waker;
|
|
|
+
|
|
|
use heapless::Vec;
|
|
|
use managed::ManagedSlice;
|
|
|
|
|
@@ -7,6 +10,9 @@ use crate::wire::dns::{Flags, Opcode, Packet, Question, Rcode, Record, RecordDat
|
|
|
use crate::wire::{IpAddress, IpProtocol, IpRepr, UdpRepr};
|
|
|
use crate::{Error, Result};
|
|
|
|
|
|
+#[cfg(feature = "async")]
|
|
|
+use super::WakerRegistration;
|
|
|
+
|
|
|
const DNS_PORT: u16 = 53;
|
|
|
const MAX_NAME_LEN: usize = 255;
|
|
|
const MAX_ADDRESS_COUNT: usize = 4;
|
|
@@ -22,6 +28,17 @@ const RETRANSMIT_TIMEOUT: Duration = Duration::from_millis(10_000); // Should ge
|
|
|
#[derive(Debug)]
|
|
|
pub struct DnsQuery {
|
|
|
state: State,
|
|
|
+
|
|
|
+ #[cfg(feature = "async")]
|
|
|
+ waker: WakerRegistration,
|
|
|
+}
|
|
|
+
|
|
|
+impl DnsQuery {
|
|
|
+ fn set_state(&mut self, state: State) {
|
|
|
+ self.state = state;
|
|
|
+ #[cfg(feature = "async")]
|
|
|
+ self.waker.wake();
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
#[derive(Debug)]
|
|
@@ -201,6 +218,8 @@ impl<'a> DnsSocket<'a> {
|
|
|
retransmit_at: Instant::ZERO,
|
|
|
server_idx: 0,
|
|
|
}),
|
|
|
+ #[cfg(feature = "async")]
|
|
|
+ waker: WakerRegistration::new(),
|
|
|
});
|
|
|
Ok(handle)
|
|
|
}
|
|
@@ -236,6 +255,13 @@ impl<'a> DnsSocket<'a> {
|
|
|
Ok(())
|
|
|
}
|
|
|
|
|
|
+ #[cfg(feature = "async")]
|
|
|
+ pub fn register_query_waker(&mut self, handle: QueryHandle, waker: &Waker) -> Result<()> {
|
|
|
+ let slot = self.queries.get_mut(handle.0).ok_or(Error::Illegal)?;
|
|
|
+ slot.as_mut().ok_or(Error::Illegal)?.waker.register(waker);
|
|
|
+ Ok(())
|
|
|
+ }
|
|
|
+
|
|
|
pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool {
|
|
|
udp_repr.src_port == DNS_PORT
|
|
|
&& self
|
|
@@ -287,8 +313,7 @@ impl<'a> DnsSocket<'a> {
|
|
|
|
|
|
if p.rcode() == Rcode::NXDomain {
|
|
|
net_trace!("rcode NXDomain");
|
|
|
-
|
|
|
- q.state = State::Failure;
|
|
|
+ q.set_state(State::Failure);
|
|
|
continue;
|
|
|
}
|
|
|
|
|
@@ -349,11 +374,11 @@ impl<'a> DnsSocket<'a> {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if addresses.is_empty() {
|
|
|
- q.state = State::Failure;
|
|
|
+ q.set_state(if addresses.is_empty() {
|
|
|
+ State::Failure
|
|
|
} else {
|
|
|
- q.state = State::Completed(CompletedQuery { addresses })
|
|
|
- }
|
|
|
+ State::Completed(CompletedQuery { addresses })
|
|
|
+ });
|
|
|
|
|
|
// If we get here, packet matched the current query, stop processing.
|
|
|
return Ok(());
|
|
@@ -395,14 +420,14 @@ impl<'a> DnsSocket<'a> {
|
|
|
// Check if we've run out of servers to try.
|
|
|
if pq.server_idx >= self.servers.len() {
|
|
|
net_trace!("already tried all servers.");
|
|
|
- q.state = State::Failure;
|
|
|
+ q.set_state(State::Failure);
|
|
|
continue;
|
|
|
}
|
|
|
|
|
|
// Check so the IP address is valid
|
|
|
if self.servers[pq.server_idx].is_unspecified() {
|
|
|
net_trace!("invalid unspecified DNS server addr.");
|
|
|
- q.state = State::Failure;
|
|
|
+ q.set_state(State::Failure);
|
|
|
continue;
|
|
|
}
|
|
|
|