Эх сурвалжийг харах

support hardware based checksum settings in during packet send/recv

- makes sure the checksum is zeroed when not emitted by software
  (This is required by some implementations such as STM32 to work properly)
Steffen Butzer 7 жил өмнө
parent
commit
d5147efb82

+ 6 - 4
examples/ping.rs

@@ -10,6 +10,7 @@ mod utils;
 use std::str::FromStr;
 use std::time::Instant;
 use std::os::unix::io::AsRawFd;
+use smoltcp::phy::Device;
 use smoltcp::phy::wait as phy_wait;
 use smoltcp::wire::{EthernetAddress, IpVersion, IpProtocol, IpAddress,
                     Ipv4Address, Ipv4Packet, Ipv4Repr,
@@ -56,6 +57,7 @@ fn main() {
                                     raw_rx_buffer, raw_tx_buffer);
 
     let hardware_addr  = EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x02]);
+    let caps = device.capabilities();
     let mut iface = EthernetInterface::new(
         Box::new(device), Box::new(arp_cache) as Box<ArpCache>,
         hardware_addr, [IpAddress::from(local_addr)]);
@@ -98,9 +100,9 @@ fn main() {
                     .unwrap();
 
                 let mut ipv4_packet = Ipv4Packet::new(raw_payload);
-                ipv4_repr.emit(&mut ipv4_packet);
+                ipv4_repr.emit(&mut ipv4_packet, &caps.checksum);
                 let mut icmp_packet = Icmpv4Packet::new(ipv4_packet.payload_mut());
-                icmp_repr.emit(&mut icmp_packet);
+                icmp_repr.emit(&mut icmp_packet, &caps.checksum);
 
                 waiting_queue.insert(seq_no, timestamp);
                 seq_no += 1;
@@ -110,11 +112,11 @@ fn main() {
             if socket.can_recv() {
                 let payload = socket.recv().unwrap();
                 let ipv4_packet = Ipv4Packet::new(payload);
-                let ipv4_repr = Ipv4Repr::parse(&ipv4_packet).unwrap();
+                let ipv4_repr = Ipv4Repr::parse(&ipv4_packet, &caps.checksum).unwrap();
 
                 if ipv4_repr.src_addr == remote_addr && ipv4_repr.dst_addr == local_addr {
                     let icmp_packet = Icmpv4Packet::new(ipv4_packet.payload());
-                    let icmp_repr = Icmpv4Repr::parse(&icmp_packet);
+                    let icmp_repr = Icmpv4Repr::parse(&icmp_packet, &caps.checksum);
 
                     if let Ok(Icmpv4Repr::EchoReply { seq_no, data, .. }) = icmp_repr {
                         if let Some(_) = waiting_queue.get(&seq_no) {

+ 26 - 18
src/iface/ethernet.rs

@@ -181,7 +181,7 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
                         socket.dispatch(|response| {
                             device_result = self.dispatch(timestamp, Packet::Raw(response));
                             device_result
-                        }),
+                        }, &caps.checksum),
                     #[cfg(feature = "socket-udp")]
                     &mut Socket::Udp(ref mut socket) =>
                         socket.dispatch(|response| {
@@ -190,7 +190,7 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
                         }),
                     #[cfg(feature = "socket-tcp")]
                     &mut Socket::Tcp(ref mut socket) =>
-                        socket.dispatch(timestamp, &limits, |response| {
+                        socket.dispatch(timestamp, &caps, |response| {
                             device_result = self.dispatch(timestamp, Packet::Tcp(response));
                             device_result
                         }),
@@ -278,7 +278,8 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
                     eth_frame: &EthernetFrame<&'frame T>) ->
                    Result<Packet<'frame>> {
         let ipv4_packet = Ipv4Packet::new_checked(eth_frame.payload())?;
-        let ipv4_repr = Ipv4Repr::parse(&ipv4_packet)?;
+        let checksum_caps = self.device.capabilities().checksum;
+        let ipv4_repr = Ipv4Repr::parse(&ipv4_packet, &checksum_caps)?;
 
         if !ipv4_repr.src_addr.is_unicast() {
             // Discard packets with non-unicast source addresses.
@@ -304,7 +305,7 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
                 <Socket as AsSocket<RawSocket>>::try_as_socket) {
             if !raw_socket.accepts(&ip_repr) { continue }
 
-            match raw_socket.process(&ip_repr, ip_payload) {
+            match raw_socket.process(&ip_repr, ip_payload, &checksum_caps) {
                 // The packet is valid and handled by socket.
                 Ok(()) => handled_by_raw_socket = true,
                 // The socket buffer is full.
@@ -321,15 +322,15 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
 
         match ipv4_repr.protocol {
             IpProtocol::Icmp =>
-                Self::process_icmpv4(ipv4_repr, ip_payload),
+                self.process_icmpv4(ipv4_repr, ip_payload),
 
             #[cfg(feature = "socket-udp")]
             IpProtocol::Udp =>
-                Self::process_udp(sockets, ip_repr, ip_payload),
+                self.process_udp(sockets, ip_repr, ip_payload),
 
             #[cfg(feature = "socket-tcp")]
             IpProtocol::Tcp =>
-                Self::process_tcp(sockets, _timestamp, ip_repr, ip_payload),
+                self.process_tcp(sockets, _timestamp, ip_repr, ip_payload),
 
             #[cfg(feature = "socket-raw")]
             _ if handled_by_raw_socket =>
@@ -352,10 +353,11 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
         }
     }
 
-    fn process_icmpv4<'frame>(ipv4_repr: Ipv4Repr, ip_payload: &'frame [u8]) ->
-                             Result<Packet<'frame>> {
+    fn process_icmpv4<'frame>(&self, ipv4_repr: Ipv4Repr, 
+                              ip_payload: &'frame [u8]) -> Result<Packet<'frame>> {
         let icmp_packet = Icmpv4Packet::new_checked(ip_payload)?;
-        let icmp_repr = Icmpv4Repr::parse(&icmp_packet)?;
+        let checksum_caps = self.device.capabilities().checksum;
+        let icmp_repr = Icmpv4Repr::parse(&icmp_packet, &checksum_caps)?;
 
         match icmp_repr {
             // Respond to echo requests.
@@ -383,12 +385,13 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
     }
 
     #[cfg(feature = "socket-udp")]
-    fn process_udp<'frame>(sockets: &mut SocketSet,
+    fn process_udp<'frame>(&self, sockets: &mut SocketSet,
                            ip_repr: IpRepr, ip_payload: &'frame [u8]) ->
                           Result<Packet<'frame>> {
         let (src_addr, dst_addr) = (ip_repr.src_addr(), ip_repr.dst_addr());
         let udp_packet = UdpPacket::new_checked(ip_payload)?;
-        let udp_repr = UdpRepr::parse(&udp_packet, &src_addr, &dst_addr)?;
+        let checksum_caps = self.device.capabilities().checksum;
+        let udp_repr = UdpRepr::parse(&udp_packet, &src_addr, &dst_addr, &checksum_caps)?;
 
         for udp_socket in sockets.iter_mut().filter_map(
                 <Socket as AsSocket<UdpSocket>>::try_as_socket) {
@@ -425,12 +428,13 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
     }
 
     #[cfg(feature = "socket-tcp")]
-    fn process_tcp<'frame>(sockets: &mut SocketSet, timestamp: u64,
+    fn process_tcp<'frame>(&self, sockets: &mut SocketSet, timestamp: u64,
                            ip_repr: IpRepr, ip_payload: &'frame [u8]) ->
                           Result<Packet<'frame>> {
         let (src_addr, dst_addr) = (ip_repr.src_addr(), ip_repr.dst_addr());
         let tcp_packet = TcpPacket::new_checked(ip_payload)?;
-        let tcp_repr = TcpRepr::parse(&tcp_packet, &src_addr, &dst_addr)?;
+        let checksum_caps = self.device.capabilities().checksum;
+        let tcp_repr = TcpRepr::parse(&tcp_packet, &src_addr, &dst_addr, &checksum_caps)?;
 
         for tcp_socket in sockets.iter_mut().filter_map(
                 <Socket as AsSocket<TcpSocket>>::try_as_socket) {
@@ -455,6 +459,7 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
     }
 
     fn dispatch(&mut self, timestamp: u64, packet: Packet) -> Result<()> {
+        let checksum_caps = self.device.capabilities().checksum;
         match packet {
             Packet::Arp(arp_repr) => {
                 let dst_hardware_addr =
@@ -473,7 +478,7 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
             },
             Packet::Icmpv4(ipv4_repr, icmpv4_repr) => {
                 self.dispatch_ip(timestamp, IpRepr::Ipv4(ipv4_repr), |_ip_repr, payload| {
-                    icmpv4_repr.emit(&mut Icmpv4Packet::new(payload));
+                    icmpv4_repr.emit(&mut Icmpv4Packet::new(payload), &checksum_caps);
                 })
             }
             #[cfg(feature = "socket-raw")]
@@ -486,7 +491,8 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
             Packet::Udp((ip_repr, udp_repr)) => {
                 self.dispatch_ip(timestamp, ip_repr, |ip_repr, payload| {
                     udp_repr.emit(&mut UdpPacket::new(payload),
-                                  &ip_repr.src_addr(), &ip_repr.dst_addr());
+                                  &ip_repr.src_addr(), &ip_repr.dst_addr(), 
+                                  &checksum_caps);
                 })
             }
             #[cfg(feature = "socket-tcp")]
@@ -513,7 +519,8 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
                     }
 
                     tcp_repr.emit(&mut TcpPacket::new(payload),
-                                  &ip_repr.src_addr(), &ip_repr.dst_addr());
+                                  &ip_repr.src_addr(), &ip_repr.dst_addr(),
+                                  &checksum_caps);
                 })
             }
             Packet::None => Ok(())
@@ -574,6 +581,7 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
     fn dispatch_ip<F>(&mut self, timestamp: u64, ip_repr: IpRepr, f: F) -> Result<()>
             where F: FnOnce(IpRepr, &mut [u8]) {
         let ip_repr = ip_repr.lower(&self.protocol_addrs)?;
+        let checksum_caps = self.device.capabilities().checksum;
 
         let dst_hardware_addr =
             self.lookup_hardware_addr(timestamp, &ip_repr.src_addr(), &ip_repr.dst_addr())?;
@@ -585,7 +593,7 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
                 _ => unreachable!()
             }
 
-            ip_repr.emit(frame.payload_mut());
+            ip_repr.emit(frame.payload_mut(), &checksum_caps);
 
             let payload = &mut frame.payload_mut()[ip_repr.buffer_len()..];
             f(ip_repr, payload)

+ 49 - 0
src/phy/mod.rs

@@ -135,6 +135,52 @@ pub use self::tap_interface::TapInterface;
 /// A tracer device for Ethernet frames.
 pub type EthernetTracer<T> = Tracer<T, super::wire::EthernetFrame<&'static [u8]>>;
 
+/// The checksum configuration for a device
+#[derive(Debug, Clone, Copy)]
+pub enum Checksum {
+    /// Validate checksum when receiving and supply checksum when sending
+    Both,
+    /// Validate checksum when receiving
+    Rx,
+    /// Supply checksum before sending
+    Tx,
+    /// Ignore checksum
+    None,
+}
+
+impl Default for Checksum {
+    fn default() -> Checksum {
+        Checksum::Both
+    }
+}
+
+impl Checksum {
+    pub(crate) fn rx(&self) -> bool {
+        match *self {
+            Checksum::Both | Checksum::Rx => true,
+            _ => false
+        }
+    }
+
+    pub(crate) fn tx(&self) -> bool {
+        match *self {
+            Checksum::Both | Checksum::Tx => true,
+            _ => false
+        }
+    }
+}
+
+/// Configuration of checksum capabilities for each applicable protocol
+#[derive(Debug, Clone, Default)]
+pub struct ChecksumCapabilities {
+    pub ipv4: Checksum,
+    pub udpv4: Checksum,
+    pub udpv6: Checksum,
+    pub tcpv4: Checksum,
+    pub icmpv4: Checksum,
+    dummy: (),
+}
+
 /// A description of device capabilities.
 ///
 /// Higher-level protocols may achieve higher throughput or lower latency if they consider
@@ -158,6 +204,9 @@ pub struct DeviceCapabilities {
     /// dynamically allocated.
     pub max_burst_size: Option<usize>,
 
+    /// Checksum capabilities for the current device
+    pub checksum: ChecksumCapabilities,
+
     /// Only present to prevent people from trying to initialize every field of DeviceLimits,
     /// which would not let us add new fields in the future.
     dummy: ()

+ 24 - 16
src/socket/raw.rs

@@ -2,6 +2,7 @@ use core::cmp::min;
 use managed::Managed;
 
 use {Error, Result};
+use phy::ChecksumCapabilities;
 use wire::{IpVersion, IpProtocol, Ipv4Repr, Ipv4Packet};
 use socket::{IpRepr, Socket};
 use storage::{Resettable, RingBuffer};
@@ -173,13 +174,14 @@ impl<'a, 'b> RawSocket<'a, 'b> {
         true
     }
 
-    pub(crate) fn process(&mut self, ip_repr: &IpRepr, payload: &[u8]) -> Result<()> {
+    pub(crate) fn process(&mut self, ip_repr: &IpRepr, payload: &[u8], 
+                          checksum_caps: &ChecksumCapabilities) -> Result<()> {
         debug_assert!(self.accepts(ip_repr));
 
         let header_len = ip_repr.buffer_len();
         let total_len = header_len + payload.len();
         let packet_buf = self.rx_buffer.enqueue_one_with(|buf| buf.resize(total_len))?;
-        ip_repr.emit(&mut packet_buf.as_mut()[..header_len]);
+        ip_repr.emit(&mut packet_buf.as_mut()[..header_len], &checksum_caps);
         packet_buf.as_mut()[header_len..].copy_from_slice(payload);
         net_trace!("[{}]:{}:{}: receiving {} octets",
                    self.debug_id, self.ip_version, self.ip_protocol,
@@ -187,17 +189,23 @@ impl<'a, 'b> RawSocket<'a, 'b> {
         Ok(())
     }
 
-    pub(crate) fn dispatch<F>(&mut self, emit: F) -> Result<()>
+    pub(crate) fn dispatch<F>(&mut self, emit: F, checksum_caps: &ChecksumCapabilities) -> Result<()>
             where F: FnOnce((IpRepr, &[u8])) -> Result<()> {
-        fn prepare(protocol: IpProtocol, buffer: &mut [u8]) -> Result<(IpRepr, &[u8])> {
+        fn prepare<'a>(protocol: IpProtocol, buffer: &'a mut [u8], 
+                   checksum_caps: &ChecksumCapabilities) -> Result<(IpRepr, &'a [u8])> {
             match IpVersion::of_packet(buffer.as_ref())? {
                 IpVersion::Ipv4 => {
                     let mut packet = Ipv4Packet::new_checked(buffer.as_mut())?;
                     if packet.protocol() != protocol { return Err(Error::Unaddressable) }
-                    packet.fill_checksum();
+                    if checksum_caps.ipv4.tx() {
+                        packet.fill_checksum();
+                    } else {
+                        // make sure we get a consistently zeroed checksum, since implementations might rely on it
+                        packet.set_checksum(0);
+                    }
 
                     let packet = Ipv4Packet::new(&*packet.into_inner());
-                    let ipv4_repr = Ipv4Repr::parse(&packet)?;
+                    let ipv4_repr = Ipv4Repr::parse(&packet, checksum_caps)?;
                     Ok((IpRepr::Ipv4(ipv4_repr), packet.payload()))
                 }
                 IpVersion::Unspecified => unreachable!(),
@@ -209,7 +217,7 @@ impl<'a, 'b> RawSocket<'a, 'b> {
         let ip_protocol = self.ip_protocol;
         let ip_version  = self.ip_version;
         self.tx_buffer.dequeue_one_with(|packet_buf| {
-            match prepare(ip_protocol, packet_buf.as_mut()) {
+            match prepare(ip_protocol, packet_buf.as_mut(), &checksum_caps) {
                 Ok((ip_repr, raw_packet)) => {
                     net_trace!("[{}]:{}:{}: sending {} octets",
                                debug_id, ip_version, ip_protocol,
@@ -289,7 +297,7 @@ mod test {
         let mut socket = socket(buffer(0), buffer(1));
 
         assert!(socket.can_send());
-        assert_eq!(socket.dispatch(|_| unreachable!()),
+        assert_eq!(socket.dispatch(|_| unreachable!(), &ChecksumCapabilities::default()),
                    Err(Error::Exhausted));
 
         assert_eq!(socket.send_slice(&PACKET_BYTES[..]), Ok(()));
@@ -300,14 +308,14 @@ mod test {
             assert_eq!(ip_repr, HEADER_REPR);
             assert_eq!(ip_payload, &PACKET_PAYLOAD);
             Err(Error::Unaddressable)
-        }), Err(Error::Unaddressable));
+        }, &ChecksumCapabilities::default()), Err(Error::Unaddressable));
         assert!(!socket.can_send());
 
         assert_eq!(socket.dispatch(|(ip_repr, ip_payload)| {
             assert_eq!(ip_repr, HEADER_REPR);
             assert_eq!(ip_payload, &PACKET_PAYLOAD);
             Ok(())
-        }), Ok(()));
+        }, &ChecksumCapabilities::default()), Ok(()));
         assert!(socket.can_send());
     }
 
@@ -319,14 +327,14 @@ mod test {
         Ipv4Packet::new(&mut wrong_version).set_version(5);
 
         assert_eq!(socket.send_slice(&wrong_version[..]), Ok(()));
-        assert_eq!(socket.dispatch(|_| unreachable!()),
+        assert_eq!(socket.dispatch(|_| unreachable!(), &ChecksumCapabilities::default()),
                    Ok(()));
 
         let mut wrong_protocol = PACKET_BYTES.clone();
         Ipv4Packet::new(&mut wrong_protocol).set_protocol(IpProtocol::Tcp);
 
         assert_eq!(socket.send_slice(&wrong_protocol[..]), Ok(()));
-        assert_eq!(socket.dispatch(|_| unreachable!()),
+        assert_eq!(socket.dispatch(|_| unreachable!(), &ChecksumCapabilities::default()),
                    Ok(()));
     }
 
@@ -340,12 +348,12 @@ mod test {
 
         assert_eq!(socket.recv(), Err(Error::Exhausted));
         assert!(socket.accepts(&HEADER_REPR));
-        assert_eq!(socket.process(&HEADER_REPR, &PACKET_PAYLOAD),
+        assert_eq!(socket.process(&HEADER_REPR, &PACKET_PAYLOAD, &ChecksumCapabilities::default()),
                    Ok(()));
         assert!(socket.can_recv());
 
         assert!(socket.accepts(&HEADER_REPR));
-        assert_eq!(socket.process(&HEADER_REPR, &PACKET_PAYLOAD),
+        assert_eq!(socket.process(&HEADER_REPR, &PACKET_PAYLOAD, &ChecksumCapabilities::default()),
                    Err(Error::Exhausted));
         assert_eq!(socket.recv(), Ok(&cksumd_packet[..]));
         assert!(!socket.can_recv());
@@ -356,7 +364,7 @@ mod test {
         let mut socket = socket(buffer(1), buffer(0));
 
         assert!(socket.accepts(&HEADER_REPR));
-        assert_eq!(socket.process(&HEADER_REPR, &PACKET_PAYLOAD),
+        assert_eq!(socket.process(&HEADER_REPR, &PACKET_PAYLOAD, &ChecksumCapabilities::default()),
                    Ok(()));
 
         let mut slice = [0; 4];
@@ -372,7 +380,7 @@ mod test {
         buffer[..PACKET_BYTES.len()].copy_from_slice(&PACKET_BYTES[..]);
 
         assert!(socket.accepts(&HEADER_REPR));
-        assert_eq!(socket.process(&HEADER_REPR, &buffer),
+        assert_eq!(socket.process(&HEADER_REPR, &buffer, &ChecksumCapabilities::default()),
                    Err(Error::Truncated));
     }
 

+ 20 - 7
src/wire/icmpv4.rs

@@ -2,6 +2,7 @@ use core::{cmp, fmt};
 use byteorder::{ByteOrder, NetworkEndian};
 
 use {Error, Result};
+use phy::ChecksumCapabilities;
 use super::ip::checksum;
 use super::{Ipv4Packet, Ipv4Repr};
 
@@ -384,7 +385,11 @@ pub enum Repr<'a> {
 impl<'a> Repr<'a> {
     /// Parse an Internet Control Message Protocol version 4 packet and return
     /// a high-level representation.
-    pub fn parse<T: AsRef<[u8]> + ?Sized>(packet: &Packet<&'a T>) -> Result<Repr<'a>> {
+    pub fn parse<T: AsRef<[u8]> + ?Sized>(packet: &Packet<&'a T>, checksum_caps: &ChecksumCapabilities) -> Result<Repr<'a>> {
+        if checksum_caps.icmpv4.rx() && !packet.verify_checksum() { 
+            return Err(Error::Checksum) 
+        }
+
         match (packet.msg_type(), packet.msg_code()) {
             (Message::EchoRequest, 0) => {
                 Ok(Repr::EchoRequest {
@@ -441,7 +446,9 @@ impl<'a> Repr<'a> {
 
     /// Emit a high-level representation into an Internet Control Message Protocol version 4
     /// packet.
-    pub fn emit<T: AsRef<[u8]> + AsMut<[u8]> + ?Sized>(&self, packet: &mut Packet<&mut T>) {
+    pub fn emit<T: AsRef<[u8]> + AsMut<[u8]> + ?Sized>(&self, 
+                                                       packet: &mut Packet<&mut T>, 
+                                                       checksum_caps: &ChecksumCapabilities) {
         packet.set_msg_code(0);
         match self {
             &Repr::EchoRequest { ident, seq_no, data } => {
@@ -467,20 +474,26 @@ impl<'a> Repr<'a> {
                 packet.set_msg_code(reason.into());
 
                 let mut ip_packet = Ipv4Packet::new(packet.data_mut());
-                header.emit(&mut ip_packet);
+                header.emit(&mut ip_packet, checksum_caps);
                 let payload = &mut ip_packet.into_inner()[header.buffer_len()..];
                 payload.copy_from_slice(&data[..])
             }
 
             &Repr::__Nonexhaustive => unreachable!()
         }
-        packet.fill_checksum()
+
+        if checksum_caps.icmpv4.tx() {
+            packet.fill_checksum()
+        } else {
+            // make sure we get a consistently zeroed checksum, since implementations might rely on it
+            packet.set_checksum(0);
+        }
     }
 }
 
 impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {
     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
-        match Repr::parse(self) {
+        match Repr::parse(self, &ChecksumCapabilities::default()) {
             Ok(repr) => write!(f, "{}", repr),
             Err(err) => {
                 write!(f, "ICMPv4 ({})", err)?;
@@ -580,7 +593,7 @@ mod test {
     #[test]
     fn test_echo_parse() {
         let packet = Packet::new(&ECHO_PACKET_BYTES[..]);
-        let repr = Repr::parse(&packet).unwrap();
+        let repr = Repr::parse(&packet, &ChecksumCapabilities::default()).unwrap();
         assert_eq!(repr, echo_packet_repr());
     }
 
@@ -589,7 +602,7 @@ mod test {
         let repr = echo_packet_repr();
         let mut bytes = vec![0xa5; repr.buffer_len()];
         let mut packet = Packet::new(&mut bytes);
-        repr.emit(&mut packet);
+        repr.emit(&mut packet, &ChecksumCapabilities::default());
         assert_eq!(&packet.into_inner()[..], &ECHO_PACKET_BYTES[..]);
     }
 }

+ 3 - 2
src/wire/ip.rs

@@ -1,6 +1,7 @@
 use core::fmt;
 
 use {Error, Result};
+use phy::ChecksumCapabilities;
 use super::{Ipv4Address, Ipv4Packet, Ipv4Repr};
 
 /// Internet protocol version.
@@ -338,12 +339,12 @@ impl IpRepr {
     ///
     /// # Panics
     /// This function panics if invoked on an unspecified representation.
-    pub fn emit<T: AsRef<[u8]> + AsMut<[u8]>>(&self, buffer: T) {
+    pub fn emit<T: AsRef<[u8]> + AsMut<[u8]>>(&self, buffer: T, checksum_caps: &ChecksumCapabilities) {
         match self {
             &IpRepr::Unspecified { .. } =>
                 panic!("unspecified IP representation"),
             &IpRepr::Ipv4(repr) =>
-                repr.emit(&mut Ipv4Packet::new(buffer)),
+                repr.emit(&mut Ipv4Packet::new(buffer), &checksum_caps),
             &IpRepr::__Nonexhaustive =>
                 unreachable!()
         }

+ 23 - 11
src/wire/ipv4.rs

@@ -2,6 +2,7 @@ use core::fmt;
 use byteorder::{ByteOrder, NetworkEndian};
 
 use {Error, Result};
+use phy::ChecksumCapabilities;
 use super::ip::checksum;
 use super::IpAddress;
 
@@ -401,11 +402,14 @@ pub struct Repr {
 
 impl Repr {
     /// Parse an Internet Protocol version 4 packet and return a high-level representation.
-    pub fn parse<T: AsRef<[u8]> + ?Sized>(packet: &Packet<&T>) -> Result<Repr> {
+    pub fn parse<T: AsRef<[u8]> + ?Sized>(packet: &Packet<&T>, 
+                                          checksum_caps: &ChecksumCapabilities) -> Result<Repr> {
         // Version 4 is expected.
         if packet.version() != 4 { return Err(Error::Malformed) }
         // Valid checksum is expected.
-        if !packet.verify_checksum() { return Err(Error::Checksum) }
+        if checksum_caps.ipv4.rx() {
+            if !packet.verify_checksum() { return Err(Error::Checksum) }
+        }
         // We do not support fragmentation.
         if packet.more_frags() || packet.frag_offset() != 0 { return Err(Error::Fragmented) }
         // Total length may not be less than header length.
@@ -432,7 +436,7 @@ impl Repr {
     }
 
     /// Emit a high-level representation into an Internet Protocol version 4 packet.
-    pub fn emit<T: AsRef<[u8]> + AsMut<[u8]>>(&self, packet: &mut Packet<T>) {
+    pub fn emit<T: AsRef<[u8]> + AsMut<[u8]>>(&self, packet: &mut Packet<T>, checksum_caps: &ChecksumCapabilities) {
         packet.set_version(4);
         packet.set_header_len(field::DST_ADDR.end as u8);
         packet.set_dscp(0);
@@ -448,13 +452,19 @@ impl Repr {
         packet.set_protocol(self.protocol);
         packet.set_src_addr(self.src_addr);
         packet.set_dst_addr(self.dst_addr);
-        packet.fill_checksum();
+
+        if checksum_caps.ipv4.tx() {
+            packet.fill_checksum();
+        } else {
+            // make sure we get a consistently zeroed checksum, since implementations might rely on it
+            packet.set_checksum(0);
+        }
     }
 }
 
 impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {
     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
-        match Repr::parse(self) {
+        match Repr::parse(self, &ChecksumCapabilities::default()) {
             Ok(repr) => write!(f, "{}", repr),
             Err(err) => {
                 write!(f, "IPv4 ({})", err)?;
@@ -507,7 +517,7 @@ impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> {
             Err(err) => return write!(f, "{}({})\n", indent, err),
             Ok(ip_packet) => {
                 write!(f, "{}{}\n", indent, ip_packet)?;
-                match Repr::parse(&ip_packet) {
+                match Repr::parse(&ip_packet, &ChecksumCapabilities::default()) {
                     Err(_) => return Ok(()),
                     Ok(ip_repr) => (ip_repr, &ip_packet.payload()[..ip_repr.payload_len])
                 }
@@ -524,7 +534,8 @@ impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> {
                     Ok(udp_packet) => {
                         match super::UdpRepr::parse(&udp_packet,
                                                     &IpAddress::from(ip_repr.src_addr),
-                                                    &IpAddress::from(ip_repr.dst_addr)) {
+                                                    &IpAddress::from(ip_repr.dst_addr),
+                                                    &ChecksumCapabilities::default()) {
                             Err(err) => write!(f, "{}{} ({})\n", indent, udp_packet, err),
                             Ok(udp_repr) => write!(f, "{}{}\n", indent, udp_repr)
                         }
@@ -537,7 +548,8 @@ impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> {
                     Ok(tcp_packet) => {
                         match super::TcpRepr::parse(&tcp_packet,
                                                     &IpAddress::from(ip_repr.src_addr),
-                                                    &IpAddress::from(ip_repr.dst_addr)) {
+                                                    &IpAddress::from(ip_repr.dst_addr),
+                                                    &ChecksumCapabilities::default()) {
                             Err(err) => write!(f, "{}{} ({})\n", indent, tcp_packet, err),
                             Ok(tcp_repr) => write!(f, "{}{}\n", indent, tcp_repr)
                         }
@@ -647,7 +659,7 @@ mod test {
     #[test]
     fn test_parse() {
         let packet = Packet::new(&REPR_PACKET_BYTES[..]);
-        let repr = Repr::parse(&packet).unwrap();
+        let repr = Repr::parse(&packet, &ChecksumCapabilities::default()).unwrap();
         assert_eq!(repr, packet_repr());
     }
 
@@ -659,7 +671,7 @@ mod test {
         packet.set_total_len(10);
         packet.fill_checksum();
         let packet = Packet::new(&*packet.into_inner());
-        assert_eq!(Repr::parse(&packet), Err(Error::Malformed));
+        assert_eq!(Repr::parse(&packet, &ChecksumCapabilities::default()), Err(Error::Malformed));
     }
 
     #[test]
@@ -667,7 +679,7 @@ mod test {
         let repr = packet_repr();
         let mut bytes = vec![0xa5; repr.buffer_len() + REPR_PAYLOAD_BYTES.len()];
         let mut packet = Packet::new(&mut bytes);
-        repr.emit(&mut packet);
+        repr.emit(&mut packet, &ChecksumCapabilities::default());
         packet.payload_mut().copy_from_slice(&REPR_PAYLOAD_BYTES);
         assert_eq!(&packet.into_inner()[..], &REPR_PACKET_BYTES[..]);
     }

+ 3 - 2
src/wire/mod.rs

@@ -46,6 +46,7 @@
 //!
 /*!
 ```rust
+use smoltcp::phy::ChecksumCapabilities;
 use smoltcp::wire::*;
 let repr = Ipv4Repr {
     src_addr:    Ipv4Address::new(10, 0, 0, 1),
@@ -56,12 +57,12 @@ let repr = Ipv4Repr {
 let mut buffer = vec![0; repr.buffer_len() + repr.payload_len];
 { // emission
     let mut packet = Ipv4Packet::new(&mut buffer);
-    repr.emit(&mut packet);
+    repr.emit(&mut packet, &ChecksumCapabilities::default());
 }
 { // parsing
     let packet = Ipv4Packet::new_checked(&buffer)
                             .expect("truncated packet");
-    let parsed = Ipv4Repr::parse(&packet)
+    let parsed = Ipv4Repr::parse(&packet, &ChecksumCapabilities::default())
                           .expect("malformed packet");
     assert_eq!(repr, parsed);
 }

+ 19 - 6
src/wire/tcp.rs

@@ -2,6 +2,7 @@ use core::{i32, ops, cmp, fmt};
 use byteorder::{ByteOrder, NetworkEndian};
 
 use {Error, Result};
+use phy::ChecksumCapabilities;
 use super::{IpProtocol, IpAddress};
 use super::ip::checksum;
 
@@ -644,13 +645,17 @@ impl<'a> Repr<'a> {
     /// Parse a Transmission Control Protocol packet and return a high-level representation.
     pub fn parse<T: ?Sized>(packet: &Packet<&'a T>,
                             src_addr: &IpAddress,
-                            dst_addr: &IpAddress) -> Result<Repr<'a>>
+                            dst_addr: &IpAddress,
+                            checksum_caps: &ChecksumCapabilities) -> Result<Repr<'a>>
             where T: AsRef<[u8]> {
         // Source and destination ports must be present.
         if packet.src_port() == 0 { return Err(Error::Malformed) }
         if packet.dst_port() == 0 { return Err(Error::Malformed) }
+        
         // Valid checksum is expected...
-        if !packet.verify_checksum(src_addr, dst_addr) { return Err(Error::Checksum) }
+        if checksum_caps.tcpv4.rx() && !packet.verify_checksum(src_addr, dst_addr) { 
+            return Err(Error::Checksum) 
+        }
 
         let control =
             match (packet.syn(), packet.fin(), packet.rst(), packet.psh()) {
@@ -713,7 +718,9 @@ impl<'a> Repr<'a> {
 
     /// Emit a high-level representation into a Transmission Control Protocol packet.
     pub fn emit<T>(&self, packet: &mut Packet<&mut T>,
-                   src_addr: &IpAddress, dst_addr: &IpAddress)
+                          src_addr: &IpAddress, 
+                          dst_addr: &IpAddress,
+                          checksum_caps: &ChecksumCapabilities)
             where T: AsRef<[u8]> + AsMut<[u8]> + ?Sized {
         packet.set_src_port(self.src_port);
         packet.set_dst_port(self.dst_port);
@@ -741,7 +748,13 @@ impl<'a> Repr<'a> {
         }
         packet.set_urgent_at(0);
         packet.payload_mut().copy_from_slice(self.payload);
-        packet.fill_checksum(src_addr, dst_addr)
+        
+        if checksum_caps.tcpv4.tx() {
+            packet.fill_checksum(src_addr, dst_addr)
+        } else {
+            // make sure we get a consistently zeroed checksum, since implementations might rely on it
+            packet.set_checksum(0);
+        }
     }
 
     /// Return the length of the segment, in terms of sequence space.
@@ -948,7 +961,7 @@ mod test {
     #[test]
     fn test_parse() {
         let packet = Packet::new(&SYN_PACKET_BYTES[..]);
-        let repr = Repr::parse(&packet, &SRC_ADDR.into(), &DST_ADDR.into()).unwrap();
+        let repr = Repr::parse(&packet, &SRC_ADDR.into(), &DST_ADDR.into(), &ChecksumCapabilities::default()).unwrap();
         assert_eq!(repr, packet_repr());
     }
 
@@ -957,7 +970,7 @@ mod test {
         let repr = packet_repr();
         let mut bytes = vec![0xa5; repr.buffer_len()];
         let mut packet = Packet::new(&mut bytes);
-        repr.emit(&mut packet, &SRC_ADDR.into(), &DST_ADDR.into());
+        repr.emit(&mut packet, &SRC_ADDR.into(), &DST_ADDR.into(), &ChecksumCapabilities::default());
         assert_eq!(&packet.into_inner()[..], &SYN_PACKET_BYTES[..]);
     }
 

+ 16 - 6
src/wire/udp.rs

@@ -2,6 +2,7 @@ use core::fmt;
 use byteorder::{ByteOrder, NetworkEndian};
 
 use {Error, Result};
+use phy::ChecksumCapabilities;
 use super::{IpProtocol, IpAddress};
 use super::ip::checksum;
 
@@ -203,12 +204,14 @@ impl<'a> Repr<'a> {
     /// Parse an User Datagram Protocol packet and return a high-level representation.
     pub fn parse<T: ?Sized>(packet: &Packet<&'a T>,
                             src_addr: &IpAddress,
-                            dst_addr: &IpAddress) -> Result<Repr<'a>>
+                            dst_addr: &IpAddress,
+                            checksum_caps: &ChecksumCapabilities) -> Result<Repr<'a>>
             where T: AsRef<[u8]> {
         // Destination port cannot be omitted (but source port can be).
         if packet.dst_port() == 0 { return Err(Error::Malformed) }
+
         // Valid checksum is expected...
-        if !packet.verify_checksum(src_addr, dst_addr) {
+        if checksum_caps.udpv4.rx() && !packet.verify_checksum(src_addr, dst_addr) {
             match (src_addr, dst_addr) {
                 (&IpAddress::Ipv4(_), &IpAddress::Ipv4(_))
                         if packet.checksum() != 0 => {
@@ -236,13 +239,20 @@ impl<'a> Repr<'a> {
     /// Emit a high-level representation into an User Datagram Protocol packet.
     pub fn emit<T: ?Sized>(&self, packet: &mut Packet<&mut T>,
                            src_addr: &IpAddress,
-                           dst_addr: &IpAddress)
+                           dst_addr: &IpAddress,
+                           checksum_caps: &ChecksumCapabilities)
             where T: AsRef<[u8]> + AsMut<[u8]> {
         packet.set_src_port(self.src_port);
         packet.set_dst_port(self.dst_port);
         packet.set_len((field::CHECKSUM.end + self.payload.len()) as u16);
         packet.payload_mut().copy_from_slice(self.payload);
-        packet.fill_checksum(src_addr, dst_addr)
+
+        if checksum_caps.udpv4.tx() {
+            packet.fill_checksum(src_addr, dst_addr)
+        } else {
+            // make sure we get a consistently zeroed checksum, since implementations might rely on it
+            packet.set_checksum(0);
+        }
     }
 }
 
@@ -343,7 +353,7 @@ mod test {
     #[test]
     fn test_parse() {
         let packet = Packet::new(&PACKET_BYTES[..]);
-        let repr = Repr::parse(&packet, &SRC_ADDR.into(), &DST_ADDR.into()).unwrap();
+        let repr = Repr::parse(&packet, &SRC_ADDR.into(), &DST_ADDR.into(), &ChecksumCapabilities::default()).unwrap();
         assert_eq!(repr, packet_repr());
     }
 
@@ -352,7 +362,7 @@ mod test {
         let repr = packet_repr();
         let mut bytes = vec![0xa5; repr.buffer_len()];
         let mut packet = Packet::new(&mut bytes);
-        repr.emit(&mut packet, &SRC_ADDR.into(), &DST_ADDR.into());
+        repr.emit(&mut packet, &SRC_ADDR.into(), &DST_ADDR.into(), &ChecksumCapabilities::default());
         assert_eq!(&packet.into_inner()[..], &PACKET_BYTES[..]);
     }
 }