Sfoglia il codice sorgente

Get rid of IpPayload and indirection in Socket::dispatch.

This was just completely pointless, and only served to obfuscate
the data path and make testing harder.
whitequark 7 anni fa
parent
commit
ab0eccd213
5 ha cambiato i file con 88 aggiunte e 160 eliminazioni
  1. 38 23
      src/iface/ethernet.rs
  2. 0 18
      src/socket/mod.rs
  3. 23 47
      src/socket/raw.rs
  4. 13 36
      src/socket/tcp.rs
  5. 14 36
      src/socket/udp.rs

+ 38 - 23
src/iface/ethernet.rs

@@ -7,8 +7,8 @@ use wire::{ArpPacket, ArpRepr, ArpOperation};
 use wire::{Ipv4Packet, Ipv4Repr};
 use wire::{Icmpv4Packet, Icmpv4Repr, Icmpv4DstUnreachable};
 use wire::{IpAddress, IpProtocol, IpRepr};
-use wire::{TcpPacket, TcpRepr, TcpControl};
-use socket::{Socket, SocketSet, RawSocket, TcpSocket, UdpSocket, AsSocket, IpPayload};
+use wire::{UdpPacket, UdpRepr, TcpPacket, TcpRepr, TcpControl};
+use socket::{Socket, SocketSet, RawSocket, TcpSocket, UdpSocket, AsSocket};
 use super::ArpCache;
 
 /// An Ethernet network interface.
