|
@@ -8,7 +8,7 @@ use crate::socket::{Context, PollAt};
|
|
|
use crate::time::{Duration, Instant};
|
|
|
use crate::wire::dns::{Flags, Opcode, Packet, Question, Rcode, Record, RecordData, Repr, Type};
|
|
|
use crate::wire::{self, IpAddress, IpProtocol, IpRepr, UdpRepr};
|
|
|
-use crate::{Error, Result};
|
|
|
+use crate::Error;
|
|
|
|
|
|
#[cfg(feature = "async")]
|
|
|
use super::WakerRegistration;
|
|
@@ -21,6 +21,25 @@ const RETRANSMIT_DELAY: Duration = Duration::from_millis(1_000);
|
|
|
const MAX_RETRANSMIT_DELAY: Duration = Duration::from_millis(10_000);
|
|
|
const RETRANSMIT_TIMEOUT: Duration = Duration::from_millis(10_000); // Should generally be 2-10 secs
|
|
|
|
|
|
+/// Error returned by [`Socket::start_query`]
|
|
|
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
|
|
+#[cfg_attr(feature = "defmt", derive(defmt::Format))]
|
|
|
+pub enum StartQueryError {
|
|
|
+ NoFreeSlot,
|
|
|
+ InvalidName,
|
|
|
+ NameTooLong,
|
|
|
+}
|
|
|
+
|
|
|
+/// Error returned by [`Socket::get_query_result`]
|
|
|
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
|
|
+#[cfg_attr(feature = "defmt", derive(defmt::Format))]
|
|
|
+pub enum GetQueryResultError {
|
|
|
+ /// Query is not done yet.
|
|
|
+ Pending,
|
|
|
+ /// Query failed.
|
|
|
+ Failed,
|
|
|
+}
|
|
|
+
|
|
|
/// State for an in-progress DNS query.
|
|
|
///
|
|
|
/// The only reason this struct is public is to allow the socket state
|
|
@@ -139,20 +158,20 @@ impl<'a> Socket<'a> {
|
|
|
self.hop_limit = hop_limit
|
|
|
}
|
|
|
|
|
|
- fn find_free_query(&mut self) -> Result<QueryHandle> {
|
|
|
+ fn find_free_query(&mut self) -> Option<QueryHandle> {
|
|
|
for (i, q) in self.queries.iter().enumerate() {
|
|
|
if q.is_none() {
|
|
|
- return Ok(QueryHandle(i));
|
|
|
+ return Some(QueryHandle(i));
|
|
|
}
|
|
|
}
|
|
|
|
|
|
match self.queries {
|
|
|
- ManagedSlice::Borrowed(_) => Err(Error::Exhausted),
|
|
|
+ ManagedSlice::Borrowed(_) => None,
|
|
|
#[cfg(any(feature = "std", feature = "alloc"))]
|
|
|
ManagedSlice::Owned(ref mut queries) => {
|
|
|
queries.push(None);
|
|
|
let index = queries.len() - 1;
|
|
|
- Ok(QueryHandle(index))
|
|
|
+ Some(QueryHandle(index))
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -162,12 +181,16 @@ impl<'a> Socket<'a> {
|
|
|
/// `name` is specified in human-friendly format, such as `"rust-lang.org"`.
|
|
|
/// It accepts names both with and without trailing dot, and they're treated
|
|
|
/// the same (there's no support for DNS search path).
|
|
|
- pub fn start_query(&mut self, cx: &mut Context, name: &str) -> Result<QueryHandle> {
|
|
|
+ pub fn start_query(
|
|
|
+ &mut self,
|
|
|
+ cx: &mut Context,
|
|
|
+ name: &str,
|
|
|
+ ) -> Result<QueryHandle, StartQueryError> {
|
|
|
let mut name = name.as_bytes();
|
|
|
|
|
|
if name.is_empty() {
|
|
|
net_trace!("invalid name: zero length");
|
|
|
- return Err(Error::Illegal);
|
|
|
+ return Err(StartQueryError::InvalidName);
|
|
|
}
|
|
|
|
|
|
// Remove trailing dot, if any
|
|
@@ -180,22 +203,26 @@ impl<'a> Socket<'a> {
|
|
|
for s in name.split(|&c| c == b'.') {
|
|
|
if s.len() > 255 {
|
|
|
net_trace!("invalid name: too long label");
|
|
|
- return Err(Error::Illegal);
|
|
|
+ return Err(StartQueryError::InvalidName);
|
|
|
}
|
|
|
if s.is_empty() {
|
|
|
net_trace!("invalid name: zero length label");
|
|
|
- return Err(Error::Illegal);
|
|
|
+ return Err(StartQueryError::InvalidName);
|
|
|
}
|
|
|
|
|
|
// Push label
|
|
|
- raw_name.push(s.len() as u8).map_err(|_| Error::Exhausted)?;
|
|
|
+ raw_name
|
|
|
+ .push(s.len() as u8)
|
|
|
+ .map_err(|_| StartQueryError::NameTooLong)?;
|
|
|
raw_name
|
|
|
.extend_from_slice(s)
|
|
|
- .map_err(|_| Error::Exhausted)?;
|
|
|
+ .map_err(|_| StartQueryError::NameTooLong)?;
|
|
|
}
|
|
|
|
|
|
// Push terminator.
|
|
|
- raw_name.push(0x00).map_err(|_| Error::Exhausted)?;
|
|
|
+ raw_name
|
|
|
+ .push(0x00)
|
|
|
+ .map_err(|_| StartQueryError::NameTooLong)?;
|
|
|
|
|
|
self.start_query_raw(cx, &raw_name)
|
|
|
}
|
|
@@ -204,12 +231,16 @@ impl<'a> Socket<'a> {
|
|
|
/// `b"\x09rust-lang\x03org\x00"`
|
|
|
///
|
|
|
/// You probably want to use [`start_query`] instead.
|
|
|
- pub fn start_query_raw(&mut self, cx: &mut Context, raw_name: &[u8]) -> Result<QueryHandle> {
|
|
|
- let handle = self.find_free_query()?;
|
|
|
+ pub fn start_query_raw(
|
|
|
+ &mut self,
|
|
|
+ cx: &mut Context,
|
|
|
+ raw_name: &[u8],
|
|
|
+ ) -> Result<QueryHandle, StartQueryError> {
|
|
|
+ let handle = self.find_free_query().ok_or(StartQueryError::NoFreeSlot)?;
|
|
|
|
|
|
self.queries[handle.0] = Some(DnsQuery {
|
|
|
state: State::Pending(PendingQuery {
|
|
|
- name: Vec::from_slice(raw_name).map_err(|_| Error::Exhausted)?,
|
|
|
+ name: Vec::from_slice(raw_name).map_err(|_| StartQueryError::NameTooLong)?,
|
|
|
type_: Type::A,
|
|
|
txid: cx.rand().rand_u16(),
|
|
|
port: cx.rand().rand_source_port(),
|
|
@@ -224,15 +255,21 @@ impl<'a> Socket<'a> {
|
|
|
Ok(handle)
|
|
|
}
|
|
|
|
|
|
+ /// Get the result of a query.
|
|
|
+ ///
|
|
|
+ /// If the query is completed, the query slot is automatically freed.
|
|
|
+ ///
|
|
|
+ /// # Panics
|
|
|
+ /// Panics if the QueryHandle corresponds to a free slot.
|
|
|
pub fn get_query_result(
|
|
|
&mut self,
|
|
|
handle: QueryHandle,
|
|
|
- ) -> Result<Vec<IpAddress, MAX_ADDRESS_COUNT>> {
|
|
|
- let slot = self.queries.get_mut(handle.0).ok_or(Error::Illegal)?;
|
|
|
- let q = slot.as_mut().ok_or(Error::Illegal)?;
|
|
|
+ ) -> Result<Vec<IpAddress, MAX_ADDRESS_COUNT>, GetQueryResultError> {
|
|
|
+ let slot = &mut self.queries[handle.0];
|
|
|
+ let q = slot.as_mut().unwrap();
|
|
|
match &mut q.state {
|
|
|
// Query is not done yet.
|
|
|
- State::Pending(_) => Err(Error::Exhausted),
|
|
|
+ State::Pending(_) => Err(GetQueryResultError::Pending),
|
|
|
// Query is done
|
|
|
State::Completed(q) => {
|
|
|
let res = q.addresses.clone();
|
|
@@ -241,25 +278,38 @@ impl<'a> Socket<'a> {
|
|
|
}
|
|
|
State::Failure => {
|
|
|
*slot = None; // Free up the slot for recycling.
|
|
|
- Err(Error::Unaddressable)
|
|
|
+ Err(GetQueryResultError::Failed)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- pub fn cancel_query(&mut self, handle: QueryHandle) -> Result<()> {
|
|
|
- let slot = self.queries.get_mut(handle.0).ok_or(Error::Illegal)?;
|
|
|
+ /// Cancels a query, freeing the slot.
|
|
|
+ ///
|
|
|
+ /// # Panics
|
|
|
+ ///
|
|
|
+ /// Panics if the QueryHandle corresponds to an already free slot.
|
|
|
+ pub fn cancel_query(&mut self, handle: QueryHandle) {
|
|
|
+ let slot = &mut self.queries[handle.0];
|
|
|
if slot.is_none() {
|
|
|
- return Err(Error::Illegal);
|
|
|
+ panic!("Canceling query in a free slot.")
|
|
|
}
|
|
|
*slot = None; // Free up the slot for recycling.
|
|
|
- Ok(())
|
|
|
}
|
|
|
|
|
|
+ /// Assign a waker to a query slot
|
|
|
+ ///
|
|
|
+ /// The waker will be woken when the query completes, either successfully or failed.
|
|
|
+ ///
|
|
|
+ /// # Panics
|
|
|
+ ///
|
|
|
+ /// Panics if the QueryHandle corresponds to an already free slot.
|
|
|
#[cfg(feature = "async")]
|
|
|
- pub fn register_query_waker(&mut self, handle: QueryHandle, waker: &Waker) -> Result<()> {
|
|
|
- let slot = self.queries.get_mut(handle.0).ok_or(Error::Illegal)?;
|
|
|
- slot.as_mut().ok_or(Error::Illegal)?.waker.register(waker);
|
|
|
- Ok(())
|
|
|
+ pub fn register_query_waker(&mut self, handle: QueryHandle, waker: &Waker) {
|
|
|
+ self.queries[handle.0]
|
|
|
+ .as_mut()
|
|
|
+ .unwrap()
|
|
|
+ .waker
|
|
|
+ .register(waker);
|
|
|
}
|
|
|
|
|
|
pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool {
|
|
@@ -424,9 +474,9 @@ impl<'a> Socket<'a> {
|
|
|
net_trace!("no query matched");
|
|
|
}
|
|
|
|
|
|
- pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<()>
|
|
|
+ pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<(), Error>
|
|
|
where
|
|
|
- F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<()>,
|
|
|
+ F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), Error>,
|
|
|
{
|
|
|
let hop_limit = self.hop_limit.unwrap_or(64);
|
|
|
|
|
@@ -566,18 +616,17 @@ fn eq_names<'a>(
|
|
|
fn copy_name<'a, const N: usize>(
|
|
|
dest: &mut Vec<u8, N>,
|
|
|
name: impl Iterator<Item = wire::Result<&'a [u8]>>,
|
|
|
-) -> Result<()> {
|
|
|
+) -> Result<(), wire::Error> {
|
|
|
dest.truncate(0);
|
|
|
|
|
|
for label in name {
|
|
|
let label = label?;
|
|
|
- dest.push(label.len() as u8).map_err(|_| Error::Truncated)?;
|
|
|
- dest.extend_from_slice(label)
|
|
|
- .map_err(|_| Error::Truncated)?;
|
|
|
+ dest.push(label.len() as u8).map_err(|_| wire::Error)?;
|
|
|
+ dest.extend_from_slice(label).map_err(|_| wire::Error)?;
|
|
|
}
|
|
|
|
|
|
// Write terminator 0x00
|
|
|
- dest.push(0).map_err(|_| Error::Truncated)?;
|
|
|
+ dest.push(0).map_err(|_| wire::Error)?;
|
|
|
|
|
|
Ok(())
|
|
|
}
|