Quellcode durchsuchen

Merge pull request #51 from sanxiyn/search-list

Implement search option in resolv.conf
Benjamin Sago vor 4 Jahren
Ursprung
Commit
0d0baa8df2
7 geänderte Dateien mit 118 neuen und 65 gelöschten Zeilen
  1. 1 1
      dns/src/record/opt.rs
  2. 12 0
      dns/src/strings.rs
  3. 1 3
      src/connect.rs
  4. 24 15
      src/main.rs
  5. 9 9
      src/options.rs
  6. 14 13
      src/requests.rs
  7. 57 24
      src/resolve.rs

+ 1 - 1
dns/src/record/opt.rs

@@ -30,7 +30,7 @@ use crate::wire::*;
 ///
 /// - [RFC 6891](https://tools.ietf.org/html/rfc6891) — Extension Mechanisms
 ///   for DNS (April 2013)
-#[derive(PartialEq, Debug)]
+#[derive(PartialEq, Debug, Clone)]
 pub struct OPT {
 
     /// The maximum size of a UDP packet that the client supports.

+ 12 - 0
dns/src/strings.rs

@@ -63,6 +63,18 @@ impl Labels {
 
         Ok(Self { segments })
     }
+
+    /// Returns the number of segments.
+    pub fn len(&self) -> usize {
+        self.segments.len()
+    }
+
+    /// Returns a new set of labels concatenating two names.
+    pub fn extend(&self, other: &Self) -> Self {
+        let mut segments = self.segments.clone();
+        segments.extend_from_slice(&other.segments);
+        Self { segments }
+    }
 }
 
 impl fmt::Display for Labels {

+ 1 - 3
src/connect.rs

@@ -2,8 +2,6 @@
 
 use dns_transport::*;
 
-use crate::resolve::Nameserver;
-
 
 /// A **transport type** creates a `Transport` that determines which protocols
 /// should be used to send and receive DNS wire data over the network.
@@ -34,7 +32,7 @@ pub enum TransportType {
 impl TransportType {
 
     /// Creates a boxed `Transport` depending on the transport type.
-    pub fn make_transport(self, ns: Nameserver) -> Box<dyn Transport> {
+    pub fn make_transport(self, ns: String) -> Box<dyn Transport> {
         match self {
             Self::Automatic  => Box::new(AutoTransport::new(ns)),
             Self::UDP        => Box::new(UdpTransport::new(ns)),

+ 24 - 15
src/main.rs

@@ -107,22 +107,31 @@ fn run(Options { requests, format, measure_time }: Options) -> i32 {
     let timer = if measure_time { Some(Instant::now()) } else { None };
 
     let mut errored = false;
-    for (request, transport) in requests.generate() {
-        let result = transport.send(&request);
-
-        match result {
-            Ok(mut response) => {
-                if ! should_show_opt {
-                    response.answers.retain(dns::Answer::is_standard);
-                    response.authorities.retain(dns::Answer::is_standard);
-                    response.additionals.retain(dns::Answer::is_standard);
+    for (request_list, transport) in requests.generate() {
+        let request_list_len = request_list.len();
+        for (i, request) in request_list.into_iter().enumerate() {
+            let result = transport.send(&request);
+
+            match result {
+                Ok(mut response) => {
+                    if response.flags.error_code.is_some() && i != request_list_len - 1 {
+                        continue;
+                    }
+
+                    if ! should_show_opt {
+                        response.answers.retain(dns::Answer::is_standard);
+                        response.authorities.retain(dns::Answer::is_standard);
+                        response.additionals.retain(dns::Answer::is_standard);
+                    }
+
+                    responses.push(response);
+                    break;
+                }
+                Err(e) => {
+                    format.print_error(e);
+                    errored = true;
+                    break;
                 }
-
-                responses.push(response);
-            }
-            Err(e) => {
-                format.print_error(e);
-                errored = true;
             }
         }
     }

+ 9 - 9
src/options.rs

@@ -194,7 +194,7 @@ impl Inputs {
     }
 
     fn add_nameserver(&mut self, input: &str) -> Result<(), OptionsError> {
-        self.resolvers.push(Resolver::Specified(input.into()));
+        self.resolvers.push(Resolver::specified(input.into()));
         Ok(())
     }
 
@@ -259,7 +259,7 @@ impl Inputs {
         }
 
         if self.resolvers.is_empty() {
-            self.resolvers.push(Resolver::SystemDefault);
+            self.resolvers.push(Resolver::system_default());
         }
 
         if self.transport_types.is_empty() {
@@ -492,7 +492,7 @@ mod test {
                 domains:         vec![ /* No domains by default */ ],
                 types:           vec![ qtype!(A) ],
                 classes:         vec![ QClass::IN ],
-                resolvers:       vec![ Resolver::SystemDefault ],
+                resolvers:       vec![ Resolver::system_default() ],
                 transport_types: vec![ TransportType::Automatic ],
             }
         }
@@ -587,7 +587,7 @@ mod test {
         let options = Options::getopts(&[ "lookup.dog", "@1.1.1.1" ]).unwrap();
         assert_eq!(options.requests.inputs, Inputs {
             domains:    vec![ Labels::encode("lookup.dog").unwrap() ],
-            resolvers:  vec![ Resolver::Specified("1.1.1.1".into()) ],
+            resolvers:  vec![ Resolver::specified("1.1.1.1".into()) ],
             .. Inputs::fallbacks()
         });
     }
@@ -609,7 +609,7 @@ mod test {
             domains:    vec![ Labels::encode("lookup.dog").unwrap() ],
             classes:    vec![ QClass::CH ],
             types:      vec![ qtype!(NS) ],
-            resolvers:  vec![ Resolver::Specified("1.1.1.1".into()) ],
+            resolvers:  vec![ Resolver::specified("1.1.1.1".into()) ],
             .. Inputs::fallbacks()
         });
     }
@@ -621,7 +621,7 @@ mod test {
             domains:    vec![ Labels::encode("lookup.dog").unwrap() ],
             classes:    vec![ QClass::CH ],
             types:      vec![ qtype!(SOA) ],
-            resolvers:  vec![ Resolver::Specified("1.1.1.1".into()) ],
+            resolvers:  vec![ Resolver::specified("1.1.1.1".into()) ],
             .. Inputs::fallbacks()
         });
     }
@@ -653,7 +653,7 @@ mod test {
             domains:    vec![ Labels::encode("lookup.dog").unwrap() ],
             classes:    vec![ QClass::CH ],
             types:      vec![ qtype!(SOA) ],
-            resolvers:  vec![ Resolver::Specified("1.1.1.1".into()) ],
+            resolvers:  vec![ Resolver::specified("1.1.1.1".into()) ],
             .. Inputs::fallbacks()
         });
     }
@@ -674,8 +674,8 @@ mod test {
         let options = Options::getopts(&[ "lookup.dog", "--nameserver", "1.1.1.1", "--nameserver", "1.0.0.1" ]).unwrap();
         assert_eq!(options.requests.inputs, Inputs {
             domains:    vec![ Labels::encode("lookup.dog").unwrap() ],
-            resolvers:  vec![ Resolver::Specified("1.1.1.1".into()),
-                              Resolver::Specified("1.0.0.1".into()), ],
+            resolvers:  vec![ Resolver::specified("1.1.1.1".into()),
+                              Resolver::specified("1.0.0.1".into()), ],
             .. Inputs::fallbacks()
         });
     }

+ 14 - 13
src/requests.rs

@@ -81,21 +81,16 @@ pub enum UseEDNS {
 
 impl RequestGenerator {
 
-    /// Iterate through the inputs matrix, returning pairs of DNS requests and
-    /// the details of the transport to send them down.
-    pub fn generate(self) -> Vec<(dns::Request, Box<dyn dns_transport::Transport>)> {
-        let nameservers = self.inputs.resolvers.into_iter()
-                              .map(|e| e.lookup().expect("Failed to get nameserver").expect("No nameserver found"))
-                              .collect::<Vec<_>>();
-
+    /// Iterate through the inputs matrix, returning pairs of DNS request list
+    /// and the details of the transport to send them down.
+    pub fn generate(self) -> Vec<(Vec<dns::Request>, Box<dyn dns_transport::Transport>)> {
         let mut requests = Vec::new();
         for domain in &self.inputs.domains {
             for qtype in self.inputs.types.iter().copied() {
                 for qclass in self.inputs.classes.iter().copied() {
-                    for nameserver in &nameservers {
+                    for resolver in &self.inputs.resolvers {
                         for transport_type in &self.inputs.transport_types {
 
-                            let transaction_id = self.txid_generator.generate();
                             let mut flags = dns::Flags::query();
                             self.protocol_tweaks.set_request_flags(&mut flags);
 
@@ -106,11 +101,17 @@ impl RequestGenerator {
                                 additional = Some(opt);
                             }
 
-                            let query = dns::Query { qname: domain.clone(), qtype, qclass };
-                            let request = dns::Request { transaction_id, flags, query, additional };
+                            let nameserver = resolver.nameserver();
+                            let transport = transport_type.make_transport(nameserver);
 
-                            let transport = transport_type.make_transport(nameserver.clone());
-                            requests.push((request, transport));
+                            let mut request_list = Vec::new();
+                            for qname in resolver.name_list(domain) {
+                                let transaction_id = self.txid_generator.generate();
+                                let query = dns::Query { qname, qtype, qclass };
+                                let request = dns::Request { transaction_id, flags, query, additional: additional.clone() };
+                                request_list.push(request);
+                            }
+                            requests.push((request_list, transport));
                         }
                     }
                 }

+ 57 - 24
src/resolve.rs

@@ -4,31 +4,58 @@ use std::io;
 
 use log::*;
 
+use dns::Labels;
 
-/// A **resolver** is used to obtain the IP address of the server we should
-/// send DNS requests to.
+
+/// A **resolver** knows the address of the server we should
+/// send DNS requests to, and the search list for name lookup.
 #[derive(PartialEq, Debug)]
-pub enum Resolver {
+pub struct Resolver {
 
-    /// Read the list of nameservers from the system, and use that.
-    SystemDefault,
+    /// The address of the name server.
+    pub nameserver: String,
 
-    // Use a specific nameserver specified by the user.
-    Specified(Nameserver),
+    /// The search list for name lookup.
+    pub search_list: Vec<String>,
 }
 
-pub type Nameserver = String;
-
 impl Resolver {
 
-    /// Returns a nameserver that queries should be sent to, possibly by
-    /// obtaining one based on the system, returning an error if there was a
-    /// problem looking one up.
-    pub fn lookup(self) -> io::Result<Option<Nameserver>> {
-        match self {
-            Self::Specified(ns)  => Ok(Some(ns)),
-            Self::SystemDefault  => system_nameservers(),
+    /// Returns a resolver with the specified nameserver and an empty
+    /// search list.
+    pub fn specified(nameserver: String) -> Self {
+        let search_list = Vec::new();
+        Self { nameserver, search_list }
+    }
+
+    /// Returns a resolver that is default for the system.
+    pub fn system_default() -> Self {
+        let (nameserver_opt, search_list) = system_nameservers().expect("Failed to get nameserver");
+        let nameserver = nameserver_opt.expect("No nameserver found");
+        Self { nameserver, search_list }
+    }
+
+    /// Returns a nameserver that queries should be sent to.
+    pub fn nameserver(&self) -> String {
+        self.nameserver.clone()
+    }
+
+    /// Returns a sequence of names to be queried, taking into account
+    /// of the search list.
+    pub fn name_list(&self, name: &Labels) -> Vec<Labels> {
+        let mut list = Vec::new();
+        if name.len() > 1 {
+            list.push(name.clone());
+            return list;
+        }
+        for search in &self.search_list {
+            match Labels::encode(search) {
+                Ok(suffix) => list.push(name.extend(&suffix)),
+                Err(_) => panic!("Invalid search list {}", search),
+            }
         }
+        list.push(name.clone());
+        list
     }
 }
 
@@ -38,7 +65,7 @@ impl Resolver {
 /// Returns an error if there’s a problem reading the file, or `None` if no
 /// nameserver is specified in the file.
 #[cfg(unix)]
-fn system_nameservers() -> io::Result<Option<Nameserver>> {
+fn system_nameservers() -> io::Result<(Option<String>, Vec<String>)> {
     use std::io::{BufRead, BufReader};
     use std::fs::File;
 
@@ -46,6 +73,7 @@ fn system_nameservers() -> io::Result<Option<Nameserver>> {
     let reader = BufReader::new(f);
 
     let mut nameservers = Vec::new();
+    let mut search_list = Vec::new();
     for line in reader.lines() {
         let line = line?;
 
@@ -58,21 +86,26 @@ fn system_nameservers() -> io::Result<Option<Nameserver>> {
                 Err(e)  => warn!("Failed to parse nameserver line {:?}: {}", line, e),
             }
         }
+
+        if let Some(search_str) = line.strip_prefix("search ") {
+            search_list.clear();
+            search_list.extend(search_str.split_ascii_whitespace().map(|s| s.into()));
+        }
     }
 
-    Ok(nameservers.first().cloned())
+    Ok((nameservers.first().cloned(), search_list))
 }
 
 
 /// Looks up the system default nameserver on Windows, by iterating through
 /// the list of network adapters and returning the first nameserver it finds.
 #[cfg(windows)]
-fn system_nameservers() -> io::Result<Option<Nameserver>> {
+fn system_nameservers() -> io::Result<(Option<String>, Vec<String>)> {
     let adapters = match ipconfig::get_adapters() {
         Ok(a) => a,
         Err(e) => {
             warn!("Error getting network adapters: {}", e);
-            return Ok(None);
+            return Ok((None, Vec::new()));
         }
     };
 
@@ -83,20 +116,20 @@ fn system_nameservers() -> io::Result<Option<Nameserver>> {
             // TODO: This will need to be changed for IPv6 support.
             if dns_server.is_ipv4() {
                 debug!("Found first nameserver {:?}", dns_server);
-                return Ok(Some(dns_server.to_string()));
+                return Ok((Some(dns_server.to_string()), Vec::new()));
             }
         }
     }
 
     warn!("No nameservers available");
-    return Ok(None)
+    return Ok((None, Vec::new()))
 }
 
 
 /// The fall-back system default nameserver determinator that is not very
 /// determined as it returns nothing without actually checking anything.
 #[cfg(all(not(unix), not(windows)))]
-fn system_nameservers() -> io::Result<Option<Nameserver>> {
+fn system_nameservers() -> io::Result<(Option<String>, Vec<String>)> {
     warn!("Unable to fetch default nameservers on this platform.");
-    Ok(None)
+    Ok((None, Vec::new()))
 }