瀏覽代碼

Add MandatedLength enum

It turns out that the "wrong length" error could either be "should be an exact length" or "should be at least a certain length".
Benjamin Sago 4 年之前
父節點
當前提交
a9f41dc5eb
共有 6 個文件被更改,包括 33 次插入16 次删除
  1. 1 1
      dns/src/lib.rs
  2. 4 4
      dns/src/record/a.rs
  3. 6 4
      dns/src/record/aaaa.rs
  4. 4 3
      dns/src/record/loc.rs
  5. 12 1
      dns/src/wire.rs
  6. 6 3
      src/output.rs

+ 1 - 1
dns/src/lib.rs

@@ -35,6 +35,6 @@ mod strings;
 pub use self::strings::Labels;
 
 mod wire;
-pub use self::wire::{Wire, WireError, find_qtype_number};
+pub use self::wire::{Wire, WireError, MandatedLength, find_qtype_number};
 
 pub mod record;

+ 4 - 4
dns/src/record/a.rs

@@ -35,7 +35,7 @@ impl Wire for A {
         }
         else {
             warn!("Length is incorrect (record length {:?}, but should be four)", stated_length);
-            Err(WireError::WrongRecordLength { stated_length, mandated_length: 4 })
+            Err(WireError::WrongRecordLength { stated_length, mandated_length: MandatedLength::Exactly(4) })
         }
     }
 }
@@ -63,7 +63,7 @@ mod test {
         ];
 
         assert_eq!(A::read(buf.len() as _, &mut Cursor::new(buf)),
-                   Err(WireError::WrongRecordLength { stated_length: 3, mandated_length: 4 }));
+                   Err(WireError::WrongRecordLength { stated_length: 3, mandated_length: MandatedLength::Exactly(4) }));
     }
 
     #[test]
@@ -74,13 +74,13 @@ mod test {
         ];
 
         assert_eq!(A::read(buf.len() as _, &mut Cursor::new(buf)),
-                   Err(WireError::WrongRecordLength { stated_length: 5, mandated_length: 4 }));
+                   Err(WireError::WrongRecordLength { stated_length: 5, mandated_length: MandatedLength::Exactly(4) }));
     }
 
     #[test]
     fn record_empty() {
         assert_eq!(A::read(0, &mut Cursor::new(&[])),
-                   Err(WireError::WrongRecordLength { stated_length: 0, mandated_length: 4 }));
+                   Err(WireError::WrongRecordLength { stated_length: 0, mandated_length: MandatedLength::Exactly(4) }));
     }
 
     #[test]

+ 6 - 4
dns/src/record/aaaa.rs

@@ -37,7 +37,9 @@ impl Wire for AAAA {
         }
         else {
             warn!("Length is incorrect (stated length {:?}, but should be sixteen)", stated_length);
-            Err(WireError::WrongRecordLength { stated_length, mandated_length: 16 })
+
+            let mandated_length = MandatedLength::Exactly(16);
+            Err(WireError::WrongRecordLength { stated_length, mandated_length })
         }
     }
 }
@@ -68,7 +70,7 @@ mod test {
         ];
 
         assert_eq!(AAAA::read(buf.len() as _, &mut Cursor::new(buf)),
-                   Err(WireError::WrongRecordLength { stated_length: 17, mandated_length: 16 }));
+                   Err(WireError::WrongRecordLength { stated_length: 17, mandated_length: MandatedLength::Exactly(16) }));
     }
 
     #[test]
@@ -78,13 +80,13 @@ mod test {
         ];
 
         assert_eq!(AAAA::read(buf.len() as _, &mut Cursor::new(buf)),
-                   Err(WireError::WrongRecordLength { stated_length: 5, mandated_length: 16 }));
+                   Err(WireError::WrongRecordLength { stated_length: 5, mandated_length: MandatedLength::Exactly(16) }));
     }
 
     #[test]
     fn record_empty() {
         assert_eq!(AAAA::read(0, &mut Cursor::new(&[])),
-                   Err(WireError::WrongRecordLength { stated_length: 0, mandated_length: 16 }));
+                   Err(WireError::WrongRecordLength { stated_length: 0, mandated_length: MandatedLength::Exactly(16) }));
     }
 
     #[test]

+ 4 - 3
dns/src/record/loc.rs

@@ -63,7 +63,8 @@ impl Wire for LOC {
         }
 
         if stated_length != 16 {
-            return Err(WireError::WrongRecordLength { stated_length, mandated_length: 16 });
+            let mandated_length = MandatedLength::Exactly(16);
+            return Err(WireError::WrongRecordLength { stated_length, mandated_length });
         }
 
         let size_bits = c.read_u8()?;
@@ -137,7 +138,7 @@ mod test {
         ];
 
         assert_eq!(LOC::read(buf.len() as _, &mut Cursor::new(buf)),
-                   Err(WireError::WrongRecordLength { stated_length: 2, mandated_length: 16 }));
+                   Err(WireError::WrongRecordLength { stated_length: 2, mandated_length: MandatedLength::Exactly(16) }));
     }
 
     #[test]
@@ -154,7 +155,7 @@ mod test {
         ];
 
         assert_eq!(LOC::read(buf.len() as _, &mut Cursor::new(buf)),
-                   Err(WireError::WrongRecordLength { stated_length: 19, mandated_length: 16 }));
+                   Err(WireError::WrongRecordLength { stated_length: 19, mandated_length: MandatedLength::Exactly(16) }));
     }
 
     #[test]

+ 12 - 1
dns/src/wire.rs

@@ -391,7 +391,7 @@ pub enum WireError {
         stated_length: u16,
 
         /// The length of the record that the DNS specification mandates.
-        mandated_length: u16,
+        mandated_length: MandatedLength,
     },
 
     /// When the length of this record as specified in the packet differs from
@@ -457,6 +457,17 @@ pub enum WireError {
     }
 }
 
+/// The rule for how long a record in a packet should be.
+#[derive(PartialEq, Debug, Copy, Clone)]
+pub enum MandatedLength {
+
+    /// The record should be exactly this many bytes in length.
+    Exactly(u16),
+
+    /// The record should be _at least_ this many bytes in length.
+    AtLeast(u16),
+}
+
 impl From<io::Error> for WireError {
     fn from(ioe: io::Error) -> Self {
         error!("IO error -> {:?}", ioe);

+ 6 - 3
src/output.rs

@@ -2,7 +2,7 @@
 
 use std::time::Duration;
 
-use dns::{Response, Query, Answer, ErrorCode, WireError};
+use dns::{Response, Query, Answer, ErrorCode, WireError, MandatedLength};
 use dns::record::{Record, OPT, UnknownQtype};
 use dns_transport::Error as TransportError;
 use serde_json::{json, Value as JsonValue};
@@ -516,8 +516,11 @@ fn wire_error_message(error: WireError) -> String {
         WireError::IO => {
             "Malformed packet: insufficient data".into()
         }
-        WireError::WrongRecordLength { stated_length, mandated_length } => {
-            format!("Malformed packet: record length should be {}, got {}", mandated_length, stated_length )
+        WireError::WrongRecordLength { stated_length, mandated_length: MandatedLength::Exactly(len) } => {
+            format!("Malformed packet: record length should be {}, got {}", len, stated_length )
+        }
+        WireError::WrongRecordLength { stated_length, mandated_length: MandatedLength::AtLeast(len) } => {
+            format!("Malformed packet: record length should be at least {}, got {}", len, stated_length )
         }
         WireError::WrongLabelLength { stated_length, length_after_labels } => {
             format!("Malformed packet: length {} was specified, but read {} bytes", stated_length, length_after_labels)