浏览代码

Make incorrect label lengths a hard protocol error

This commit makes all record types properly check the length of the data after they have read it, returning an error if there is a mismatch. Now, no record types ignore their length argument. Some of them used to log a warning, but now they all do, as well as returning the error.

Doing this required changing the read_labels function so that it returns the number of bytes read without recursing, as this number is not necessarily the same as the number of bytes in the string plus one (for the null terminator).

For clarity, the WrongLength error case has been split in two, to tell the difference between where the data disagrees with the DNS protocol, and where the data disagrees with the length field in the packet.
Benjamin Sago 4 年之前
父节点
当前提交
6bd4bd1de4

+ 5 - 5
dns/src/record/a.rs

@@ -34,8 +34,8 @@ impl Wire for A {
             Ok(A { address })
         }
         else {
-            warn!("Received non-four length -> {:?}", buf.len());
-            Err(WireError::WrongLength { expected: 4, got: buf.len() as u16 })
+            warn!("Length is incorrect (record length {:?}, but should be four)", len);
+            Err(WireError::WrongRecordLength { expected: 4, got: buf.len() as u16 })
         }
     }
 }
@@ -62,7 +62,7 @@ mod test {
         ];
 
         assert_eq!(A::read(buf.len() as _, &mut Cursor::new(buf)),
-                   Err(WireError::WrongLength { expected: 4, got: 3 }));
+                   Err(WireError::WrongRecordLength { expected: 4, got: 3 }));
     }
 
     #[test]
@@ -73,12 +73,12 @@ mod test {
         ];
 
         assert_eq!(A::read(buf.len() as _, &mut Cursor::new(buf)),
-                   Err(WireError::WrongLength { expected: 4, got: 5 }));
+                   Err(WireError::WrongRecordLength { expected: 4, got: 5 }));
     }
 
     #[test]
     fn empty() {
         assert_eq!(A::read(0, &mut Cursor::new(&[])),
-                   Err(WireError::WrongLength { expected: 4, got: 0 }));
+                   Err(WireError::WrongRecordLength { expected: 4, got: 0 }));
     }
 }

+ 5 - 5
dns/src/record/aaaa.rs

@@ -36,8 +36,8 @@ impl Wire for AAAA {
             Ok(AAAA { address })
         }
         else {
-            warn!("Received non-sixteen length -> {:?}", buf.len());
-            Err(WireError::WrongLength { expected: 16, got: buf.len() as u16 })
+            warn!("Length is incorrect (record length {:?}, but should be sixteen)", len);
+            Err(WireError::WrongRecordLength { expected: 16, got: buf.len() as u16 })
         }
     }
 }
@@ -67,7 +67,7 @@ mod test {
         ];
 
         assert_eq!(AAAA::read(buf.len() as _, &mut Cursor::new(buf)),
-                   Err(WireError::WrongLength { expected: 16, got: 17 }));
+                   Err(WireError::WrongRecordLength { expected: 16, got: 17 }));
     }
 
     #[test]
@@ -77,12 +77,12 @@ mod test {
         ];
 
         assert_eq!(AAAA::read(buf.len() as _, &mut Cursor::new(buf)),
-                   Err(WireError::WrongLength { expected: 16, got: 5 }));
+                   Err(WireError::WrongRecordLength { expected: 16, got: 5 }));
     }
 
     #[test]
     fn empty() {
         assert_eq!(AAAA::read(0, &mut Cursor::new(&[])),
-                   Err(WireError::WrongLength { expected: 16, got: 0 }));
+                   Err(WireError::WrongRecordLength { expected: 16, got: 0 }));
     }
 }

+ 10 - 1
dns/src/record/caa.rs

@@ -56,7 +56,16 @@ impl Wire for CAA {
         let value = String::from_utf8_lossy(&value_buf).to_string();
         trace!("Parsed value -> {:?}", value);
 
-        Ok(CAA { critical, tag, value })
+        let got_len = 1 + 1 + u16::from(tag_length) + remaining_length;
+        if len == got_len {
+            // This one’s a little weird, because remaining_len is based on len
+            trace!("Length is correct");
+            Ok(CAA { critical, tag, value })
+        }
+        else {
+            warn!("Length is incorrect (record length {:?}, flags plus tag plus data length {:?}", len, got_len);
+            Err(WireError::WrongLabelLength { expected: len, got: got_len })
+        }
     }
 }
 

