ソースを参照

Use read_exact to improve A and AAAA parsing code

As sad as it is to see the best two lines of code I have ever written disappear, it turns out that there's a function in the standard library for this exact use case (fill an array with an exact number of bytes) that we can use instead of doing it ourselves.
Benjamin Sago 4 年 前
コミット
f381b85830
2 ファイル変更22 行追加26 行削除
  1. 12 12
      dns/src/record/a.rs
  2. 10 14
      dns/src/record/aaaa.rs

+ 12 - 12
dns/src/record/a.rs

@@ -1,3 +1,4 @@
+use std::io::Read;
 use std::net::Ipv4Addr;
 
 use log::*;
@@ -24,20 +25,19 @@ impl Wire for A {
 
     #[cfg_attr(all(test, feature = "with_mutagen"), ::mutagen::mutate)]
     fn read(stated_length: u16, c: &mut Cursor<&[u8]>) -> Result<Self, WireError> {
-        let mut buf = Vec::new();
-        for _ in 0 .. stated_length {
-            buf.push(c.read_u8()?);
-        }
-
-        if let [a, b, c, d] = *buf {
-            let address = Ipv4Addr::new(a, b, c, d);
-            trace!("Parsed IPv4 address -> {:?}", address);
-            Ok(Self { address })
-        }
-        else {
+        if stated_length != 4 {
             warn!("Length is incorrect (record length {:?}, but should be four)", stated_length);
-            Err(WireError::WrongRecordLength { stated_length, mandated_length: MandatedLength::Exactly(4) })
+            let mandated_length = MandatedLength::Exactly(4);
+            return Err(WireError::WrongRecordLength { stated_length, mandated_length });
         }
+
+        let mut buf = [0_u8; 4];
+        c.read_exact(&mut buf)?;
+
+        let address = Ipv4Addr::from(buf);
+        trace!("Parsed IPv4 address -> {:?}", address);
+
+        Ok(Self { address })
     }
 }
 

+ 10 - 14
dns/src/record/aaaa.rs

@@ -1,3 +1,4 @@
+use std::io::Read;
 use std::net::Ipv6Addr;
 
 use log::*;
@@ -24,24 +25,19 @@ impl Wire for AAAA {
 
     #[cfg_attr(all(test, feature = "with_mutagen"), ::mutagen::mutate)]
     fn read(stated_length: u16, c: &mut Cursor<&[u8]>) -> Result<Self, WireError> {
-        let mut buf = Vec::new();
-        for _ in 0 .. stated_length {
-            buf.push(c.read_u8()?);
+        if stated_length != 16 {
+            warn!("Length is incorrect (stated length {:?}, but should be sixteen)", stated_length);
+            let mandated_length = MandatedLength::Exactly(16);
+            return Err(WireError::WrongRecordLength { stated_length, mandated_length });
         }
 
-        if let [a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p] = *buf {
-            let address = Ipv6Addr::from([a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p]);
-            // probably the best two lines of code I have ever written
+        let mut buf = [0_u8; 16];
+        c.read_exact(&mut buf)?;
 
-            trace!("Parsed IPv6 address -> {:?}", address);
-            Ok(Self { address })
-        }
-        else {
-            warn!("Length is incorrect (stated length {:?}, but should be sixteen)", stated_length);
+        let address = Ipv6Addr::from(buf);
+        trace!("Parsed IPv6 address -> {:?}", address);
 
-            let mandated_length = MandatedLength::Exactly(16);
-            Err(WireError::WrongRecordLength { stated_length, mandated_length })
-        }
+        Ok(Self { address })
     }
 }