ソースを参照

Factor out IpRepr into the wire module.

whitequark 8 年 前
コミット
72abe80df8
6 ファイル変更196 行追加123 行削除
  1. 13 53
      src/iface/ethernet.rs
  2. 6 16
      src/socket/mod.rs
  3. 30 35
      src/socket/tcp.rs
  4. 19 18
      src/socket/udp.rs
  5. 127 1
      src/wire/ip.rs
  6. 1 0
      src/wire/mod.rs

+ 13 - 53
src/iface/ethernet.rs

@@ -5,11 +5,11 @@ use Error;
 use phy::Device;
 use wire::{EthernetAddress, EthernetProtocol, EthernetFrame};
 use wire::{ArpPacket, ArpRepr, ArpOperation};
-use wire::{IpAddress, IpProtocol};
-use wire::{Ipv4Address, Ipv4Packet, Ipv4Repr};
+use wire::{Ipv4Packet, Ipv4Repr};
 use wire::{Icmpv4Packet, Icmpv4Repr, Icmpv4DstUnreachable};
+use wire::{IpAddress, IpProtocol, IpRepr};
 use wire::{TcpPacket, TcpRepr, TcpControl};
-use socket::{Socket, IpRepr};
+use socket::Socket;
 use super::{ArpCache};
 
 /// An Ethernet network interface.