+ 11 - 3
dns/src/record/cname.rs

@@ -21,10 +21,18 @@ impl Wire for CNAME {
     const RR_TYPE: u16 = 5;
 
     #[cfg_attr(all(test, feature = "with_mutagen"), ::mutagen::mutate)]
-    fn read(_len: u16, c: &mut Cursor<&[u8]>) -> Result<Self, WireError> {
-        let domain = c.read_labels()?;
+    fn read(len: u16, c: &mut Cursor<&[u8]>) -> Result<Self, WireError> {
+        let (domain, domain_len) = c.read_labels()?;
         trace!("Parsed domain -> {:?}", domain);
-        Ok(CNAME { domain })
+
+        if len == domain_len {
+            trace!("Length is correct");
+            Ok(CNAME { domain })
+        }
+        else {
+            warn!("Length is incorrect (record length {:?}, domain length {:?})", len, domain_len);
+            Err(WireError::WrongLabelLength { expected: len, got: domain_len })
+        }
     }
 }
 

+ 7 - 6
dns/src/record/mx.rs

@@ -30,17 +30,18 @@ impl Wire for MX {
         let preference = c.read_u16::<BigEndian>()?;
         trace!("Parsed preference -> {:?}", preference);
 
-        let exchange = c.read_labels()?;
+        let (exchange, exchange_len) = c.read_labels()?;
         trace!("Parsed exchange -> {:?}", exchange);
 
-        if 2 + exchange.len() + 1 == len as usize {
-            debug!("Length {} is correct", len);
+        let got_len = 2 + exchange_len;
+        if len == got_len {
+            trace!("Length is correct");
+            Ok(MX { preference, exchange })
         }
         else {
-            warn!("Expected length {} but read {} bytes", len, 2 + exchange.len() + 1);
+            warn!("Length is incorrect (record length {:?}, preference plus exchange length {:?}", len, got_len);
+            Err(WireError::WrongLabelLength { expected: len, got: got_len })
         }
-
-        Ok(MX { preference, exchange })
     }
 }
 

+ 6 - 6
dns/src/record/ns.rs

@@ -23,17 +23,17 @@ impl Wire for NS {
 
     #[cfg_attr(all(test, feature = "with_mutagen"), ::mutagen::mutate)]
     fn read(len: u16, c: &mut Cursor<&[u8]>) -> Result<Self, WireError> {
-        let nameserver = c.read_labels()?;
+        let (nameserver, nameserver_len) = c.read_labels()?;
         trace!("Parsed nameserver -> {:?}", nameserver);
 
-        if nameserver.len() + 1 == len as usize {
-            debug!("Length {} is correct", nameserver.len() + 1);
+        if len == nameserver_len {
+            trace!("Length is correct");
+            Ok(NS { nameserver })
         }
         else {
-            warn!("Expected length {} but read {} bytes", len, nameserver.len() + 1);
+            warn!("Length is incorrect (record length {:?}, nameserver length {:?}", len, nameserver_len);
+            Err(WireError::WrongLabelLength { expected: len, got: nameserver_len })
         }
-
-        Ok(NS { nameserver })
     }
 }
 

+ 11 - 3
dns/src/record/ptr.rs

@@ -27,10 +27,18 @@ impl Wire for PTR {
     const RR_TYPE: u16 = 12;
 
     #[cfg_attr(all(test, feature = "with_mutagen"), ::mutagen::mutate)]
-    fn read(_len: u16, c: &mut Cursor<&[u8]>) -> Result<Self, WireError> {
-        let cname = c.read_labels()?;
+    fn read(len: u16, c: &mut Cursor<&[u8]>) -> Result<Self, WireError> {
+        let (cname, cname_len) = c.read_labels()?;
         trace!("Parsed cname -> {:?}", cname);
-        Ok(PTR { cname })
+
+        if len == cname_len {
+            trace!("Length is correct");
+            Ok(PTR { cname })
+        }
+        else {
+            warn!("Length is incorrect (record length {:?}, cname length {:?}", len, cname_len);
+            Err(WireError::WrongLabelLength { expected: len, got: cname_len })
+        }
     }
 }
 

+ 11 - 11
dns/src/record/soa.rs

@@ -47,10 +47,10 @@ impl Wire for SOA {
 
     #[cfg_attr(all(test, feature = "with_mutagen"), ::mutagen::mutate)]
     fn read(len: u16, c: &mut Cursor<&[u8]>) -> Result<Self, WireError> {
-        let mname = c.read_labels()?;
+        let (mname, mname_len) = c.read_labels()?;
         trace!("Parsed mname -> {:?}", mname);
 
-        let rname = c.read_labels()?;
+        let (rname, rname_len) = c.read_labels()?;
         trace!("Parsed rname -> {:?}", rname);
 
         let serial = c.read_u32::<BigEndian>()?;
@@ -68,18 +68,18 @@ impl Wire for SOA {
         let minimum_ttl = c.read_u32::<BigEndian>()?;
         trace!("Parsed minimum TTL -> {:?}", minimum_ttl);
 
-        let got_length = mname.len() + rname.len() + 4 * 5 + 2;
-        if got_length == len as usize {
-            debug!("Length {} is correct", len);
+        let got_len = 4 * 5 + mname_len + rname_len;
+        if len == got_len {
+            trace!("Length is correct");
+            Ok(SOA {
+                mname, rname, serial, refresh_interval,
+                retry_interval, expire_limit, minimum_ttl,
+            })
         }
         else {
-            warn!("Expected length {} but got {}", len, got_length);
+            warn!("Length is incorrect (record length {:?}, mname plus rname plus fields length {:?})", len, got_len);
+            Err(WireError::WrongLabelLength { expected: len, got: got_len })
         }
-
-        Ok(SOA {
-            mname, rname, serial, refresh_interval,
-            retry_interval, expire_limit, minimum_ttl,
-        })
     }
 }
 

+ 7 - 7
dns/src/record/srv.rs

@@ -43,18 +43,18 @@ impl Wire for SRV {
         let port = c.read_u16::<BigEndian>()?;
         trace!("Parsed port -> {:?}", port);
 
-        let target = c.read_labels()?;
+        let (target, target_len) = c.read_labels()?;
         trace!("Parsed target -> {:?}", target);
 
-        let got_length = 3 * 2 + target.len() + 1;
-        if got_length == len as usize {
-            debug!("Length {} is correct", len);
+        let got_len = 3 * 2 + target_len;
+        if len == got_len {
+            trace!("Length is correct");
+            Ok(SRV { priority, weight, port, target })
         }
         else {
-            warn!("Expected length {} but got {}", len, got_length);
+            warn!("Length is incorrect (record length {:?}, fields plus target length {:?})", len, got_len);
+            Err(WireError::WrongLabelLength { expected: len, got: got_len })
         }
-
-        Ok(SRV { priority, weight, port, target })
     }
 }
 

+ 11 - 4
dns/src/record/txt.rs

@@ -27,11 +27,11 @@ impl Wire for TXT {
     #[cfg_attr(all(test, feature = "with_mutagen"), ::mutagen::mutate)]
     fn read(len: u16, c: &mut Cursor<&[u8]>) -> Result<Self, WireError> {
         let mut buf = Vec::new();
-        let mut total_len = 0_usize;
+        let mut total_len = 0_u16;
 
         loop {
             let next_len = c.read_u8()?;
-            total_len += next_len as usize + 1;
+            total_len += u16::from(next_len) + 1;
             trace!("Parsed slice length -> {:?} (total so far {:?})", next_len, total_len);
 
             for _ in 0 .. next_len {
@@ -46,7 +46,7 @@ impl Wire for TXT {
             }
         }
 
-        if total_len == len as usize {
+        if len == total_len {
             debug!("Length matches expected");
         }
         else {
@@ -56,7 +56,14 @@ impl Wire for TXT {
         let message = String::from_utf8_lossy(&buf).to_string();
         trace!("Parsed message -> {:?}", message);
 
-        Ok(TXT { message })
+        if len == total_len {
+            trace!("Length is correct");
+            Ok(TXT { message })
+        }
+        else {
+            warn!("Length is incorrect (record length {:?}, message length {:?})", len, total_len);
+            Err(WireError::WrongLabelLength { expected: len, got: total_len })
+        }
     }
 }
 

+ 16 - 6
dns/src/strings.rs

@@ -13,14 +13,15 @@ use crate::wire::*;
 pub(crate) trait ReadLabels {
 
     /// Read and expand a compressed domain name.
-    fn read_labels(&mut self) -> Result<String, WireError>;
+    fn read_labels(&mut self) -> Result<(String, u16), WireError>;
 }
 
 impl ReadLabels for Cursor<&[u8]> {
-    fn read_labels(&mut self) -> Result<String, WireError> {
+    fn read_labels(&mut self) -> Result<(String, u16), WireError> {
         let mut name_buf = Vec::new();
-        read_string_recursive(&mut name_buf, self, &mut Vec::new())?;
-        Ok(String::from_utf8_lossy(&*name_buf).to_string())
+        let bytes_read = read_string_recursive(&mut name_buf, self, &mut Vec::new())?;
+        let string = String::from_utf8_lossy(&*name_buf).to_string();
+        Ok((string, bytes_read))
     }
 }
 
@@ -57,10 +58,17 @@ impl<W: Write> WriteLabels for W {
 
 const RECURSION_LIMIT: usize = 8;
 
+/// Reads bytes from the given cursor into the given buffer, using the list of
+/// recursions to track backtracking positions. Returns the count of bytes
+/// that had to be read to produce the string, including the bytes to signify
+/// backtracking, but not including the bytes read _during_ backtracking.
 #[cfg_attr(all(test, feature = "with_mutagen"), ::mutagen::mutate)]
-fn read_string_recursive(name_buf: &mut Vec<u8>, c: &mut Cursor<&[u8]>, recursions: &mut Vec<u16>) -> Result<(), WireError> {
+fn read_string_recursive(name_buf: &mut Vec<u8>, c: &mut Cursor<&[u8]>, recursions: &mut Vec<u16>) -> Result<u16, WireError> {
+    let mut bytes_read = 0;
+
     loop {
         let byte = c.read_u8()?;
+        bytes_read += 1;
 
         if byte == 0 {
             break;
@@ -73,6 +81,7 @@ fn read_string_recursive(name_buf: &mut Vec<u8>, c: &mut Cursor<&[u8]>, recursio
 
             let name_one = byte - 0b1100_0000;
             let name_two = c.read_u8()?;
+            bytes_read += 1;
             let offset = u16::from_be_bytes([name_one, name_two]);
 
             trace!("Backtracking to offset {}", offset);
@@ -93,6 +102,7 @@ fn read_string_recursive(name_buf: &mut Vec<u8>, c: &mut Cursor<&[u8]>, recursio
         else {
             for _ in 0 .. byte {
                 let c = c.read_u8()?;
+                bytes_read += 1;
                 name_buf.push(c);
             }
 
@@ -100,5 +110,5 @@ fn read_string_recursive(name_buf: &mut Vec<u8>, c: &mut Cursor<&[u8]>, recursio
         }
     }
 
-    Ok(())
+    Ok(bytes_read)
 }

+ 51 - 7
dns/src/wire.rs

@@ -76,28 +76,28 @@ impl Response {
         let mut queries = Vec::new();
         debug!("Reading {}x query from response", query_count);
         for _ in 0 .. query_count {
-            let qname = c.read_labels()?;
+            let (qname, _) = c.read_labels()?;
             queries.push(Query::from_bytes(qname, &mut c)?);
         }
 
         let mut answers = Vec::new();
         debug!("Reading {}x answer from response", answer_count);
         for _ in 0 .. answer_count {
-            let qname = c.read_labels()?;
+            let (qname, _) = c.read_labels()?;
             answers.push(Answer::from_bytes(qname, &mut c)?);
         }
 
         let mut authorities = Vec::new();
         debug!("Reading {}x authority from response", authority_count);
         for _ in 0 .. authority_count {
-            let qname = c.read_labels()?;
+            let (qname, _) = c.read_labels()?;
             authorities.push(Answer::from_bytes(qname, &mut c)?);
         }
 
         let mut additionals = Vec::new();
         debug!("Reading {}x additional answer from response", additional_count);
         for _ in 0 .. additional_count {
-            let qname = c.read_labels()?;
+            let (qname, _) = c.read_labels()?;
             additionals.push(Answer::from_bytes(qname, &mut c)?);
         }
 
@@ -352,9 +352,53 @@ pub enum WireError {
     IO,
     // (io::Error is not PartialEq so we don’t propagate it)
 
-    /// When this record expected the data to be a certain size, but it was
-    /// a different one.
-    WrongLength {
+    /// When the DNS standard requires records of this type to have a certain
+    /// fixed length, but the response specified a different length.
+    ///
+    /// This error should be returned regardless of the _content_ of the
+    /// record, whatever it is.
+    WrongRecordLength {
+
+        /// The expected size.
+        expected: u16,
+
+        /// The size that was actually received.
+        got: u16,
+    },
+
+    /// When the length of this record as specified in the packet differs from
+    /// the computed length, as determined by reading labels.
+    ///
+    /// There are two ways, in general, to read arbitrary-length data from a
+    /// stream of bytes: length-prefixed (read the length, then read that many
+    /// bytes) or sentinel-terminated (keep reading bytes until you read a
+    /// certain value, usually zero). The DNS protocol uses both: each
+    /// record’s size is specified up-front in the packet, but inside the
+    /// record, there exist arbitrary-length strings that must be read until a
+    /// zero is read, indicating there is no more string.
+    ///
+    /// Consider the case of a packet, with a specified length, containing a
+    /// string of arbitrary length (such as the CNAME or TXT records). A DNS
+    /// client has to deal with this in one of two ways:
+    ///
+    /// 1. Read exactly the specified length of bytes from the record, raising
+    ///    an error if the contents are too short or a string keeps going past
+    ///    the length (assume the length is correct but the contents are wrong).
+    ///
+    /// 2. Read as many bytes from the record as the string requests, raising
+    ///    an error if the number of bytes read at the end differs from the
+    ///    expected length of the record (assume the length is wrong but the
+    ///    contents are correct).
+    ///
+    /// Note that no matter which way is picked, the record will still be
+    /// incorrect — it only impacts the parsing of records that occur after it
+    /// in the packet. Knowing which method should be used requires knowing
+    /// what caused the DNS packet to be erroneous, which we cannot know.
+    ///
+    /// dog picks the second way. If a record ends up reading more or fewer
+    /// bytes than it is ‘supposed’ to, it will raise this error, but _after_
+    /// having read a different number of bytes than the specified length.
+    WrongLabelLength {
 
         /// The expected size.
         expected: u16,

+ 11 - 8
src/output.rs

@@ -179,17 +179,20 @@ fn error_message(error: TransportError) -> String {
 		TransportError::HttpError(e)     => e.to_string(),
 		TransportError::TlsError(e)      => e.to_string(),
 		TransportError::BadRequest       => "Nameserver returned HTTP 400 Bad Request".into(),
-		TransportError::WireError(e)     => {
-			match e {
-				WireError::IO                             => "Malformed packet: insufficient data".into(),
-				WireError::WrongLength { expected, got }  => format!("Malformed packet: expected length {}, got {}", expected, got),
-				WireError::TooMuchRecursion(indices)      => format!("Malformed packet: too much recursion: {:?}", indices),
-				WireError::OutOfBounds(index)             => format!("Malformed packet: out of bounds ({})", index),
-			}
-		}
+		TransportError::WireError(e)     => wire_error_message(e),
 	}
 }
 
+fn wire_error_message(error: WireError) -> String {
+    match error {
+        WireError::IO                                   => "Malformed packet: insufficient data".into(),
+        WireError::WrongRecordLength { expected, got }  => format!("Malformed packet: expected length {}, got {}", expected, got),
+        WireError::WrongLabelLength { expected, got }   => format!("Malformed packet: expected length {}, got {}", expected, got),
+        WireError::TooMuchRecursion(indices)            => format!("Malformed packet: too much recursion: {:?}", indices),
+        WireError::OutOfBounds(index)                   => format!("Malformed packet: out of bounds ({})", index),
+    }
+}
+
 
 impl TextFormat {
     pub fn record_payload_summary(self, record: &Record) -> String {