@@ -27,8 +27,9 @@ enum Response<'a> {
     Nop,
     Arp(ArpRepr),
     Icmpv4(Ipv4Repr, Icmpv4Repr<'a>),
-    Tcp((IpRepr, TcpRepr<'a>)),
-    Payload(IpRepr, &'a IpPayload)
+    Raw((IpRepr, &'a [u8])),
+    Udp((IpRepr, UdpRepr<'a>)),
+    Tcp((IpRepr, TcpRepr<'a>))
 }
 
 impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
@@ -341,9 +342,18 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
 
         let mut nothing_to_transmit = true;
         for socket in sockets.iter_mut() {
-            let result = socket.dispatch(timestamp, &limits, |repr, payload| {
-                self.dispatch_response(timestamp, Response::Payload(repr.clone(), payload))
-            });
+            let result = match socket {
+                &mut Socket::Raw(ref mut socket) =>
+                    socket.dispatch(timestamp, &limits, |response|
+                        self.dispatch_response(timestamp, Response::Raw(response))),
+                &mut Socket::Udp(ref mut socket) =>
+                    socket.dispatch(timestamp, &limits, |response|
+                        self.dispatch_response(timestamp, Response::Udp(response))),
+                &mut Socket::Tcp(ref mut socket) =>
+                    socket.dispatch(timestamp, &limits, |response|
+                        self.dispatch_response(timestamp, Response::Tcp(response))),
+                &mut Socket::__Nonexhaustive => unreachable!()
+            };
 
             match result {
                 Ok(()) => {
@@ -373,12 +383,12 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
                 Ok(())
             });
 
-            (Ip, $ip_repr:expr, |$payload:ident| $code:stmt) => ({
-                let ip_repr = $ip_repr.lower(&self.protocol_addrs)?;
+            (Ip, $ip_repr_unspec:expr, |$ip_repr:ident, $payload:ident| $code:stmt) => ({
+                let $ip_repr = $ip_repr_unspec.lower(&self.protocol_addrs)?;
 
-                match self.arp_cache.lookup(&ip_repr.dst_addr()) {
+                match self.arp_cache.lookup(&$ip_repr.dst_addr()) {
                     None => {
-                        match (ip_repr.src_addr(), ip_repr.dst_addr()) {
+                        match ($ip_repr.src_addr(), $ip_repr.dst_addr()) {
                             (IpAddress::Ipv4(src_addr), IpAddress::Ipv4(dst_addr)) => {
                                 net_debug!("address {} not in ARP cache, sending request",
                                            dst_addr);
@@ -402,16 +412,16 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
                         }
                     },
                     Some(dst_hardware_addr) => {
-                        emit_packet!(Ethernet, ip_repr.total_len(), |frame| {
+                        emit_packet!(Ethernet, $ip_repr.total_len(), |frame| {
                             frame.set_dst_addr(dst_hardware_addr);
-                            match ip_repr {
+                            match $ip_repr {
                                 IpRepr::Ipv4(_) => frame.set_ethertype(EthernetProtocol::Ipv4),
                                 _ => unreachable!()
                             }
 
-                            ip_repr.emit(frame.payload_mut());
+                            $ip_repr.emit(frame.payload_mut());
 
-                            let $payload = &mut frame.payload_mut()[ip_repr.buffer_len()..];
+                            let $payload = &mut frame.payload_mut()[$ip_repr.buffer_len()..];
                             $code
                         })
                     }
@@ -436,20 +446,25 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
                 })
             },
             Response::Icmpv4(ipv4_repr, icmpv4_repr) => {
-                emit_packet!(Ip, IpRepr::Ipv4(ipv4_repr), |payload| {
+                emit_packet!(Ip, IpRepr::Ipv4(ipv4_repr), |ip_repr, payload| {
                     icmpv4_repr.emit(&mut Icmpv4Packet::new(payload));
                 })
             }
-            Response::Tcp((ip_repr, tcp_repr)) => {
-                emit_packet!(Ip, ip_repr, |payload| {
-                    tcp_repr.emit(&mut TcpPacket::new(payload),
+            Response::Raw((ip_repr, raw_packet)) => {
+                emit_packet!(Ip, ip_repr, |ip_repr, payload| {
+                    payload.copy_from_slice(raw_packet);
+                })
+            }
+            Response::Udp((ip_repr, udp_repr)) => {
+                emit_packet!(Ip, ip_repr, |ip_repr, payload| {
+                    udp_repr.emit(&mut UdpPacket::new(payload),
                                   &ip_repr.src_addr(), &ip_repr.dst_addr());
                 })
             }
-            Response::Payload(ip_repr, ip_payload) => {
-                let ip_repr = ip_repr.lower(&self.protocol_addrs)?;
-                emit_packet!(Ip, ip_repr, |payload| {
-                    ip_payload.emit(&ip_repr, payload);
+            Response::Tcp((ip_repr, tcp_repr)) => {
+                emit_packet!(Ip, ip_repr, |ip_repr, payload| {
+                    tcp_repr.emit(&mut TcpPacket::new(payload),
+                                  &ip_repr.src_addr(), &ip_repr.dst_addr());
                 })
             }
             Response::Nop => Ok(())

+ 0 - 18
src/socket/mod.rs

@@ -80,24 +80,6 @@ impl<'a, 'b> Socket<'a, 'b> {
     pub fn set_debug_id(&mut self, id: usize) {
         dispatch_socket!(self, |socket [mut]| socket.set_debug_id(id))
     }
-
-    pub(crate) fn dispatch<F, R>(&mut self, timestamp: u64, limits: &DeviceLimits,
-                                 emit: F) -> Result<R>
-            where F: FnOnce(&IpRepr, &IpPayload) -> Result<R> {
-        dispatch_socket!(self, |socket [mut]| socket.dispatch(timestamp, limits, emit))
-    }
-}
-
-/// An IP-encapsulated packet representation.
-///
-/// This trait abstracts the various types of packets layered under the IP protocol,
-/// and serves as an accessory to [trait Socket](trait.Socket.html).
-pub trait IpPayload {
-    /// Return the length of the buffer required to serialize this high-level representation.
-    fn buffer_len(&self) -> usize;
-
-    /// Emit this high-level representation into a sequence of octets.
-    fn emit(&self, ip_repr: &IpRepr, payload: &mut [u8]);
 }
 
 /// A conversion trait for network sockets.

+ 23 - 47
src/socket/raw.rs

@@ -4,7 +4,7 @@ use managed::Managed;
 use {Error, Result};
 use phy::DeviceLimits;
 use wire::{IpVersion, IpProtocol, Ipv4Repr, Ipv4Packet};
-use socket::{IpRepr, IpPayload, Socket};
+use socket::{IpRepr, Socket};
 use storage::{Resettable, RingBuffer};
 
 /// A buffered raw IP packet.
@@ -183,10 +183,10 @@ impl<'a, 'b> RawSocket<'a, 'b> {
         Ok(())
     }
 
-    pub(crate) fn dispatch<F, R>(&mut self, _timestamp: u64, _limits: &DeviceLimits,
-                                 emit: F) -> Result<R>
-            where F: FnOnce(&IpRepr, &IpPayload) -> Result<R> {
-        fn prepare(protocol: IpProtocol, buffer: &mut [u8]) -> Result<(IpRepr, RawRepr)> {
+    pub(crate) fn dispatch<F>(&mut self, _timestamp: u64, _limits: &DeviceLimits,
+                              emit: F) -> Result<()>
+            where F: FnOnce((IpRepr, &[u8])) -> Result<()> {
+        fn prepare(protocol: IpProtocol, buffer: &mut [u8]) -> Result<(IpRepr, &[u8])> {
             match IpVersion::of_packet(buffer.as_ref())? {
                 IpVersion::Ipv4 => {
                     let mut packet = Ipv4Packet::new_checked(buffer.as_mut())?;
@@ -195,8 +195,7 @@ impl<'a, 'b> RawSocket<'a, 'b> {
 
                     let packet = Ipv4Packet::new(&*packet.into_inner());
                     let ipv4_repr = Ipv4Repr::parse(&packet)?;
-                    let raw_repr = RawRepr(packet.payload());
-                    Ok((IpRepr::Ipv4(ipv4_repr), raw_repr))
+                    Ok((IpRepr::Ipv4(ipv4_repr), packet.payload()))
                 }
                 IpVersion::Unspecified => unreachable!(),
                 IpVersion::__Nonexhaustive => unreachable!()
@@ -205,14 +204,14 @@ impl<'a, 'b> RawSocket<'a, 'b> {
 
         let mut packet_buf = self.tx_buffer.dequeue()?;
         match prepare(self.ip_protocol, packet_buf.as_mut()) {
-            Ok((ip_repr, raw_repr)) => {
+            Ok((ip_repr, raw_packet)) => {
                 net_trace!("[{}]:{}:{}: sending {} octets",
                            self.debug_id, self.ip_version, self.ip_protocol,
-                           ip_repr.buffer_len() + raw_repr.buffer_len());
-                emit(&ip_repr, &raw_repr)
+                           ip_repr.buffer_len() + raw_packet.len());
+                emit((ip_repr, raw_packet))
             }
             Err(error) => {
-                net_trace!("[{}]:{}:{}: dropping outgoing packet ({})",
+                net_debug!("[{}]:{}:{}: dropping outgoing packet ({})",
                            self.debug_id, self.ip_version, self.ip_protocol,
                            error);
                 // This case is a bit special because in every other socket, no matter what data
@@ -225,18 +224,6 @@ impl<'a, 'b> RawSocket<'a, 'b> {
     }
 }
 
-struct RawRepr<'a>(&'a [u8]);
-
-impl<'a> IpPayload for RawRepr<'a> {
-    fn buffer_len(&self) -> usize {
-        self.0.len()
-    }
-
-    fn emit(&self, _repr: &IpRepr, payload: &mut [u8]) {
-        payload.copy_from_slice(self.0);
-    }
-}
-
 #[cfg(test)]
 mod test {
     use wire::{Ipv4Address, IpRepr, Ipv4Repr};
@@ -291,32 +278,23 @@ mod test {
         let mut socket = socket(buffer(0), buffer(1));
 
         assert!(socket.can_send());
-        assert_eq!(socket.dispatch(0, &limits, |_ip_repr, _ip_payload| {
-            unreachable!()
-        }), Err(Error::Exhausted) as Result<()>);
+        assert_eq!(socket.dispatch(0, &limits, |_| unreachable!()),
+                   Err(Error::Exhausted));
 
         assert_eq!(socket.send_slice(&PACKET_BYTES[..]), Ok(()));
         assert_eq!(socket.send_slice(b""), Err(Error::Exhausted));
         assert!(!socket.can_send());
 
-        macro_rules! assert_payload_eq {
-            ($ip_repr:expr, $ip_payload:expr, $expected:expr) => {{
-                let mut buffer = vec![0; $ip_payload.buffer_len()];
-                $ip_payload.emit(&$ip_repr, &mut buffer);
-                assert_eq!(&buffer[..], &$expected[$ip_repr.buffer_len()..]);
-            }}
-        }
-
-        assert_eq!(socket.dispatch(0, &limits, |ip_repr, ip_payload| {
-            assert_eq!(ip_repr, &HEADER_REPR);
-            assert_payload_eq!(ip_repr, ip_payload, PACKET_BYTES);
+        assert_eq!(socket.dispatch(0, &limits, |(ip_repr, ip_payload)| {
+            assert_eq!(ip_repr, HEADER_REPR);
+            assert_eq!(ip_payload, &PACKET_PAYLOAD);
             Err(Error::Unaddressable)
-        }), Err(Error::Unaddressable) as Result<()>);
+        }), Err(Error::Unaddressable));
         /*assert!(!socket.can_send());*/
 
-        assert_eq!(socket.dispatch(0, &limits, |ip_repr, ip_payload| {
-            assert_eq!(ip_repr, &HEADER_REPR);
-            assert_payload_eq!(ip_repr, ip_payload, PACKET_BYTES);
+        assert_eq!(socket.dispatch(0, &limits, |(ip_repr, ip_payload)| {
+            assert_eq!(ip_repr, HEADER_REPR);
+            assert_eq!(ip_payload, &PACKET_PAYLOAD);
             Ok(())
         }), /*Ok(())*/ Err(Error::Exhausted));
         assert!(socket.can_send());
@@ -332,17 +310,15 @@ mod test {
         Ipv4Packet::new(&mut wrong_version).set_version(5);
 
         assert_eq!(socket.send_slice(&wrong_version[..]), Ok(()));
-        assert_eq!(socket.dispatch(0, &limits, |_ip_repr, _ip_payload| {
-            unreachable!()
-        }), Err(Error::Rejected) as Result<()>);
+        assert_eq!(socket.dispatch(0, &limits, |_| unreachable!()),
+                   Err(Error::Rejected));
 
         let mut wrong_protocol = PACKET_BYTES.clone();
         Ipv4Packet::new(&mut wrong_protocol).set_protocol(IpProtocol::Tcp);
 
         assert_eq!(socket.send_slice(&wrong_protocol[..]), Ok(()));
-        assert_eq!(socket.dispatch(0, &limits, |_ip_repr, _ip_payload| {
-            unreachable!()
-        }), Err(Error::Rejected) as Result<()>);
+        assert_eq!(socket.dispatch(0, &limits, |_| unreachable!()),
+                   Err(Error::Rejected));
     }
 
     #[test]

+ 13 - 36
src/socket/tcp.rs

@@ -8,7 +8,7 @@ use {Error, Result};
 use phy::DeviceLimits;
 use wire::{IpProtocol, IpAddress, IpEndpoint};
 use wire::{TcpSeqNumber, TcpPacket, TcpRepr, TcpControl};
-use socket::{Socket, IpRepr, IpPayload};
+use socket::{Socket, IpRepr};
 
 /// A TCP stream ring buffer.
 #[derive(Debug)]
@@ -1090,9 +1090,9 @@ impl<'a> TcpSocket<'a> {
         }
     }
 
-    pub(crate) fn dispatch<F, R>(&mut self, timestamp: u64, limits: &DeviceLimits,
-                                 emit: F) -> Result<R>
-            where F: FnOnce(&IpRepr, &IpPayload) -> Result<R> {
+    pub(crate) fn dispatch<F>(&mut self, timestamp: u64, limits: &DeviceLimits,
+                              emit: F) -> Result<()>
+            where F: FnOnce((IpRepr, TcpRepr)) -> Result<()> {
         if !self.remote_endpoint.is_specified() { return Err(Error::Exhausted) }
 
         if let Some(retransmit_delta) = self.timer.should_retransmit(timestamp) {
@@ -1245,7 +1245,7 @@ impl<'a> TcpSocket<'a> {
             }
         }
 
-        let result = emit(&ip_repr, &repr)?;
+        emit((ip_repr, repr))?;
 
         // We've sent a packet successfully, so we can update the internal state now.
         self.remote_next_seq = repr.seq_number + repr.segment_len();
@@ -1265,7 +1265,7 @@ impl<'a> TcpSocket<'a> {
             self.remote_endpoint = IpEndpoint::default();
         }
 
-        Ok(result)
+        Ok(())
     }
 }
 
@@ -1280,17 +1280,6 @@ impl<'a> fmt::Write for TcpSocket<'a> {
     }
 }
 
-impl<'a> IpPayload for TcpRepr<'a> {
-    fn buffer_len(&self) -> usize {
-        self.buffer_len()
-    }
-
-    fn emit(&self, ip_repr: &IpRepr, payload: &mut [u8]) {
-        let mut packet = TcpPacket::new(payload);
-        self.emit(&mut packet, &ip_repr.src_addr(), &ip_repr.dst_addr())
-    }
-}
-
 #[cfg(test)]
 mod test {
     use wire::{IpAddress, Ipv4Address};
@@ -1389,24 +1378,18 @@ mod test {
 
     fn recv<F>(socket: &mut TcpSocket, timestamp: u64, mut f: F)
             where F: FnMut(Result<TcpRepr>) {
-        let mut buffer = vec![];
         let mut limits = DeviceLimits::default();
         limits.max_transmission_unit = 1520;
-        let result = socket.dispatch(timestamp, &limits, |ip_repr, payload| {
+        let result = socket.dispatch(timestamp, &limits, |(ip_repr, tcp_repr)| {
             let ip_repr = ip_repr.lower(&[LOCAL_END.addr.into()]).unwrap();
 
             assert_eq!(ip_repr.protocol(), IpProtocol::Tcp);
             assert_eq!(ip_repr.src_addr(), LOCAL_IP);
             assert_eq!(ip_repr.dst_addr(), REMOTE_IP);
 
-            buffer.resize(payload.buffer_len(), 0);
-            payload.emit(&ip_repr, &mut buffer[..]);
-            let packet = TcpPacket::new(&buffer[..]);
-            let repr = TcpRepr::parse(&packet, &ip_repr.src_addr(), &ip_repr.dst_addr())?;
-            trace!("recv: {}", repr);
-            Ok(f(Ok(repr)))
+            trace!("recv: {}", tcp_repr);
+            Ok(f(Ok(tcp_repr)))
         });
-        // Appease borrow checker.
         match result {
             Ok(()) => (),
             Err(e) => f(Err(e))
@@ -2869,21 +2852,15 @@ mod test {
 
         limits.max_burst_size = None;
         s.send_slice(b"abcdef").unwrap();
-        s.dispatch(0, &limits, |ip_repr, payload| {
-            let mut buffer = vec![0; payload.buffer_len()];
-            payload.emit(&ip_repr, &mut buffer[..]);
-            let packet = TcpPacket::new(&buffer[..]);
-            assert_eq!(packet.window_len(), 32767);
+        s.dispatch(0, &limits, |(ip_repr, tcp_repr)| {
+            assert_eq!(tcp_repr.window_len, 32767);
             Ok(())
         }).unwrap();
 
         limits.max_burst_size = Some(4);
         s.send_slice(b"abcdef").unwrap();
-        s.dispatch(0, &limits, |ip_repr, payload| {
-            let mut buffer = vec![0; payload.buffer_len()];
-            payload.emit(&ip_repr, &mut buffer[..]);
-            let packet = TcpPacket::new(&buffer[..]);
-            assert_eq!(packet.window_len(), 5920);
+        s.dispatch(0, &limits, |(ip_repr, tcp_repr)| {
+            assert_eq!(tcp_repr.window_len, 5920);
             Ok(())
         }).unwrap();
     }

+ 14 - 36
src/socket/udp.rs

@@ -5,7 +5,7 @@ use {Error, Result};
 use phy::DeviceLimits;
 use wire::{IpProtocol, IpEndpoint};
 use wire::{UdpPacket, UdpRepr};
-use socket::{Socket, IpRepr, IpPayload};
+use socket::{Socket, IpRepr};
 use storage::{Resettable, RingBuffer};
 
 /// A buffered UDP packet.
@@ -202,9 +202,9 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
         Ok(())
     }
 
-    pub(crate) fn dispatch<F, R>(&mut self, _timestamp: u64, _limits: &DeviceLimits,
-                                 emit: F) -> Result<R>
-            where F: FnOnce(&IpRepr, &IpPayload) -> Result<R> {
+    pub(crate) fn dispatch<F>(&mut self, _timestamp: u64, _limits: &DeviceLimits,
+                              emit: F) -> Result<()>
+            where F: FnOnce((IpRepr, UdpRepr)) -> Result<()> {
         let packet_buf = self.tx_buffer.dequeue()?;
         net_trace!("[{}]{}:{}: sending {} octets",
                    self.debug_id, self.endpoint,
@@ -221,18 +221,7 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
             protocol:    IpProtocol::Udp,
             payload_len: repr.buffer_len()
         };
-        emit(&ip_repr, &repr)
-    }
-}
-
-impl<'a> IpPayload for UdpRepr<'a> {
-    fn buffer_len(&self) -> usize {
-        self.buffer_len()
-    }
-
-    fn emit(&self, repr: &IpRepr, payload: &mut [u8]) {
-        let mut packet = UdpPacket::new(payload);
-        self.emit(&mut packet, &repr.src_addr(), &repr.dst_addr())
+        emit((ip_repr, repr))
     }
 }
 
@@ -320,34 +309,23 @@ mod test {
         assert_eq!(socket.bind(LOCAL_END), Ok(()));
 
         assert!(socket.can_send());
-        assert_eq!(socket.dispatch(0, &limits, |_ip_repr, _ip_payload| {
-            unreachable!()
-        }), Err(Error::Exhausted) as Result<()>);
+        assert_eq!(socket.dispatch(0, &limits, |_| unreachable!()),
+                   Err(Error::Exhausted));
 
         assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Ok(()));
         assert_eq!(socket.send_slice(b"123456", REMOTE_END), Err(Error::Exhausted));
         assert!(!socket.can_send());
 
-        macro_rules! assert_payload_eq {
-            ($ip_repr:expr, $ip_payload:expr, $expected:expr) => {{
-                let mut buffer = vec![0; $ip_payload.buffer_len()];
-                $ip_payload.emit($ip_repr, &mut buffer);
-                let udp_packet = UdpPacket::new_checked(&buffer).unwrap();
-                let udp_repr = UdpRepr::parse(&udp_packet, &LOCAL_IP, &REMOTE_IP).unwrap();
-                assert_eq!(&udp_repr, $expected)
-            }}
-        }
-
-        assert_eq!(socket.dispatch(0, &limits, |ip_repr, ip_payload| {
-            assert_eq!(ip_repr, &LOCAL_IP_REPR);
-            assert_payload_eq!(ip_repr, ip_payload, &LOCAL_UDP_REPR);
+        assert_eq!(socket.dispatch(0, &limits, |(ip_repr, udp_repr)| {
+            assert_eq!(ip_repr, LOCAL_IP_REPR);
+            assert_eq!(udp_repr, LOCAL_UDP_REPR);
             Err(Error::Unaddressable)
-        }), Err(Error::Unaddressable) as Result<()>);
+        }), Err(Error::Unaddressable));
         /*assert!(!socket.can_send());*/
 
-        assert_eq!(socket.dispatch(0, &limits, |ip_repr, ip_payload| {
-            assert_eq!(ip_repr, &LOCAL_IP_REPR);
-            assert_payload_eq!(ip_repr, ip_payload, &LOCAL_UDP_REPR);
+        assert_eq!(socket.dispatch(0, &limits, |(ip_repr, udp_repr)| {
+            assert_eq!(ip_repr, LOCAL_IP_REPR);
+            assert_eq!(udp_repr, LOCAL_UDP_REPR);
             Ok(())
         }), /*Ok(())*/ Err(Error::Exhausted));
         assert!(socket.can_send());