瀏覽代碼

Make DhcpRepr use an Option<Vec<Ipv4Address>>

datdenkikniet 2 年之前
父節點
當前提交
ca4a98acc6
共有 2 個文件被更改,包括 73 次插入63 次删除
  1. 36 35
      src/socket/dhcpv4.rs
  2. 37 28
      src/wire/dhcpv4.rs

+ 36 - 35
src/socket/dhcpv4.rs

@@ -424,7 +424,6 @@ impl<'a> Socket<'a> {
             .dns_servers
             .iter()
             .flatten()
-            .flatten()
             .filter(|s| s.is_unicast())
             .for_each(|a| {
                 // This will never produce an error, as both the arrays and `dns_servers`
@@ -779,8 +778,6 @@ mod test {
     const DNS_IP_1: Ipv4Address = Ipv4Address([1, 1, 1, 1]);
     const DNS_IP_2: Ipv4Address = Ipv4Address([1, 1, 1, 2]);
     const DNS_IP_3: Ipv4Address = Ipv4Address([1, 1, 1, 3]);
-    const DNS_IPS_ARR: [Option<Ipv4Address>; DHCP_MAX_DNS_SERVER_COUNT] =
-        [Some(DNS_IP_1), Some(DNS_IP_2), Some(DNS_IP_3)];
     const DNS_IPS: &[Ipv4Address] = &[DNS_IP_1, DNS_IP_2, DNS_IP_3];
 
     const MASK_24: Ipv4Address = Ipv4Address([255, 255, 255, 0]);
@@ -858,19 +855,21 @@ mod test {
         ..DHCP_DEFAULT
     };
 
-    const DHCP_OFFER: DhcpRepr = DhcpRepr {
-        message_type: DhcpMessageType::Offer,
-        server_ip: SERVER_IP,
-        server_identifier: Some(SERVER_IP),
+    fn dhcp_offer() -> DhcpRepr<'static> {
+        DhcpRepr {
+            message_type: DhcpMessageType::Offer,
+            server_ip: SERVER_IP,
+            server_identifier: Some(SERVER_IP),
 
-        your_ip: MY_IP,
-        router: Some(SERVER_IP),
-        subnet_mask: Some(MASK_24),
-        dns_servers: Some(DNS_IPS_ARR),
-        lease_duration: Some(1000),
+            your_ip: MY_IP,
+            router: Some(SERVER_IP),
+            subnet_mask: Some(MASK_24),
+            dns_servers: Some(Vec::from_slice(DNS_IPS).unwrap()),
+            lease_duration: Some(1000),
 
-        ..DHCP_DEFAULT
-    };
+            ..DHCP_DEFAULT
+        }
+    }
 
     const DHCP_REQUEST: DhcpRepr = DhcpRepr {
         message_type: DhcpMessageType::Request,
@@ -883,19 +882,21 @@ mod test {
         ..DHCP_DEFAULT
     };
 
-    const DHCP_ACK: DhcpRepr = DhcpRepr {
-        message_type: DhcpMessageType::Ack,
-        server_ip: SERVER_IP,
-        server_identifier: Some(SERVER_IP),
+    fn dhcp_ack() -> DhcpRepr<'static> {
+        DhcpRepr {
+            message_type: DhcpMessageType::Ack,
+            server_ip: SERVER_IP,
+            server_identifier: Some(SERVER_IP),
 
-        your_ip: MY_IP,
-        router: Some(SERVER_IP),
-        subnet_mask: Some(MASK_24),
-        dns_servers: Some(DNS_IPS_ARR),
-        lease_duration: Some(1000),
+            your_ip: MY_IP,
+            router: Some(SERVER_IP),
+            subnet_mask: Some(MASK_24),
+            dns_servers: Some(Vec::from_slice(DNS_IPS).unwrap()),
+            lease_duration: Some(1000),
 
-        ..DHCP_DEFAULT
-    };
+            ..DHCP_DEFAULT
+        }
+    }
 
     const DHCP_NAK: DhcpRepr = DhcpRepr {
         message_type: DhcpMessageType::Nak,
@@ -954,11 +955,11 @@ mod test {
 
         recv!(s, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]);
         assert_eq!(s.poll(), None);
-        send!(s, (IP_RECV, UDP_RECV, DHCP_OFFER));
+        send!(s, (IP_RECV, UDP_RECV, dhcp_offer()));
         assert_eq!(s.poll(), None);
         recv!(s, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]);
         assert_eq!(s.poll(), None);
-        send!(s, (IP_RECV, UDP_RECV, DHCP_ACK));
+        send!(s, (IP_RECV, UDP_RECV, dhcp_ack()));
 
         assert_eq!(
             s.poll(),
@@ -994,7 +995,7 @@ mod test {
         recv!(s, time 20_000, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]);
 
         // check after retransmits it still works
-        send!(s, time 20_000, (IP_RECV, UDP_RECV, DHCP_OFFER));
+        send!(s, time 20_000, (IP_RECV, UDP_RECV, dhcp_offer()));
         recv!(s, time 20_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]);
     }
 
@@ -1003,7 +1004,7 @@ mod test {
         let mut s = socket();
 
         recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]);
-        send!(s, time 0, (IP_RECV, UDP_RECV, DHCP_OFFER));
+        send!(s, time 0, (IP_RECV, UDP_RECV, dhcp_offer()));
         recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]);
         recv!(s, time 1_000, []);
         recv!(s, time 5_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]);