@@ -213,13 +213,8 @@ impl<'a, 'b: 'a,
                     Ipv4Repr { src_addr, dst_addr, protocol } => {
                         let mut handled = false;
                         for socket in self.sockets.borrow_mut() {
-                            let ip_repr = IpRepr {
-                                src_addr: src_addr.into(),
-                                dst_addr: dst_addr.into(),
-                                protocol: protocol,
-                                payload:  ipv4_packet.payload()
-                            };
-                            match socket.collect(&ip_repr) {
+                            let ip_repr = IpRepr::Ipv4(ipv4_repr);
+                            match socket.collect(&ip_repr, ipv4_packet.payload()) {
                                 Ok(()) => {
                                     // The packet was valid and handled by socket.
                                     handled = true;
@@ -370,62 +365,27 @@ impl<'a, 'b: 'a,
 
         let mut nothing_to_transmit = true;
         for socket in self.sockets.borrow_mut() {
-            let result = socket.dispatch(&mut |repr| {
-                let src_addr =
-                    try!(match &repr.src_addr {
-                        &IpAddress::Unspecified |
-                        &IpAddress::Ipv4(Ipv4Address([0, _, _, _])) => {
-                            let mut assigned_addr = None;
-                            for addr in src_protocol_addrs {
-                                match addr {
-                                    addr @ &IpAddress::Ipv4(_) => {
-                                        assigned_addr = Some(addr);
-                                        break
-                                    }
-                                    _ => ()
-                                }
-                            }
-                            assigned_addr.ok_or(Error::Unaddressable)
-                        },
-                        addr => Ok(addr)
-                    });
-
-                let ipv4_repr =
-                    match (src_addr, &repr.dst_addr) {
-                        (&IpAddress::Ipv4(src_addr),
-                         &IpAddress::Ipv4(dst_addr)) => {
-                            Ipv4Repr {
-                                src_addr: src_addr,
-                                dst_addr: dst_addr,
-                                protocol: repr.protocol
-                            }
-                        },
-                        _ => unreachable!()
-                    };
+            let result = socket.dispatch(&mut |repr, payload| {
+                let repr = try!(repr.lower(src_protocol_addrs));
 
                 let dst_hardware_addr =
-                    match arp_cache.lookup(&repr.dst_addr) {
+                    match arp_cache.lookup(&repr.dst_addr()) {
                         None => return Err(Error::Unaddressable),
                         Some(hardware_addr) => hardware_addr
                     };
 
-                let tx_len = EthernetFrame::<&[u8]>::buffer_len(ipv4_repr.buffer_len() +
-                                                                repr.payload.buffer_len());
+                let tx_len = EthernetFrame::<&[u8]>::buffer_len(repr.buffer_len() +
+                                                                payload.buffer_len());
                 let mut tx_buffer = try!(device.transmit(tx_len));
                 let mut frame = try!(EthernetFrame::new(&mut tx_buffer));
                 frame.set_src_addr(src_hardware_addr);
                 frame.set_dst_addr(dst_hardware_addr);
                 frame.set_ethertype(EthernetProtocol::Ipv4);
 
-                let mut ip_packet = try!(Ipv4Packet::new(frame.payload_mut()));
-                ipv4_repr.emit(&mut ip_packet, repr.payload.buffer_len());
+                repr.emit(frame.payload_mut(), payload.buffer_len());
 
-                repr.payload.emit(&mut IpRepr {
-                    src_addr: repr.src_addr,
-                    dst_addr: repr.dst_addr,
-                    protocol: repr.protocol,
-                    payload:  ip_packet.payload_mut()
-                });
+                let mut ip_packet = try!(Ipv4Packet::new(frame.payload_mut()));
+                payload.emit(&repr, ip_packet.payload_mut());
 
                 Ok(())
             });

+ 6 - 16
src/socket/mod.rs

@@ -11,7 +11,7 @@
 //! size for a buffer, allocate it, and let the networking stack use it.
 
 use Error;
-use wire::{IpAddress, IpProtocol};
+use wire::IpRepr;
 
 mod udp;
 mod tcp;
@@ -52,12 +52,12 @@ impl<'a, 'b> Socket<'a, 'b> {
     /// is returned.
     ///
     /// This function is used internally by the networking stack.
-    pub fn collect(&mut self, repr: &IpRepr<&[u8]>) -> Result<(), Error> {
+    pub fn collect(&mut self, ip_repr: &IpRepr, payload: &[u8]) -> Result<(), Error> {
         match self {
             &mut Socket::Udp(ref mut socket) =>
-                socket.collect(repr),
+                socket.collect(ip_repr, payload),
             &mut Socket::Tcp(ref mut socket) =>
-                socket.collect(repr),
+                socket.collect(ip_repr, payload),
             &mut Socket::__Nonexhaustive => unreachable!()
         }
     }
@@ -70,7 +70,7 @@ impl<'a, 'b> Socket<'a, 'b> {
     ///
     /// This function is used internally by the networking stack.
     pub fn dispatch<F>(&mut self, emit: &mut F) -> Result<(), Error>
-            where F: FnMut(&IpRepr<&IpPayload>) -> Result<(), Error> {
+            where F: FnMut(&IpRepr, &IpPayload) -> Result<(), Error> {
         match self {
             &mut Socket::Udp(ref mut socket) =>
                 socket.dispatch(emit),
@@ -81,16 +81,6 @@ impl<'a, 'b> Socket<'a, 'b> {
     }
 }
 
-/// An IP packet representation.
-///
-/// This struct abstracts the various versions of IP packets.
-pub struct IpRepr<T> {
-    pub src_addr: IpAddress,
-    pub dst_addr: IpAddress,
-    pub protocol: IpProtocol,
-    pub payload:  T
-}
-
 /// An IP-encapsulated packet representation.
 ///
 /// This trait abstracts the various types of packets layered under the IP protocol,
@@ -100,7 +90,7 @@ pub trait IpPayload {
     fn buffer_len(&self) -> usize;
 
     /// Emit this high-level representation into a sequence of octets.
-    fn emit(&self, repr: &mut IpRepr<&mut [u8]>);
+    fn emit(&self, ip_repr: &IpRepr, payload: &mut [u8]);
 }
 
 /// A conversion trait for network sockets.

+ 30 - 35
src/socket/tcp.rs

@@ -241,27 +241,27 @@ impl<'a> TcpSocket<'a> {
     }
 
     /// See [Socket::collect](enum.Socket.html#method.collect).
-    pub fn collect(&mut self, ip_repr: &IpRepr<&[u8]>) -> Result<(), Error> {
-        if ip_repr.protocol != IpProtocol::Tcp { return Err(Error::Rejected) }
+    pub fn collect(&mut self, ip_repr: &IpRepr, payload: &[u8]) -> Result<(), Error> {
+        if ip_repr.protocol() != IpProtocol::Tcp { return Err(Error::Rejected) }
 
-        let packet = try!(TcpPacket::new(ip_repr.payload));
-        let repr = try!(TcpRepr::parse(&packet, &ip_repr.src_addr, &ip_repr.dst_addr));
+        let packet = try!(TcpPacket::new(payload));
+        let repr = try!(TcpRepr::parse(&packet, &ip_repr.src_addr(), &ip_repr.dst_addr()));
 
         // Reject packets with a wrong destination.
         if self.local_endpoint.port != repr.dst_port { return Err(Error::Rejected) }
         if !self.local_endpoint.addr.is_unspecified() &&
-           self.local_endpoint.addr != ip_repr.dst_addr { return Err(Error::Rejected) }
+           self.local_endpoint.addr != ip_repr.dst_addr() { return Err(Error::Rejected) }
 
         // Reject packets from a source to which we aren't connected.
         if self.remote_endpoint.port != 0 &&
            self.remote_endpoint.port != repr.src_port { return Err(Error::Rejected) }
         if !self.remote_endpoint.addr.is_unspecified() &&
-           self.remote_endpoint.addr != ip_repr.src_addr { return Err(Error::Rejected) }
+           self.remote_endpoint.addr != ip_repr.src_addr() { return Err(Error::Rejected) }
 
         // Reject packets addressed to a closed socket.
         if self.state == State::Closed {
             net_trace!("tcp:{}:{}:{}: packet sent to a closed socket",
-                       self.local_endpoint, ip_repr.src_addr, repr.src_port);
+                       self.local_endpoint, ip_repr.src_addr(), repr.src_port);
             return Err(Error::Malformed)
         }
 
@@ -315,8 +315,8 @@ impl<'a> TcpSocket<'a> {
             (State::Listen, TcpRepr {
                 src_port, dst_port, control: TcpControl::Syn, seq_number, ack_number: None, ..
             }) => {
-                self.local_endpoint  = IpEndpoint::new(ip_repr.dst_addr, dst_port);
-                self.remote_endpoint = IpEndpoint::new(ip_repr.src_addr, src_port);
+                self.local_endpoint  = IpEndpoint::new(ip_repr.dst_addr(), dst_port);
+                self.remote_endpoint = IpEndpoint::new(ip_repr.src_addr(), src_port);
                 self.local_seq_no    = -seq_number; // FIXME: use something more secure
                 self.remote_seq_no   = seq_number + 1;
                 self.set_state(State::SynReceived);
@@ -369,7 +369,7 @@ impl<'a> TcpSocket<'a> {
 
     /// See [Socket::dispatch](enum.Socket.html#method.dispatch).
     pub fn dispatch<F>(&mut self, emit: &mut F) -> Result<(), Error>
-            where F: FnMut(&IpRepr<&IpPayload>) -> Result<(), Error> {
+            where F: FnMut(&IpRepr, &IpPayload) -> Result<(), Error> {
         let mut repr = TcpRepr {
             src_port:   self.local_endpoint.port,
             dst_port:   self.remote_endpoint.port,
@@ -413,12 +413,12 @@ impl<'a> TcpSocket<'a> {
             _ => unreachable!()
         }
 
-        emit(&IpRepr {
+        let ip_repr = IpRepr::Unspecified {
             src_addr: self.local_endpoint.addr,
             dst_addr: self.remote_endpoint.addr,
             protocol: IpProtocol::Tcp,
-            payload:  &repr as &IpPayload
-        })
+        };
+        emit(&ip_repr, &repr)
     }
 }
 
@@ -427,9 +427,9 @@ impl<'a> IpPayload for TcpRepr<'a> {
         self.buffer_len()
     }
 
-    fn emit(&self, repr: &mut IpRepr<&mut [u8]>) {
-        let mut packet = TcpPacket::new(&mut repr.payload).expect("undersized payload");
-        self.emit(&mut packet, &repr.src_addr, &repr.dst_addr)
+    fn emit(&self, ip_repr: &IpRepr, payload: &mut [u8]) {
+        let mut packet = TcpPacket::new(payload).expect("undersized payload");
+        self.emit(&mut packet, &ip_repr.src_addr(), &ip_repr.dst_addr())
     }
 }
 
@@ -486,37 +486,32 @@ mod test {
             let mut buffer = vec![0; repr.buffer_len()];
             let mut packet = TcpPacket::new(&mut buffer).unwrap();
             repr.emit(&mut packet, &REMOTE_IP, &LOCAL_IP);
-            let result = $socket.collect(&IpRepr {
+            let ip_repr = IpRepr::Unspecified {
                 src_addr: REMOTE_IP,
                 dst_addr: LOCAL_IP,
-                protocol: IpProtocol::Tcp,
-                payload:  &packet.into_inner()[..]
-            });
+                protocol: IpProtocol::Tcp
+            };
+            let result = $socket.collect(&ip_repr, &packet.into_inner()[..]);
             result.expect("send error")
         })
     }
 
     macro_rules! recv {
         ($socket:ident, $expected:expr) => ({
-            let result = $socket.dispatch(&mut |repr| {
-                assert_eq!(repr.protocol, IpProtocol::Tcp);
-                assert_eq!(repr.src_addr, LOCAL_IP);
-                assert_eq!(repr.dst_addr, REMOTE_IP);
-
-                let mut buffer = vec![0; repr.payload.buffer_len()];
-                repr.payload.emit(&mut IpRepr {
-                    src_addr: repr.src_addr,
-                    dst_addr: repr.dst_addr,
-                    protocol: repr.protocol,
-                    payload:  &mut buffer[..]
-                });
+            let result = $socket.dispatch(&mut |ip_repr, payload| {
+                assert_eq!(ip_repr.protocol(), IpProtocol::Tcp);
+                assert_eq!(ip_repr.src_addr(), LOCAL_IP);
+                assert_eq!(ip_repr.dst_addr(), REMOTE_IP);
+
+                let mut buffer = vec![0; payload.buffer_len()];
+                payload.emit(&ip_repr, &mut buffer[..]);
                 let packet = TcpPacket::new(&buffer[..]).unwrap();
-                let repr = TcpRepr::parse(&packet, &repr.src_addr, &repr.dst_addr).unwrap();
-                assert_eq!(repr, $expected);
+                let repr = TcpRepr::parse(&packet, &ip_repr.src_addr(), &ip_repr.dst_addr());
+                assert_eq!(repr, Ok($expected));
                 Ok(())
             });
             assert_eq!(result, Ok(()));
-            let result = $socket.dispatch(&mut |_repr| {
+            let result = $socket.dispatch(&mut |_repr, _payload| {
                 Ok(())
             });
             assert_eq!(result, Err(Error::Exhausted));

+ 19 - 18
src/socket/udp.rs

@@ -168,19 +168,19 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
     }
 
     /// See [Socket::collect](enum.Socket.html#method.collect).
-    pub fn collect(&mut self, ip_repr: &IpRepr<&[u8]>) -> Result<(), Error> {
-        if ip_repr.protocol != IpProtocol::Udp { return Err(Error::Rejected) }
+    pub fn collect(&mut self, ip_repr: &IpRepr, payload: &[u8]) -> Result<(), Error> {
+        if ip_repr.protocol() != IpProtocol::Udp { return Err(Error::Rejected) }
 
-        let packet = try!(UdpPacket::new(ip_repr.payload));
-        let repr = try!(UdpRepr::parse(&packet, &ip_repr.src_addr, &ip_repr.dst_addr));
+        let packet = try!(UdpPacket::new(payload));
+        let repr = try!(UdpRepr::parse(&packet, &ip_repr.src_addr(), &ip_repr.dst_addr()));
 
         if repr.dst_port != self.endpoint.port { return Err(Error::Rejected) }
         if !self.endpoint.addr.is_unspecified() {
-            if self.endpoint.addr != ip_repr.dst_addr { return Err(Error::Rejected) }
+            if self.endpoint.addr != ip_repr.dst_addr() { return Err(Error::Rejected) }
         }
 
         let packet_buf = try!(self.rx_buffer.enqueue());
-        packet_buf.endpoint = IpEndpoint { addr: ip_repr.src_addr, port: repr.src_port };
+        packet_buf.endpoint = IpEndpoint { addr: ip_repr.src_addr(), port: repr.src_port };
         packet_buf.size = repr.payload.len();
         packet_buf.as_mut()[..repr.payload.len()].copy_from_slice(repr.payload);
         net_trace!("udp:{}:{}: collect {} octets",
@@ -190,20 +190,21 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
 
     /// See [Socket::dispatch](enum.Socket.html#method.dispatch).
     pub fn dispatch<F>(&mut self, emit: &mut F) -> Result<(), Error>
-            where F: FnMut(&IpRepr<&IpPayload>) -> Result<(), Error> {
+            where F: FnMut(&IpRepr, &IpPayload) -> Result<(), Error> {
         let packet_buf = try!(self.tx_buffer.dequeue());
         net_trace!("udp:{}:{}: dispatch {} octets",
                    self.endpoint, packet_buf.endpoint, packet_buf.size);
-        emit(&IpRepr {
+        let ip_repr = IpRepr::Unspecified {
             src_addr: self.endpoint.addr,
             dst_addr: packet_buf.endpoint.addr,
-            protocol: IpProtocol::Udp,
-            payload:  &UdpRepr {
-                src_port: self.endpoint.port,
-                dst_port: packet_buf.endpoint.port,
-                payload:  &packet_buf.as_ref()[..]
-            } as &IpPayload
-        })
+            protocol: IpProtocol::Udp
+        };
+        let payload = UdpRepr {
+            src_port: self.endpoint.port,
+            dst_port: packet_buf.endpoint.port,
+            payload:  &packet_buf.as_ref()[..]
+        };
+        emit(&ip_repr, &payload)
     }
 }
 
@@ -212,9 +213,9 @@ impl<'a> IpPayload for UdpRepr<'a> {
         self.buffer_len()
     }
 
-    fn emit(&self, repr: &mut IpRepr<&mut [u8]>) {
-        let mut packet = UdpPacket::new(&mut repr.payload).expect("undersized payload");
-        self.emit(&mut packet, &repr.src_addr, &repr.dst_addr)
+    fn emit(&self, repr: &IpRepr, payload: &mut [u8]) {
+        let mut packet = UdpPacket::new(payload).expect("undersized payload");
+        self.emit(&mut packet, &repr.src_addr(), &repr.dst_addr())
     }
 }
 

+ 127 - 1
src/wire/ip.rs

@@ -1,6 +1,7 @@
 use core::fmt;
 
-use super::Ipv4Address;
+use Error;
+use super::{Ipv4Address, Ipv4Packet, Ipv4Repr};
 
 enum_with_unknown! {
     /// Internetworking protocol.
@@ -98,6 +99,131 @@ impl fmt::Display for Endpoint {
     }
 }
 
+/// An IP packet representation.
+///
+/// This enum abstracts the various versions of IP packets. It either contains a concrete
+/// high-level representation for some IP protocol version, or an unspecified representation,
+/// which permits the `IpAddress::Unspecified` addresses.
+#[derive(Debug, Clone)]
+pub enum IpRepr {
+    Unspecified {
+        src_addr: Address,
+        dst_addr: Address,
+        protocol: Protocol
+    },
+    Ipv4(Ipv4Repr),
+    #[doc(hidden)]
+    __Nonexhaustive
+}
+
+impl IpRepr {
+    /// Return the source address.
+    pub fn src_addr(&self) -> Address {
+        match self {
+            &IpRepr::Unspecified { src_addr, .. } => src_addr,
+            &IpRepr::Ipv4(repr) => Address::Ipv4(repr.src_addr),
+            &IpRepr::__Nonexhaustive => unreachable!()
+        }
+    }
+
+    /// Return the destination address.
+    pub fn dst_addr(&self) -> Address {
+        match self {
+            &IpRepr::Unspecified { dst_addr, .. } => dst_addr,
+            &IpRepr::Ipv4(repr) => Address::Ipv4(repr.dst_addr),
+            &IpRepr::__Nonexhaustive => unreachable!()
+        }
+    }
+
+    /// Return the protocol.
+    pub fn protocol(&self) -> Protocol {
+        match self {
+            &IpRepr::Unspecified { protocol, .. } => protocol,
+            &IpRepr::Ipv4(repr) => repr.protocol,
+            &IpRepr::__Nonexhaustive => unreachable!()
+        }
+    }
+
+    /// Convert an unspecified representation into a concrete one, or return
+    /// `Err(Error::Unaddressable)` if not possible.
+    ///
+    /// # Panics
+    /// This function panics if source and destination addresses belong to different families,
+    /// or the destination address is unspecified, since this indicates a logic error.
+    pub fn lower(&self, fallback_src_addrs: &[Address]) -> Result<IpRepr, Error> {
+        match self {
+            &IpRepr::Unspecified {
+                src_addr: Address::Ipv4(src_addr),
+                dst_addr: Address::Ipv4(dst_addr),
+                protocol
+            } => {
+                Ok(IpRepr::Ipv4(Ipv4Repr {
+                    src_addr: src_addr,
+                    dst_addr: dst_addr,
+                    protocol: protocol
+                }))
+            }
+
+            &IpRepr::Unspecified {
+                src_addr: Address::Unspecified,
+                dst_addr: Address::Ipv4(dst_addr),
+                protocol
+            } => {
+                let mut src_addr = None;
+                for addr in fallback_src_addrs {
+                    match addr {
+                        &Address::Ipv4(addr) => {
+                            src_addr = Some(addr);
+                            break
+                        }
+                        _ => ()
+                    }
+                }
+                Ok(IpRepr::Ipv4(Ipv4Repr {
+                    src_addr: try!(src_addr.ok_or(Error::Unaddressable)),
+                    dst_addr: dst_addr,
+                    protocol: protocol
+                }))
+            }
+
+            &IpRepr::Unspecified { dst_addr: Address::Unspecified, .. } =>
+                panic!("unspecified destination IP address"),
+            // &IpRepr::Unspecified { .. } =>
+            //     panic!("source and destination IP address families do not match"),
+
+            repr @ &IpRepr::Ipv4(_) => Ok(repr.clone()),
+            &IpRepr::__Nonexhaustive => unreachable!()
+        }
+    }
+
+    /// Return the length of a header that will be emitted from this high-level representation.
+    ///
+    /// # Panics
+    /// This function panics if invoked on an unspecified representation.
+    pub fn buffer_len(&self) -> usize {
+        match self {
+            &IpRepr::Unspecified { .. } => panic!("unspecified IP representation"),
+            &IpRepr::Ipv4(repr) => repr.buffer_len(),
+            &IpRepr::__Nonexhaustive => unreachable!()
+        }
+    }
+
+    /// Emit this high-level representation into a buffer.
+    ///
+    /// # Panics
+    /// This function panics if invoked on an unspecified representation.
+    pub fn emit<T: AsRef<[u8]> + AsMut<[u8]>>(&self, buffer: T, payload_len: usize) {
+        match self {
+            &IpRepr::Unspecified { .. } => panic!("unspecified IP representation"),
+            &IpRepr::Ipv4(repr) => {
+                let mut packet = Ipv4Packet::new(buffer).expect("undersized buffer");
+                repr.emit(&mut packet, payload_len)
+            }
+            &IpRepr::__Nonexhaustive => unreachable!()
+        }
+    }
+}
+
 pub mod checksum {
     use byteorder::{ByteOrder, NetworkEndian};
 

+ 1 - 0
src/wire/mod.rs

@@ -102,6 +102,7 @@ pub use self::arp::Repr as ArpRepr;
 pub use self::ip::Protocol as IpProtocol;
 pub use self::ip::Address as IpAddress;
 pub use self::ip::Endpoint as IpEndpoint;
+pub use self::ip::IpRepr as IpRepr;
 
 pub use self::ipv4::Address as Ipv4Address;
 pub use self::ipv4::Packet as Ipv4Packet;