Kaynağa Gözat

Don't overwrite the DNS servers list permanently

Benjamin Brittain 2 yıl önce
ebeveyn
işleme
c6e2d8af01
1 değiştirilmiş dosya ile 16 ekleme ve 14 silme
  1. 16 14
      src/socket/dns.rs

+ 16 - 14
src/socket/dns.rs

@@ -257,17 +257,6 @@ impl<'a> Socket<'a> {
     ) -> Result<QueryHandle, StartQueryError> {
         let handle = self.find_free_query().ok_or(StartQueryError::NoFreeSlot)?;
 
-        if mdns {
-            // as per RFC 6762 any DNS query ending in .local. MUST be sent as mdns
-            // so we internally overwrite the servers for any of those queries
-            self.update_servers(&[
-                #[cfg(feature = "proto-ipv6")]
-                MDNS_IPV6_ADDR,
-                #[cfg(feature = "proto-ipv4")]
-                MDNS_IPV4_ADDR,
-            ]);
-        }
-
         self.queries[handle.0] = Some(DnsQuery {
             state: State::Pending(PendingQuery {
                 name: Vec::from_slice(raw_name).map_err(|_| StartQueryError::NameTooLong)?,
@@ -514,6 +503,19 @@ impl<'a> Socket<'a> {
 
         for q in self.queries.iter_mut().flatten() {
             if let State::Pending(pq) = &mut q.state {
+                // As per RFC 6762 any DNS query ending in .local. MUST be sent as mdns
+                // so we internally overwrite the servers for any of those queries
+                // in this function.
+                let servers = if pq.mdns {
+                    &[
+                        #[cfg(feature = "proto-ipv6")]
+                        MDNS_IPV6_ADDR,
+                        #[cfg(feature = "proto-ipv4")]
+                        MDNS_IPV4_ADDR,
+                    ]
+                } else {
+                    self.servers.as_slice()
+                };
                 let timeout = if let Some(timeout) = pq.timeout_at {
                     timeout
                 } else {
@@ -533,14 +535,14 @@ impl<'a> Socket<'a> {
                     pq.server_idx += 1;
                 }
                 // Check if we've run out of servers to try.
-                if pq.server_idx >= self.servers.len() {
+                if pq.server_idx >= servers.len() {
                     net_trace!("already tried all servers.");
                     q.set_state(State::Failure);
                     continue;
                 }
 
                 // Check so the IP address is valid
-                if self.servers[pq.server_idx].is_unspecified() {
+                if servers[pq.server_idx].is_unspecified() {
                     net_trace!("invalid unspecified DNS server addr.");
                     q.set_state(State::Failure);
                     continue;
@@ -572,7 +574,7 @@ impl<'a> Socket<'a> {
                     dst_port,
                 };
 
-                let dst_addr = self.servers[pq.server_idx];
+                let dst_addr = servers[pq.server_idx];
                 let src_addr = cx.get_source_address(dst_addr).unwrap(); // TODO remove unwrap
                 let ip_repr = IpRepr::new(
                     src_addr,