Browse Source

socket/dns: add own error enums for public API.

Dario Nieuwenhuis 2 years ago
parent
commit
43329e696e
2 changed files with 86 additions and 38 deletions
  1. 2 3
      examples/dns.rs
  2. 84 35
      src/socket/dns.rs

+ 2 - 3
examples/dns.rs

@@ -10,12 +10,11 @@ mod utils;
 use smoltcp::iface::{InterfaceBuilder, NeighborCache, Routes};
 use smoltcp::phy::Device;
 use smoltcp::phy::{wait as phy_wait, Medium};
-use smoltcp::socket::dns;
+use smoltcp::socket::dns::{self, GetQueryResultError};
 use smoltcp::time::Instant;
 use smoltcp::wire::{
     EthernetAddress, HardwareAddress, IpAddress, IpCidr, Ipv4Address, Ipv6Address,
 };
-use smoltcp::Error;
 use std::collections::BTreeMap;
 use std::os::unix::io::AsRawFd;
 
@@ -90,7 +89,7 @@ fn main() {
                 println!("Query done: {:?}", addrs);
                 break;
             }
-            Err(Error::Exhausted) => {} // not done yet
+            Err(GetQueryResultError::Pending) => {} // not done yet
             Err(e) => panic!("query failed: {:?}", e),
         }
 

+ 84 - 35
src/socket/dns.rs

@@ -8,7 +8,7 @@ use crate::socket::{Context, PollAt};
 use crate::time::{Duration, Instant};
 use crate::wire::dns::{Flags, Opcode, Packet, Question, Rcode, Record, RecordData, Repr, Type};
 use crate::wire::{self, IpAddress, IpProtocol, IpRepr, UdpRepr};
-use crate::{Error, Result};
+use crate::Error;
 
 #[cfg(feature = "async")]
 use super::WakerRegistration;
@@ -21,6 +21,25 @@ const RETRANSMIT_DELAY: Duration = Duration::from_millis(1_000);
 const MAX_RETRANSMIT_DELAY: Duration = Duration::from_millis(10_000);
 const RETRANSMIT_TIMEOUT: Duration = Duration::from_millis(10_000); // Should generally be 2-10 secs
 
+/// Error returned by [`Socket::start_query`]
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+#[cfg_attr(feature = "defmt", derive(defmt::Format))]
+pub enum StartQueryError {
+    NoFreeSlot,
+    InvalidName,
+    NameTooLong,
+}
+
+/// Error returned by [`Socket::get_query_result`]
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+#[cfg_attr(feature = "defmt", derive(defmt::Format))]
+pub enum GetQueryResultError {
+    /// Query is not done yet.
+    Pending,
+    /// Query failed.
+    Failed,
+}
+
 /// State for an in-progress DNS query.
 ///
 /// The only reason this struct is public is to allow the socket state
@@ -139,20 +158,20 @@ impl<'a> Socket<'a> {
         self.hop_limit = hop_limit
     }
 
