瀏覽代碼

Group IP header parts in the socket layer as struct IpRepr.

whitequark 8 年之前
父節點
當前提交
a713342120
共有 5 個文件被更改,包括 163 次插入135 次删除
  1. 25 25
      src/iface/arp_cache.rs
  2. 41 31
      src/iface/ethernet.rs
  3. 29 22
      src/socket/mod.rs
  4. 44 32
      src/socket/tcp.rs
  5. 24 25
      src/socket/udp.rs

+ 25 - 25
src/iface/arp_cache.rs

@@ -6,10 +6,10 @@ use wire::{EthernetAddress, IpAddress};
 /// This interface maps protocol addresses to hardware addresses.
 pub trait Cache {
     /// Update the cache to map given protocol address to given hardware address.
-    fn fill(&mut self, protocol_addr: IpAddress, hardware_addr: EthernetAddress);
+    fn fill(&mut self, protocol_addr: &IpAddress, hardware_addr: &EthernetAddress);
 
     /// Look up the hardware address corresponding for the given protocol address.
-    fn lookup(&mut self, protocol_addr: IpAddress) -> Option<EthernetAddress>;
+    fn lookup(&mut self, protocol_addr: &IpAddress) -> Option<EthernetAddress>;
 }
 
 /// An Address Resolution Protocol cache backed by a slice.
@@ -59,10 +59,10 @@ impl<'a> SliceCache<'a> {
     }
 
     /// Find an entry for the given protocol address, if any.