@@ -1013,7 +1014,7 @@ mod test {
         recv!(s, time 20_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]);
 
         // check after retransmits it still works
-        send!(s, time 20_000, (IP_RECV, UDP_RECV, DHCP_ACK));
+        send!(s, time 20_000, (IP_RECV, UDP_RECV, dhcp_ack()));
 
         match &s.state {
             ClientState::Renewing(r) => {
@@ -1029,7 +1030,7 @@ mod test {
         let mut s = socket();
 
         recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]);
-        send!(s, time 0, (IP_RECV, UDP_RECV, DHCP_OFFER));
+        send!(s, time 0, (IP_RECV, UDP_RECV, dhcp_offer()));
         recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]);
         recv!(s, time 5_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]);
         recv!(s, time 10_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]);
@@ -1041,7 +1042,7 @@ mod test {
         recv!(s, time 70_000, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]);
 
         // check it still works
-        send!(s, time 60_000, (IP_RECV, UDP_RECV, DHCP_OFFER));
+        send!(s, time 60_000, (IP_RECV, UDP_RECV, dhcp_offer()));
         recv!(s, time 60_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]);
     }
 
@@ -1050,7 +1051,7 @@ mod test {
         let mut s = socket();
 
         recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]);
-        send!(s, time 0, (IP_RECV, UDP_RECV, DHCP_OFFER));
+        send!(s, time 0, (IP_RECV, UDP_RECV, dhcp_offer()));
         recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]);
         send!(s, time 0, (IP_SERVER_BROADCAST, UDP_RECV, DHCP_NAK));
         recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]);
@@ -1074,7 +1075,7 @@ mod test {
             _ => panic!("Invalid state"),
         }
 
