瀏覽代碼

Fix: Prevent panic in DNS Socket when server list exceeds max count

- Truncate the servers list to DNS_MAX_SERVER_COUNT to prevent panics.
- Ensure only the first `DNS_MAX_SERVER_COUNT` servers are used when constructing the `Socket`.
- This prevents overflow issues when the provided server list is larger than the allowed maximum.
Jamie Bird 6 月之前
父節點
當前提交
4739cc7fc6
共有 1 個文件被更改,包括 12 次插入8 次删除
  1. 12 8
      src/socket/dns.rs

+ 12 - 8
src/socket/dns.rs

@@ -1,3 +1,4 @@
+use core::cmp::min;
 #[cfg(feature = "async")]
 use core::task::Waker;
 
@@ -149,15 +150,15 @@ pub struct Socket<'a> {
 impl<'a> Socket<'a> {
     /// Create a DNS socket.
     ///
-    /// # Panics
-    ///
-    /// Panics if `servers.len() > MAX_SERVER_COUNT`
+    /// Truncates the server list if `servers.len() > MAX_SERVER_COUNT`
     pub fn new<Q>(servers: &[IpAddress], queries: Q) -> Socket<'a>
     where
         Q: Into<ManagedSlice<'a, Option<DnsQuery>>>,
     {
+        let truncated_servers = &servers[..min(servers.len(), DNS_MAX_SERVER_COUNT)];
+
         Socket {
-            servers: Vec::from_slice(servers).unwrap(),
+            servers: Vec::from_slice(truncated_servers).unwrap(),
             queries: queries.into(),
             hop_limit: None,
         }
@@ -165,11 +166,14 @@ impl<'a> Socket<'a> {
 
     /// Update the list of DNS servers, will replace all existing servers
     ///
-    /// # Panics
-    ///
-    /// Panics if `servers.len() > MAX_SERVER_COUNT`
+    /// Truncates the server list if `servers.len() > MAX_SERVER_COUNT`
     pub fn update_servers(&mut self, servers: &[IpAddress]) {
-        self.servers = Vec::from_slice(servers).unwrap();
+        if servers.len() > DNS_MAX_SERVER_COUNT {
+            net_trace!("Max DNS Servers exceeded. Increase MAX_SERVER_COUNT");
+            self.servers = Vec::from_slice(&servers[..DNS_MAX_SERVER_COUNT]).unwrap();
+        } else {
+            self.servers = Vec::from_slice(servers).unwrap();
+        }
     }
 
     /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.