-    fn find(&self, protocol_addr: IpAddress) -> Option<usize> {
+    fn find(&self, protocol_addr: &IpAddress) -> Option<usize> {
         // The order of comparison is important: any valid IpAddress should
         // sort before IpAddress::Invalid.
-        self.storage.binary_search_by_key(&protocol_addr, |&(key, _, _)| key).ok()
+        self.storage.binary_search_by_key(protocol_addr, |&(key, _, _)| key).ok()
     }
 
     /// Sort entries in an order suitable for `find`.
@@ -94,16 +94,16 @@ impl<'a> SliceCache<'a> {
 }
 
 impl<'a> Cache for SliceCache<'a> {
-    fn fill(&mut self, protocol_addr: IpAddress, hardware_addr: EthernetAddress) {
+    fn fill(&mut self, protocol_addr: &IpAddress, hardware_addr: &EthernetAddress) {
         if let None = self.find(protocol_addr) {
             let lru_index = self.lru();
             self.storage[lru_index] =
-                (protocol_addr, hardware_addr, self.counter);
+                (*protocol_addr, *hardware_addr, self.counter);
             self.sort()
         }
     }
 
-    fn lookup(&mut self, protocol_addr: IpAddress) -> Option<EthernetAddress> {
+    fn lookup(&mut self, protocol_addr: &IpAddress) -> Option<EthernetAddress> {
         if let Some(index) = self.find(protocol_addr) {
             let (_protocol_addr, hardware_addr, ref mut counter) =
                 self.storage[index];
@@ -135,24 +135,24 @@ mod test {
         let mut cache_storage = [Default::default(); 3];
         let mut cache = SliceCache::new(&mut cache_storage[..]);
 
-        cache.fill(PADDR_A, HADDR_A);
-        assert_eq!(cache.lookup(PADDR_A), Some(HADDR_A));
-        assert_eq!(cache.lookup(PADDR_B), None);
-
-        cache.fill(PADDR_B, HADDR_B);
-        cache.fill(PADDR_C, HADDR_C);
-        assert_eq!(cache.lookup(PADDR_A), Some(HADDR_A));
-        assert_eq!(cache.lookup(PADDR_B), Some(HADDR_B));
-        assert_eq!(cache.lookup(PADDR_C), Some(HADDR_C));
-
-        cache.lookup(PADDR_B);
-        cache.lookup(PADDR_A);
-        cache.lookup(PADDR_C);
-        cache.fill(PADDR_D, HADDR_D);
-        assert_eq!(cache.lookup(PADDR_A), Some(HADDR_A));
-        assert_eq!(cache.lookup(PADDR_B), None);
-        assert_eq!(cache.lookup(PADDR_C), Some(HADDR_C));
-        assert_eq!(cache.lookup(PADDR_D), Some(HADDR_D));
+        cache.fill(&PADDR_A, &HADDR_A);
+        assert_eq!(cache.lookup(&PADDR_A), Some(HADDR_A));
+        assert_eq!(cache.lookup(&PADDR_B), None);
+
+        cache.fill(&PADDR_B, &HADDR_B);
+        cache.fill(&PADDR_C, &HADDR_C);
+        assert_eq!(cache.lookup(&PADDR_A), Some(HADDR_A));
+        assert_eq!(cache.lookup(&PADDR_B), Some(HADDR_B));
+        assert_eq!(cache.lookup(&PADDR_C), Some(HADDR_C));
+
+        cache.lookup(&PADDR_B);
+        cache.lookup(&PADDR_A);
+        cache.lookup(&PADDR_C);
+        cache.fill(&PADDR_D, &HADDR_D);
+        assert_eq!(cache.lookup(&PADDR_A), Some(HADDR_A));
+        assert_eq!(cache.lookup(&PADDR_B), None);
+        assert_eq!(cache.lookup(&PADDR_C), Some(HADDR_C));
+        assert_eq!(cache.lookup(&PADDR_D), Some(HADDR_D));
     }
 }
 

+ 41 - 31
src/iface/ethernet.rs

@@ -9,7 +9,7 @@ use wire::{IpAddress, IpProtocol};
 use wire::{Ipv4Address, Ipv4Packet, Ipv4Repr};
 use wire::{Icmpv4Packet, Icmpv4Repr, Icmpv4DstUnreachable};
 use wire::{TcpPacket, TcpRepr, TcpControl};
-use socket::Socket;
+use socket::{Socket, IpRepr};
 use super::{ArpCache};
 
 /// An Ethernet network interface.
@@ -142,7 +142,7 @@ impl<'a, 'b: 'a,
                         source_hardware_addr, source_protocol_addr,
                         target_protocol_addr, ..
                     } => {
-                        self.arp_cache.fill(source_protocol_addr.into(), source_hardware_addr);
+                        self.arp_cache.fill(&source_protocol_addr.into(), &source_hardware_addr);
 
                         if self.has_protocol_addr(target_protocol_addr) {
                             response = Response::Arp(ArpRepr::EthernetIpv4 {
@@ -160,7 +160,7 @@ impl<'a, 'b: 'a,
                         operation: ArpOperation::Reply,
                         source_hardware_addr, source_protocol_addr, ..
                     } => {
-                         self.arp_cache.fill(source_protocol_addr.into(), source_hardware_addr)
+                         self.arp_cache.fill(&source_protocol_addr.into(), &source_hardware_addr)
                     },
 
                     _ => return Err(Error::Unrecognized)
@@ -169,26 +169,26 @@ impl<'a, 'b: 'a,
 
             // Handle IP packets directed at us.
             EthernetProtocol::Ipv4 => {
-                let ip_packet = try!(Ipv4Packet::new(eth_frame.payload()));
-                let ip_repr = try!(Ipv4Repr::parse(&ip_packet));
+                let ipv4_packet = try!(Ipv4Packet::new(eth_frame.payload()));
+                let ipv4_repr = try!(Ipv4Repr::parse(&ipv4_packet));
 
                 // Fill the ARP cache from IP header.
-                self.arp_cache.fill(IpAddress::Ipv4(ip_repr.src_addr), eth_frame.src_addr());
+                self.arp_cache.fill(&IpAddress::Ipv4(ipv4_repr.src_addr), &eth_frame.src_addr());
 
-                match ip_repr {
+                match ipv4_repr {
                     // Ignore IP packets not directed at us.
                     Ipv4Repr { dst_addr, .. } if !self.has_protocol_addr(dst_addr) => (),
 
                     // Respond to ICMP packets.
                     Ipv4Repr { protocol: IpProtocol::Icmp, src_addr, dst_addr } => {
-                        let icmp_packet = try!(Icmpv4Packet::new(ip_packet.payload()));
+                        let icmp_packet = try!(Icmpv4Packet::new(ipv4_packet.payload()));
                         let icmp_repr = try!(Icmpv4Repr::parse(&icmp_packet));
                         match icmp_repr {
                             // Respond to echo requests.
                             Icmpv4Repr::EchoRequest {
                                 ident, seq_no, data
                             } => {
-                                let ip_reply_repr = Ipv4Repr {
+                                let ipv4_reply_repr = Ipv4Repr {
                                     src_addr: dst_addr,
                                     dst_addr: src_addr,
                                     protocol: IpProtocol::Icmp
@@ -198,7 +198,7 @@ impl<'a, 'b: 'a,
                                     seq_no: seq_no,
                                     data:   data
                                 };
-                                response = Response::Icmpv4(ip_reply_repr, icmp_reply_repr)
+                                response = Response::Icmpv4(ipv4_reply_repr, icmp_reply_repr)
                             }
 
                             // Ignore any echo replies.
@@ -213,8 +213,13 @@ impl<'a, 'b: 'a,
                     Ipv4Repr { src_addr, dst_addr, protocol } => {
                         let mut handled = false;
                         for socket in self.sockets.borrow_mut() {
-                            match socket.collect(&src_addr.into(), &dst_addr.into(),
-                                                 protocol, ip_packet.payload()) {
+                            let ip_repr = IpRepr {
+                                src_addr: src_addr.into(),
+                                dst_addr: dst_addr.into(),
+                                protocol: protocol,
+                                payload:  ipv4_packet.payload()
+                            };
+                            match socket.collect(&ip_repr) {
                                 Ok(()) => {
                                     // The packet was valid and handled by socket.
                                     handled = true;
@@ -236,9 +241,9 @@ impl<'a, 'b: 'a,
                         }
 
                         if !handled && protocol == IpProtocol::Tcp {
-                            let tcp_packet = try!(TcpPacket::new(ip_packet.payload()));
+                            let tcp_packet = try!(TcpPacket::new(ipv4_packet.payload()));
 
-                            let ip_reply_repr = Ipv4Repr {
+                            let ipv4_reply_repr = Ipv4Repr {
                                 src_addr: dst_addr,
                                 dst_addr: src_addr,
                                 protocol: IpProtocol::Tcp
@@ -253,7 +258,7 @@ impl<'a, 'b: 'a,
                                 window_len: 0,
                                 payload:    &[]
                             };
-                            response = Response::Tcpv4(ip_reply_repr, tcp_reply_repr);
+                            response = Response::Tcpv4(ipv4_reply_repr, tcp_reply_repr);
                         } else if !handled {
                             let reason;
                             if protocol == IpProtocol::Udp {
@@ -263,20 +268,20 @@ impl<'a, 'b: 'a,
                             }
 
                             let mut data = [0; 8];
-                            data.copy_from_slice(&ip_packet.payload()[0..8]);
+                            data.copy_from_slice(&ipv4_packet.payload()[0..8]);
 
-                            let ip_reply_repr = Ipv4Repr {
+                            let ipv4_reply_repr = Ipv4Repr {
                                 src_addr: dst_addr,
                                 dst_addr: src_addr,
                                 protocol: IpProtocol::Icmp
                             };
                             let icmp_reply_repr = Icmpv4Repr::DstUnreachable {
                                 reason:   reason,
-                                header:   ip_repr,
-                                length:   ip_packet.payload().len(),
+                                header:   ipv4_repr,
+                                length:   ipv4_packet.payload().len(),
                                 data:     data
                             };
-                            response = Response::Icmpv4(ip_reply_repr, icmp_reply_repr)
+                            response = Response::Icmpv4(ipv4_reply_repr, icmp_reply_repr)
                         }
                     },
                 }
@@ -289,7 +294,7 @@ impl<'a, 'b: 'a,
         macro_rules! ip_response {
             ($tx_buffer:ident, $frame:ident, $ip_repr:ident, $length:expr) => ({
                 let dst_hardware_addr =
-                    match self.arp_cache.lookup($ip_repr.dst_addr.into()) {
+                    match self.arp_cache.lookup(&$ip_repr.dst_addr.into()) {
                         None => return Err(Error::Unaddressable),
                         Some(hardware_addr) => hardware_addr
                     };
@@ -365,9 +370,9 @@ impl<'a, 'b: 'a,
 
         let mut nothing_to_transmit = true;
         for socket in self.sockets.borrow_mut() {
-            let result = socket.dispatch(&mut |src_addr, dst_addr, protocol, payload| {
+            let result = socket.dispatch(&mut |repr| {
                 let src_addr =
-                    try!(match src_addr {
+                    try!(match &repr.src_addr {
                         &IpAddress::Unspecified |
                         &IpAddress::Ipv4(Ipv4Address([0, _, _, _])) => {
                             let mut assigned_addr = None;
@@ -385,27 +390,27 @@ impl<'a, 'b: 'a,
                         addr => Ok(addr)
                     });
 
-                let ip_repr =
-                    match (src_addr, dst_addr) {
+                let ipv4_repr =
+                    match (src_addr, &repr.dst_addr) {
                         (&IpAddress::Ipv4(src_addr),
                          &IpAddress::Ipv4(dst_addr)) => {
                             Ipv4Repr {
                                 src_addr: src_addr,
                                 dst_addr: dst_addr,
-                                protocol: protocol
+                                protocol: repr.protocol
                             }
                         },
                         _ => unreachable!()
                     };
 
                 let dst_hardware_addr =
-                    match arp_cache.lookup(*dst_addr) {
+                    match arp_cache.lookup(&repr.dst_addr) {
                         None => return Err(Error::Unaddressable),
                         Some(hardware_addr) => hardware_addr
                     };
 
-                let tx_len = EthernetFrame::<&[u8]>::buffer_len(ip_repr.buffer_len() +
-                                                                payload.buffer_len());
+                let tx_len = EthernetFrame::<&[u8]>::buffer_len(ipv4_repr.buffer_len() +
+                                                                repr.payload.buffer_len());
                 let mut tx_buffer = try!(device.transmit(tx_len));
                 let mut frame = try!(EthernetFrame::new(&mut tx_buffer));
                 frame.set_src_addr(src_hardware_addr);
@@ -413,9 +418,14 @@ impl<'a, 'b: 'a,
                 frame.set_ethertype(EthernetProtocol::Ipv4);
 
                 let mut ip_packet = try!(Ipv4Packet::new(frame.payload_mut()));
-                ip_repr.emit(&mut ip_packet, payload.buffer_len());
+                ipv4_repr.emit(&mut ip_packet, repr.payload.buffer_len());
 
-                payload.emit(src_addr, dst_addr, ip_packet.payload_mut());
+                repr.payload.emit(&mut IpRepr {
+                    src_addr: repr.src_addr,
+                    dst_addr: repr.dst_addr,
+                    protocol: repr.protocol,
+                    payload:  ip_packet.payload_mut()
+                });
 
                 Ok(())
             });

+ 29 - 22
src/socket/mod.rs

@@ -24,18 +24,6 @@ pub use self::tcp::SocketBuffer as TcpSocketBuffer;
 pub use self::tcp::State as TcpState;
 pub use self::tcp::TcpSocket;
 
-/// A packet representation.
-///
-/// This interface abstracts the various types of packets layered under the IP protocol,
-/// and serves as an accessory to [trait Socket](trait.Socket.html).
-pub trait PacketRepr {
-    /// Return the length of the buffer required to serialize this high-level representation.
-    fn buffer_len(&self) -> usize;
-
-    /// Emit this high-level representation into a sequence of octets.
-    fn emit(&self, src_addr: &IpAddress, dst_addr: &IpAddress, payload: &mut [u8]);
-}
-
 /// A network socket.
 ///
 /// This enumeration abstracts the various types of sockets based on the IP protocol.
@@ -64,14 +52,12 @@ impl<'a, 'b> Socket<'a, 'b> {
     /// is returned.
     ///
     /// This function is used internally by the networking stack.
-    pub fn collect(&mut self, src_addr: &IpAddress, dst_addr: &IpAddress,
-                   protocol: IpProtocol, payload: &[u8])
-            -> Result<(), Error> {
+    pub fn collect(&mut self, repr: &IpRepr<&[u8]>) -> Result<(), Error> {
         match self {
             &mut Socket::Udp(ref mut socket) =>
-                socket.collect(src_addr, dst_addr, protocol, payload),
+                socket.collect(repr),
             &mut Socket::Tcp(ref mut socket) =>
-                socket.collect(src_addr, dst_addr, protocol, payload),
+                socket.collect(repr),
             &mut Socket::__Nonexhaustive => unreachable!()
         }
     }
@@ -83,19 +69,40 @@ impl<'a, 'b> Socket<'a, 'b> {
     /// is returned.
     ///
     /// This function is used internally by the networking stack.
-    pub fn dispatch(&mut self, f: &mut FnMut(&IpAddress, &IpAddress,
-                                             IpProtocol, &PacketRepr) -> Result<(), Error>)
-            -> Result<(), Error> {
+    pub fn dispatch<F>(&mut self, emit: &mut F) -> Result<(), Error>
+            where F: FnMut(&IpRepr<&IpPayload>) -> Result<(), Error> {
         match self {
             &mut Socket::Udp(ref mut socket) =>
-                socket.dispatch(f),
+                socket.dispatch(emit),
             &mut Socket::Tcp(ref mut socket) =>
-                socket.dispatch(f),
+                socket.dispatch(emit),
             &mut Socket::__Nonexhaustive => unreachable!()
         }
     }
 }
 
+/// An IP packet representation.
+///
+/// This struct abstracts the various versions of IP packets.
+pub struct IpRepr<T> {
+    pub src_addr: IpAddress,
+    pub dst_addr: IpAddress,
+    pub protocol: IpProtocol,
+    pub payload:  T
+}
+
+/// An IP-encapsulated packet representation.
+///
+/// This trait abstracts the various types of packets layered under the IP protocol,
+/// and serves as an accessory to [trait Socket](trait.Socket.html).
+pub trait IpPayload {
+    /// Return the length of the buffer required to serialize this high-level representation.
+    fn buffer_len(&self) -> usize;
+
+    /// Emit this high-level representation into a sequence of octets.
+    fn emit(&self, repr: &mut IpRepr<&mut [u8]>);
+}
+
 /// A conversion trait for network sockets.
 ///
 /// This trait is used to concisely downcast [Socket](trait.Socket.html) values to their

+ 44 - 32
src/socket/tcp.rs

@@ -2,9 +2,9 @@ use core::fmt;
 
 use Error;
 use Managed;
-use wire::{IpProtocol, IpAddress, IpEndpoint};
+use wire::{IpProtocol, IpEndpoint};
 use wire::{TcpPacket, TcpRepr, TcpControl};
-use socket::{Socket, PacketRepr};
+use socket::{Socket, IpRepr, IpPayload};
 
 /// A TCP stream ring buffer.
 #[derive(Debug)]
@@ -241,29 +241,27 @@ impl<'a> TcpSocket<'a> {
     }
 
     /// See [Socket::collect](enum.Socket.html#method.collect).
-    pub fn collect(&mut self, src_addr: &IpAddress, dst_addr: &IpAddress,
-                   protocol: IpProtocol, payload: &[u8])
-            -> Result<(), Error> {
-        if protocol != IpProtocol::Tcp { return Err(Error::Rejected) }
+    pub fn collect(&mut self, ip_repr: &IpRepr<&[u8]>) -> Result<(), Error> {
+        if ip_repr.protocol != IpProtocol::Tcp { return Err(Error::Rejected) }
 
-        let packet = try!(TcpPacket::new(payload));
-        let repr = try!(TcpRepr::parse(&packet, src_addr, dst_addr));
+        let packet = try!(TcpPacket::new(ip_repr.payload));
+        let repr = try!(TcpRepr::parse(&packet, &ip_repr.src_addr, &ip_repr.dst_addr));
 
         // Reject packets with a wrong destination.
         if self.local_endpoint.port != repr.dst_port { return Err(Error::Rejected) }
         if !self.local_endpoint.addr.is_unspecified() &&
-           self.local_endpoint.addr != *dst_addr { return Err(Error::Rejected) }
+           self.local_endpoint.addr != ip_repr.dst_addr { return Err(Error::Rejected) }
 
         // Reject packets from a source to which we aren't connected.
         if self.remote_endpoint.port != 0 &&
            self.remote_endpoint.port != repr.src_port { return Err(Error::Rejected) }
         if !self.remote_endpoint.addr.is_unspecified() &&
-           self.remote_endpoint.addr != *src_addr { return Err(Error::Rejected) }
+           self.remote_endpoint.addr != ip_repr.src_addr { return Err(Error::Rejected) }
 
         // Reject packets addressed to a closed socket.
         if self.state == State::Closed {
             net_trace!("tcp:{}:{}:{}: packet sent to a closed socket",
-                       self.local_endpoint, src_addr, repr.src_port);
+                       self.local_endpoint, ip_repr.src_addr, repr.src_port);
             return Err(Error::Malformed)
         }
 
@@ -317,8 +315,8 @@ impl<'a> TcpSocket<'a> {
             (State::Listen, TcpRepr {
                 src_port, dst_port, control: TcpControl::Syn, seq_number, ack_number: None, ..
             }) => {
-                self.local_endpoint  = IpEndpoint::new(*dst_addr, dst_port);
-                self.remote_endpoint = IpEndpoint::new(*src_addr, src_port);
+                self.local_endpoint  = IpEndpoint::new(ip_repr.dst_addr, dst_port);
+                self.remote_endpoint = IpEndpoint::new(ip_repr.src_addr, src_port);
                 self.local_seq_no    = -seq_number; // FIXME: use something more secure
                 self.remote_seq_no   = seq_number + 1;
                 self.set_state(State::SynReceived);
@@ -370,9 +368,8 @@ impl<'a> TcpSocket<'a> {
     }
 
     /// See [Socket::dispatch](enum.Socket.html#method.dispatch).
-    pub fn dispatch(&mut self, f: &mut FnMut(&IpAddress, &IpAddress,
-                                             IpProtocol, &PacketRepr) -> Result<(), Error>)
-            -> Result<(), Error> {
+    pub fn dispatch<F>(&mut self, emit: &mut F) -> Result<(), Error>
+            where F: FnMut(&IpRepr<&IpPayload>) -> Result<(), Error> {
         let mut repr = TcpRepr {
             src_port:   self.local_endpoint.port,
             dst_port:   self.remote_endpoint.port,
@@ -416,23 +413,29 @@ impl<'a> TcpSocket<'a> {
             _ => unreachable!()
         }
 
-        f(&self.local_endpoint.addr, &self.remote_endpoint.addr, IpProtocol::Tcp, &repr)
+        emit(&IpRepr {
+            src_addr: self.local_endpoint.addr,
+            dst_addr: self.remote_endpoint.addr,
+            protocol: IpProtocol::Tcp,
+            payload:  &repr as &IpPayload
+        })
     }
 }
 
-impl<'a> PacketRepr for TcpRepr<'a> {
+impl<'a> IpPayload for TcpRepr<'a> {
     fn buffer_len(&self) -> usize {
         self.buffer_len()
     }
 
-    fn emit(&self, src_addr: &IpAddress, dst_addr: &IpAddress, payload: &mut [u8]) {
-        let mut packet = TcpPacket::new(payload).expect("undersized payload");
-        self.emit(&mut packet, src_addr, dst_addr)
+    fn emit(&self, repr: &mut IpRepr<&mut [u8]>) {
+        let mut packet = TcpPacket::new(&mut repr.payload).expect("undersized payload");
+        self.emit(&mut packet, &repr.src_addr, &repr.dst_addr)
     }
 }
 
 #[cfg(test)]
 mod test {
+    use wire::IpAddress;
     use super::*;
 
     #[test]
@@ -483,28 +486,37 @@ mod test {
             let mut buffer = vec![0; repr.buffer_len()];
             let mut packet = TcpPacket::new(&mut buffer).unwrap();
             repr.emit(&mut packet, &REMOTE_IP, &LOCAL_IP);
-            let result = $socket.collect(&REMOTE_IP, &LOCAL_IP, IpProtocol::Tcp,
-                                         &packet.into_inner()[..]);
+            let result = $socket.collect(&IpRepr {
+                src_addr: REMOTE_IP,
+                dst_addr: LOCAL_IP,
+                protocol: IpProtocol::Tcp,
+                payload:  &packet.into_inner()[..]
+            });
             result.expect("send error")
         })
     }
 
     macro_rules! recv {
         ($socket:ident, $expected:expr) => ({
-            let result = $socket.dispatch(&mut |src_addr, dst_addr, protocol, payload| {
-                assert_eq!(protocol, IpProtocol::Tcp);
-                assert_eq!(src_addr, &LOCAL_IP);
-                assert_eq!(dst_addr, &REMOTE_IP);
-
-                let mut buffer = vec![0; payload.buffer_len()];
-                payload.emit(src_addr, dst_addr, &mut buffer);
+            let result = $socket.dispatch(&mut |repr| {
+                assert_eq!(repr.protocol, IpProtocol::Tcp);
+                assert_eq!(repr.src_addr, LOCAL_IP);
+                assert_eq!(repr.dst_addr, REMOTE_IP);
+
+                let mut buffer = vec![0; repr.payload.buffer_len()];
+                repr.payload.emit(&mut IpRepr {
+                    src_addr: repr.src_addr,
+                    dst_addr: repr.dst_addr,
+                    protocol: repr.protocol,
+                    payload:  &mut buffer[..]
+                });
                 let packet = TcpPacket::new(&buffer[..]).unwrap();
-                let repr = TcpRepr::parse(&packet, src_addr, dst_addr).unwrap();
+                let repr = TcpRepr::parse(&packet, &repr.src_addr, &repr.dst_addr).unwrap();
                 assert_eq!(repr, $expected);
                 Ok(())
             });
             assert_eq!(result, Ok(()));
-            let result = $socket.dispatch(&mut |_src_addr, _dst_addr, _protocol, _payload| {
+            let result = $socket.dispatch(&mut |_repr| {
                 Ok(())
             });
             assert_eq!(result, Err(Error::Exhausted));

+ 24 - 25
src/socket/udp.rs

@@ -1,8 +1,8 @@
 use Error;
 use Managed;
-use wire::{IpAddress, IpProtocol, IpEndpoint};
+use wire::{IpProtocol, IpEndpoint};
 use wire::{UdpPacket, UdpRepr};
-use socket::{Socket, PacketRepr};
+use socket::{Socket, IpRepr, IpPayload};
 
 /// A buffered UDP packet.
 #[derive(Debug)]
@@ -168,21 +168,19 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
     }
 
     /// See [Socket::collect](enum.Socket.html#method.collect).
-    pub fn collect(&mut self, src_addr: &IpAddress, dst_addr: &IpAddress,
-                   protocol: IpProtocol, payload: &[u8])
-            -> Result<(), Error> {
-        if protocol != IpProtocol::Udp { return Err(Error::Rejected) }
+    pub fn collect(&mut self, ip_repr: &IpRepr<&[u8]>) -> Result<(), Error> {
+        if ip_repr.protocol != IpProtocol::Udp { return Err(Error::Rejected) }
 
-        let packet = try!(UdpPacket::new(payload));
-        let repr = try!(UdpRepr::parse(&packet, src_addr, dst_addr));
+        let packet = try!(UdpPacket::new(ip_repr.payload));
+        let repr = try!(UdpRepr::parse(&packet, &ip_repr.src_addr, &ip_repr.dst_addr));
 
         if repr.dst_port != self.endpoint.port { return Err(Error::Rejected) }
         if !self.endpoint.addr.is_unspecified() {
-            if self.endpoint.addr != *dst_addr { return Err(Error::Rejected) }
+            if self.endpoint.addr != ip_repr.dst_addr { return Err(Error::Rejected) }
         }
 
         let packet_buf = try!(self.rx_buffer.enqueue());
-        packet_buf.endpoint = IpEndpoint { addr: *src_addr, port: repr.src_port };
+        packet_buf.endpoint = IpEndpoint { addr: ip_repr.src_addr, port: repr.src_port };
         packet_buf.size = repr.payload.len();
         packet_buf.as_mut()[..repr.payload.len()].copy_from_slice(repr.payload);
         net_trace!("udp:{}:{}: collect {} octets",
@@ -191,31 +189,32 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
     }
 
     /// See [Socket::dispatch](enum.Socket.html#method.dispatch).
-    pub fn dispatch(&mut self, f: &mut FnMut(&IpAddress, &IpAddress,
-                                             IpProtocol, &PacketRepr) -> Result<(), Error>)
-            -> Result<(), Error> {
+    pub fn dispatch<F>(&mut self, emit: &mut F) -> Result<(), Error>
+            where F: FnMut(&IpRepr<&IpPayload>) -> Result<(), Error> {
         let packet_buf = try!(self.tx_buffer.dequeue());
         net_trace!("udp:{}:{}: dispatch {} octets",
                    self.endpoint, packet_buf.endpoint, packet_buf.size);
-        f(&self.endpoint.addr,
-          &packet_buf.endpoint.addr,
-          IpProtocol::Udp,
-          &UdpRepr {
-            src_port: self.endpoint.port,
-            dst_port: packet_buf.endpoint.port,
-            payload:  &packet_buf.as_ref()[..]
-          })
+        emit(&IpRepr {
+            src_addr: self.endpoint.addr,
+            dst_addr: packet_buf.endpoint.addr,
+            protocol: IpProtocol::Udp,
+            payload:  &UdpRepr {
+                src_port: self.endpoint.port,
+                dst_port: packet_buf.endpoint.port,
+                payload:  &packet_buf.as_ref()[..]
+            } as &IpPayload
+        })
     }
 }
 
-impl<'a> PacketRepr for UdpRepr<'a> {
+impl<'a> IpPayload for UdpRepr<'a> {
     fn buffer_len(&self) -> usize {
         self.buffer_len()
     }
 
-    fn emit(&self, src_addr: &IpAddress, dst_addr: &IpAddress, payload: &mut [u8]) {
-        let mut packet = UdpPacket::new(payload).expect("undersized payload");
-        self.emit(&mut packet, src_addr, dst_addr)
+    fn emit(&self, repr: &mut IpRepr<&mut [u8]>) {
+        let mut packet = UdpPacket::new(&mut repr.payload).expect("undersized payload");
+        self.emit(&mut packet, &repr.src_addr, &repr.dst_addr)
     }
 }