-        send!(s, time 500_000, (IP_RECV, UDP_RECV, DHCP_ACK));
+        send!(s, time 500_000, (IP_RECV, UDP_RECV, dhcp_ack()));
         assert_eq!(s.poll(), None);
 
         match &s.state {
@@ -1099,7 +1100,7 @@ mod test {
         recv!(s, time 875_000, [(IP_SEND, UDP_SEND, DHCP_RENEW)]);
 
         // check it still works
-        send!(s, time 875_000, (IP_RECV, UDP_RECV, DHCP_ACK));
+        send!(s, time 875_000, (IP_RECV, UDP_RECV, dhcp_ack()));
         match &s.state {
             ClientState::Renewing(r) => {
                 // NOW the expiration gets bumped

+ 37 - 28
src/wire/dhcpv4.rs

@@ -3,6 +3,7 @@
 use bitflags::bitflags;
 use byteorder::{ByteOrder, NetworkEndian};
 use core::iter;
+use heapless::Vec;
 
 use super::{Error, Result};
 use crate::wire::arp::Hardware;
@@ -582,7 +583,7 @@ impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> Packet<&'a mut T> {
 /// length) is set to `6`.
 ///
 /// The `options` field has a variable length.
-#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+#[derive(Debug, PartialEq, Eq, Clone)]
 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
 pub struct Repr<'a> {
     /// This field is also known as `op` in the RFC. It indicates the type of DHCP message this
@@ -645,7 +646,7 @@ pub struct Repr<'a> {
     /// the client is interested in.
     pub parameter_request_list: Option<&'a [u8]>,
     /// DNS servers
-    pub dns_servers: Option<[Option<Ipv4Address>; MAX_DNS_SERVER_COUNT]>,
+    pub dns_servers: Option<Vec<Ipv4Address, MAX_DNS_SERVER_COUNT>>,
     /// The maximum size dhcp packet the interface can receive
     pub max_size: Option<u16>,
     /// The DHCP IP lease duration, specified in seconds.
@@ -683,9 +684,9 @@ impl<'a> Repr<'a> {
         if self.lease_duration.is_some() {
             len += 6;
         }
-        if let Some(dns_servers) = self.dns_servers {
+        if let Some(dns_servers) = &self.dns_servers {
             len += 2;
-            len += dns_servers.iter().flatten().count() * core::mem::size_of::<u32>();
+            len += dns_servers.iter().count() * core::mem::size_of::<u32>();
         }
         if let Some(list) = self.parameter_request_list {
             len += list.len() + 2;
@@ -773,13 +774,13 @@ impl<'a> Repr<'a> {
                     parameter_request_list = Some(data);
                 }
                 (field::OPT_DOMAIN_NAME_SERVER, _) => {
-                    let mut servers = [None; MAX_DNS_SERVER_COUNT];
-                    let chunk_size = 4;
-                    for (server, chunk) in servers.iter_mut().zip(data.chunks(chunk_size)) {
-                        if chunk.len() != chunk_size {
-                            return Err(Error);
-                        }
-                        *server = Some(Ipv4Address::from_bytes(chunk));
+                    let mut servers = Vec::new();
+                    const IP_ADDR_BYTE_LEN: usize = 4;
+                    for chunk in data.chunks(IP_ADDR_BYTE_LEN) {
+                        // We ignore push failures because that will only happen
+                        // if we attempt to push more than 4 addresses, and the only
+                        // solution to that is to support more addresses.
+                        servers.push(Ipv4Address::from_bytes(chunk)).ok();
                     }
                     dns_servers = Some(servers);
                 }
@@ -901,13 +902,12 @@ impl<'a> Repr<'a> {
                 })?;
             }
 
-            if let Some(dns_servers) = self.dns_servers {
+            if let Some(dns_servers) = &self.dns_servers {
                 const IP_SIZE: usize = core::mem::size_of::<u32>();
                 let mut servers = [0; MAX_DNS_SERVER_COUNT * IP_SIZE];
 
                 let data_len = dns_servers
                     .iter()
-                    .flatten()
                     .enumerate()
                     .inspect(|(i, ip)| {
                         servers[(i * IP_SIZE)..((i + 1) * IP_SIZE)].copy_from_slice(ip.as_bytes());
@@ -1210,11 +1210,14 @@ mod test {
     fn test_emit_offer_dns() {
         let repr = {
             let mut repr = offer_repr();
-            repr.dns_servers = Some([
-                Some(Ipv4Address([163, 1, 74, 6])),
-                Some(Ipv4Address([163, 1, 74, 7])),
-                Some(Ipv4Address([163, 1, 74, 3])),
-            ]);
+            repr.dns_servers = Some(
+                Vec::from_slice(&[
+                    Ipv4Address([163, 1, 74, 6]),
+                    Ipv4Address([163, 1, 74, 7]),
+                    Ipv4Address([163, 1, 74, 3]),
+                ])
+                .unwrap(),
+            );
             repr
         };
         let mut bytes = vec![0xa5; repr.buffer_len()];
@@ -1226,11 +1229,14 @@ mod test {
 
         assert_eq!(
             repr_parsed.dns_servers,
-            Some([
-                Some(Ipv4Address([163, 1, 74, 6])),
-                Some(Ipv4Address([163, 1, 74, 7])),
-                Some(Ipv4Address([163, 1, 74, 3]))
-            ])
+            Some(
+                Vec::from_slice(&[
+                    Ipv4Address([163, 1, 74, 6]),
+                    Ipv4Address([163, 1, 74, 7]),
+                    Ipv4Address([163, 1, 74, 3]),
+                ])
+                .unwrap()
+            )
         );
     }
 
@@ -1263,11 +1269,14 @@ mod test {
         // length-3 array (see issue #305)
         assert_eq!(
             repr.dns_servers,
-            Some([
-                Some(Ipv4Address([163, 1, 74, 6])),
-                Some(Ipv4Address([163, 1, 74, 7])),
-                Some(Ipv4Address([163, 1, 74, 3]))
-            ])
+            Some(
+                Vec::from_slice(&[
+                    Ipv4Address([163, 1, 74, 6]),
+                    Ipv4Address([163, 1, 74, 7]),
+                    Ipv4Address([163, 1, 74, 3])
+                ])
+                .unwrap()
+            )
         );
     }