-    fn find_free_query(&mut self) -> Result<QueryHandle> {
+    fn find_free_query(&mut self) -> Option<QueryHandle> {
         for (i, q) in self.queries.iter().enumerate() {
             if q.is_none() {
-                return Ok(QueryHandle(i));
+                return Some(QueryHandle(i));
             }
         }
 
         match self.queries {
-            ManagedSlice::Borrowed(_) => Err(Error::Exhausted),
+            ManagedSlice::Borrowed(_) => None,
             #[cfg(any(feature = "std", feature = "alloc"))]
             ManagedSlice::Owned(ref mut queries) => {
                 queries.push(None);
                 let index = queries.len() - 1;
-                Ok(QueryHandle(index))
+                Some(QueryHandle(index))
             }
         }
     }
@@ -162,12 +181,16 @@ impl<'a> Socket<'a> {
     /// `name` is specified in human-friendly format, such as `"rust-lang.org"`.
     /// It accepts names both with and without trailing dot, and they're treated
     /// the same (there's no support for DNS search path).
-    pub fn start_query(&mut self, cx: &mut Context, name: &str) -> Result<QueryHandle> {
+    pub fn start_query(
+        &mut self,
+        cx: &mut Context,
+        name: &str,
+    ) -> Result<QueryHandle, StartQueryError> {
         let mut name = name.as_bytes();
 
         if name.is_empty() {
             net_trace!("invalid name: zero length");
-            return Err(Error::Illegal);
+            return Err(StartQueryError::InvalidName);
         }
 
         // Remove trailing dot, if any
@@ -180,22 +203,26 @@ impl<'a> Socket<'a> {
         for s in name.split(|&c| c == b'.') {
             if s.len() > 255 {
                 net_trace!("invalid name: too long label");
-                return Err(Error::Illegal);
+                return Err(StartQueryError::InvalidName);
             }
             if s.is_empty() {
                 net_trace!("invalid name: zero length label");
-                return Err(Error::Illegal);
+                return Err(StartQueryError::InvalidName);
             }
 
             // Push label
-            raw_name.push(s.len() as u8).map_err(|_| Error::Exhausted)?;
+            raw_name
+                .push(s.len() as u8)
+                .map_err(|_| StartQueryError::NameTooLong)?;
             raw_name
                 .extend_from_slice(s)
-                .map_err(|_| Error::Exhausted)?;
+                .map_err(|_| StartQueryError::NameTooLong)?;
         }
 
         // Push terminator.
-        raw_name.push(0x00).map_err(|_| Error::Exhausted)?;
+        raw_name
+            .push(0x00)
+            .map_err(|_| StartQueryError::NameTooLong)?;
 
         self.start_query_raw(cx, &raw_name)
     }
@@ -204,12 +231,16 @@ impl<'a> Socket<'a> {
     /// `b"\x09rust-lang\x03org\x00"`
     ///
     /// You probably want to use [`start_query`] instead.
-    pub fn start_query_raw(&mut self, cx: &mut Context, raw_name: &[u8]) -> Result<QueryHandle> {
-        let handle = self.find_free_query()?;
+    pub fn start_query_raw(
+        &mut self,
+        cx: &mut Context,
+        raw_name: &[u8],
+    ) -> Result<QueryHandle, StartQueryError> {
+        let handle = self.find_free_query().ok_or(StartQueryError::NoFreeSlot)?;
 
         self.queries[handle.0] = Some(DnsQuery {
             state: State::Pending(PendingQuery {
-                name: Vec::from_slice(raw_name).map_err(|_| Error::Exhausted)?,
+                name: Vec::from_slice(raw_name).map_err(|_| StartQueryError::NameTooLong)?,
                 type_: Type::A,
                 txid: cx.rand().rand_u16(),
                 port: cx.rand().rand_source_port(),
@@ -224,15 +255,21 @@ impl<'a> Socket<'a> {
         Ok(handle)
     }
 
+    /// Get the result of a query.
+    ///
+    /// If the query is completed, the query slot is automatically freed.
+    ///
+    /// # Panics
+    /// Panics if the QueryHandle corresponds to a free slot.
     pub fn get_query_result(
         &mut self,
         handle: QueryHandle,
-    ) -> Result<Vec<IpAddress, MAX_ADDRESS_COUNT>> {
-        let slot = self.queries.get_mut(handle.0).ok_or(Error::Illegal)?;
-        let q = slot.as_mut().ok_or(Error::Illegal)?;
+    ) -> Result<Vec<IpAddress, MAX_ADDRESS_COUNT>, GetQueryResultError> {
+        let slot = &mut self.queries[handle.0];
+        let q = slot.as_mut().unwrap();
         match &mut q.state {
             // Query is not done yet.
-            State::Pending(_) => Err(Error::Exhausted),
+            State::Pending(_) => Err(GetQueryResultError::Pending),
             // Query is done
             State::Completed(q) => {
                 let res = q.addresses.clone();
@@ -241,25 +278,38 @@ impl<'a> Socket<'a> {
             }
             State::Failure => {
                 *slot = None; // Free up the slot for recycling.
-                Err(Error::Unaddressable)
+                Err(GetQueryResultError::Failed)
             }
         }
     }
 
-    pub fn cancel_query(&mut self, handle: QueryHandle) -> Result<()> {
-        let slot = self.queries.get_mut(handle.0).ok_or(Error::Illegal)?;
+    /// Cancels a query, freeing the slot.
+    ///
+    /// # Panics
+    ///
+    /// Panics if the QueryHandle corresponds to an already free slot.
+    pub fn cancel_query(&mut self, handle: QueryHandle) {
+        let slot = &mut self.queries[handle.0];
         if slot.is_none() {
-            return Err(Error::Illegal);
+            panic!("Canceling query in a free slot.")
         }
         *slot = None; // Free up the slot for recycling.
-        Ok(())
     }
 
+    /// Assign a waker to a query slot
+    ///
+    /// The waker will be woken when the query completes, either successfully or failed.
+    ///
+    /// # Panics
+    ///
+    /// Panics if the QueryHandle corresponds to an already free slot.
     #[cfg(feature = "async")]
-    pub fn register_query_waker(&mut self, handle: QueryHandle, waker: &Waker) -> Result<()> {
-        let slot = self.queries.get_mut(handle.0).ok_or(Error::Illegal)?;
-        slot.as_mut().ok_or(Error::Illegal)?.waker.register(waker);
-        Ok(())
+    pub fn register_query_waker(&mut self, handle: QueryHandle, waker: &Waker) {
+        self.queries[handle.0]
+            .as_mut()
+            .unwrap()
+            .waker
+            .register(waker);
     }
 
     pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool {
@@ -424,9 +474,9 @@ impl<'a> Socket<'a> {
         net_trace!("no query matched");
     }
 
-    pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<()>
+    pub(crate) fn dispatch<F>(&mut self, cx: &mut Context, emit: F) -> Result<(), Error>
     where
-        F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<()>,
+        F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), Error>,
     {
         let hop_limit = self.hop_limit.unwrap_or(64);
 
@@ -566,18 +616,17 @@ fn eq_names<'a>(
 fn copy_name<'a, const N: usize>(
     dest: &mut Vec<u8, N>,
     name: impl Iterator<Item = wire::Result<&'a [u8]>>,
-) -> Result<()> {
+) -> Result<(), wire::Error> {
     dest.truncate(0);
 
     for label in name {
         let label = label?;
-        dest.push(label.len() as u8).map_err(|_| Error::Truncated)?;
-        dest.extend_from_slice(label)
-            .map_err(|_| Error::Truncated)?;
+        dest.push(label.len() as u8).map_err(|_| wire::Error)?;
+        dest.extend_from_slice(label).map_err(|_| wire::Error)?;
     }
 
     // Write terminator 0x00
-    dest.push(0).map_err(|_| Error::Truncated)?;
+    dest.push(0).map_err(|_| wire::Error)?;
 
     Ok(())
 }