Browse Source

Add `RawSocket`.

Egor Karavaev 7 years ago
parent
commit
ed08b74427
6 changed files with 314 additions and 2 deletions
  1. 11 1
      src/iface/ethernet.rs
  2. 38 0
      src/socket/mod.rs
  3. 244 0
      src/socket/raw.rs
  4. 2 0
      src/socket/set.rs
  5. 18 1
      src/wire/ip.rs
  6. 1 0
      src/wire/mod.rs

+ 11 - 1
src/iface/ethernet.rs

@@ -8,7 +8,7 @@ use wire::{Ipv4Packet, Ipv4Repr};
 use wire::{Icmpv4Packet, Icmpv4Repr, Icmpv4DstUnreachable};
 use wire::{IpAddress, IpProtocol, IpRepr};
 use wire::{TcpPacket, TcpRepr, TcpControl};
-use socket::SocketSet;
+use socket::{Socket, SocketSet, RawSocket, AsSocket};
 use super::ArpCache;
 
 /// An Ethernet network interface.
@@ -179,6 +179,16 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
                                         &eth_frame.src_addr());
                 }
 
+                // Pass every IP packet to all raw sockets we have registered.
+                for raw_socket in sockets.iter_mut().filter_map(
+                        <Socket as AsSocket<RawSocket>>::try_as_socket) {
+                    match raw_socket.process(timestamp, &IpRepr::Ipv4(ipv4_repr),
+                                             ipv4_packet.payload()) {
+                        Ok(()) | Err(Error::Rejected) => (),
+                        _ => unreachable!(),
+                    }
+                }
+
                 match ipv4_repr {
                     // Ignore IP packets not directed at us.
                     Ipv4Repr { dst_addr, .. } if !self.has_protocol_addr(dst_addr) => (),

+ 38 - 0
src/socket/mod.rs

@@ -14,10 +14,15 @@ use Error;
 use phy::DeviceLimits;
 use wire::IpRepr;
 
+mod raw;
 mod udp;
 mod tcp;
 mod set;
 
+pub use self::raw::PacketBuffer as RawPacketBuffer;
+pub use self::raw::SocketBuffer as RawSocketBuffer;
+pub use self::raw::RawSocket;
+
 pub use self::udp::PacketBuffer as UdpPacketBuffer;
 pub use self::udp::SocketBuffer as UdpSocketBuffer;
 pub use self::udp::UdpSocket;
@@ -44,6 +49,7 @@ pub use self::set::{Iter as SocketSetIter, IterMut as SocketSetIterMut};
 /// since the lower layers treat the packet as an opaque octet sequence.
 #[derive(Debug)]
 pub enum Socket<'a, 'b: 'a> {
+    Raw(RawSocket<'a, 'b>),
     Udp(UdpSocket<'a, 'b>),
     Tcp(TcpSocket<'a>),
     #[doc(hidden)]
@@ -53,6 +59,7 @@ pub enum Socket<'a, 'b: 'a> {
 macro_rules! dispatch_socket {
     ($self_:expr, |$socket:ident [$( $mut_:tt )*]| $code:expr) => ({
         match $self_ {
+            &$( $mut_ )* Socket::Raw(ref $( $mut_ )* $socket) => $code,
             &$( $mut_ )* Socket::Udp(ref $( $mut_ )* $socket) => $code,
             &$( $mut_ )* Socket::Tcp(ref $( $mut_ )* $socket) => $code,
             &$( $mut_ )* Socket::__Nonexhaustive => unreachable!()
@@ -118,6 +125,23 @@ pub trait IpPayload {
 /// concrete types.
 pub trait AsSocket<T> {
     fn as_socket(&mut self) -> &mut T;
+    fn try_as_socket(&mut self) -> Option<&mut T>;
+}
+
+impl<'a, 'b> AsSocket<RawSocket<'a, 'b>> for Socket<'a, 'b> {
+    fn as_socket(&mut self) -> &mut RawSocket<'a, 'b> {
+        match self {
+            &mut Socket::Raw(ref mut socket) => socket,
+            _ => panic!(".as_socket::<RawSocket> called on wrong socket type")
+        }
+    }
+
+    fn try_as_socket(&mut self) -> Option<&mut RawSocket<'a, 'b>> {
+        match self {
+            &mut Socket::Raw(ref mut socket) => Some(socket),
+            _ => None,
+        }
+    }
 }
 
 impl<'a, 'b> AsSocket<UdpSocket<'a, 'b>> for Socket<'a, 'b> {
@@ -127,6 +151,13 @@ impl<'a, 'b> AsSocket<UdpSocket<'a, 'b>> for Socket<'a, 'b> {
             _ => panic!(".as_socket::<UdpSocket> called on wrong socket type")
         }
     }
+
+    fn try_as_socket(&mut self) -> Option<&mut UdpSocket<'a, 'b>> {
+        match self {
+            &mut Socket::Udp(ref mut socket) => Some(socket),
+            _ => None,
+        }
+    }
 }
 
 impl<'a, 'b> AsSocket<TcpSocket<'a>> for Socket<'a, 'b> {
@@ -136,4 +167,11 @@ impl<'a, 'b> AsSocket<TcpSocket<'a>> for Socket<'a, 'b> {
             _ => panic!(".as_socket::<TcpSocket> called on wrong socket type")
         }
     }
+
+    fn try_as_socket(&mut self) -> Option<&mut TcpSocket<'a>> {
+        match self {
+            &mut Socket::Tcp(ref mut socket) => Some(socket),
+            _ => None,
+        }
+    }
 }

+ 244 - 0
src/socket/raw.rs

@@ -0,0 +1,244 @@
+use managed::Managed;
+
+use Error;
+use phy::DeviceLimits;
+use wire::{IpVersion, IpProtocol, Ipv4Repr, Ipv4Packet};
+use socket::{IpRepr, IpPayload, Socket};
+use storage::{Resettable, RingBuffer};
+
+/// A buffered raw IP packet.
+#[derive(Debug)]
+pub struct PacketBuffer<'a> {
+    size:    usize,
+    payload: Managed<'a, [u8]>,
+}
+
+impl<'a> PacketBuffer<'a> {
+    /// Create a buffered packet.
+    pub fn new<T>(payload: T) -> PacketBuffer<'a>
+            where T: Into<Managed<'a, [u8]>> {
+        PacketBuffer {
+            size:    0,
+            payload: payload.into(),
+        }
+    }
+
+    fn as_ref<'b>(&'b self) -> &'b [u8] {
+        &self.payload[..self.size]
+    }
+
+    fn as_mut<'b>(&'b mut self) -> &'b mut [u8] {
+        &mut self.payload[..self.size]
+    }
+}
+
+impl<'a> Resettable for PacketBuffer<'a> {
+    fn reset(&mut self) {
+        self.size = 0;
+    }
+}
+
+/// A raw IP packet ring buffer.
+pub type SocketBuffer<'a, 'b: 'a> = RingBuffer<'a, PacketBuffer<'b>>;
+
+/// A raw IP socket.
+///
+/// A raw socket is bound to a specific IP protocol, and owns
+/// transmit and receive packet buffers.
+#[derive(Debug)]
+pub struct RawSocket<'a, 'b: 'a> {
+    ip_version:  IpVersion,
+    ip_protocol: IpProtocol,
+    rx_buffer:   SocketBuffer<'a, 'b>,
+    tx_buffer:   SocketBuffer<'a, 'b>,
+    debug_id:    usize,
+}
+
+impl<'a, 'b> RawSocket<'a, 'b> {
+    /// Create a raw IP socket bound to the given IP version and datagram protocol,
+    /// with the given buffers.
+    pub fn new(ip_version: IpVersion, ip_protocol: IpProtocol,
+               rx_buffer: SocketBuffer<'a, 'b>,
+               tx_buffer: SocketBuffer<'a, 'b>) -> Socket<'a, 'b> {
+        Socket::Raw(RawSocket {
+            ip_version,
+            ip_protocol,
+            rx_buffer,
+            tx_buffer,
+            debug_id: 0,
+        })
+    }
+
+    /// Return the debug identifier.
+    pub fn debug_id(&self) -> usize {
+        self.debug_id
+    }
+
+    /// Set the debug identifier.
+    ///
+    /// The debug identifier is a number printed in socket trace messages.
+    /// It could as well be used by the user code.
+    pub fn set_debug_id(&mut self, id: usize) {
+        self.debug_id = id;
+    }
+
+    /// Return the IP version the socket is bound to.
+    pub fn ip_version(&self) -> IpVersion {
+        self.ip_version
+    }
+
+    /// Return the IP protocol the socket is bound to.
+    pub fn ip_protocol(&self) -> IpProtocol {
+        self.ip_protocol
+    }
+
+    /// Check whether the transmit buffer is full.
+    pub fn can_send(&self) -> bool {
+        !self.tx_buffer.full()
+    }
+
+    /// Check whether the receive buffer is not empty.
+    pub fn can_recv(&self) -> bool {
+        !self.rx_buffer.empty()
+    }
+
+    /// Enqueue a packet to send, and return a pointer to its payload.
+    ///
+    /// This function returns `Err(())` if the size is greater than what
+    /// the transmit buffer can accomodate.
+    pub fn send(&mut self, size: usize) -> Result<&mut [u8], ()> {
+        let packet_buf = self.tx_buffer.enqueue()?;
+        packet_buf.size = size;
+        net_trace!("[{}]:{}:{}: buffer to send {} octets",
+                   self.debug_id, self.ip_version, self.ip_protocol,
+                   packet_buf.size);
+        Ok(&mut packet_buf.as_mut()[..size])
+    }
+
+    /// Enqueue a packet to send, and fill it from a slice.
+    ///
+    /// See also [send](#method.send).
+    pub fn send_slice(&mut self, data: &[u8]) -> Result<usize, ()> {
+        let buffer = self.send(data.len())?;
+        let data = &data[..buffer.len()];
+        buffer.copy_from_slice(data);
+        Ok(data.len())
+    }
+
+    /// Dequeue a packet, and return a pointer to the payload.
+    ///
+    /// This function returns `Err(())` if the receive buffer is empty.
+    pub fn recv(&mut self) -> Result<&[u8], ()> {
+        let packet_buf = self.rx_buffer.dequeue()?;
+        net_trace!("[{}]:{}:{}: receive {} buffered octets",
+                   self.debug_id, self.ip_version, self.ip_protocol,
+                   packet_buf.size);
+        Ok(&packet_buf.as_ref()[..packet_buf.size])
+    }
+
+    /// Dequeue a packet, and copy the payload into the given slice.
+    ///
+    /// See also [recv](#method.recv).
+    pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<usize, ()> {
+        let buffer = self.recv()?;
+        data[..buffer.len()].copy_from_slice(buffer);
+        Ok(buffer.len())
+    }
+
+    /// See [Socket::process](enum.Socket.html#method.process).
+    pub fn process(&mut self, _timestamp: u64, ip_repr: &IpRepr,
+                   payload: &[u8]) -> Result<(), Error> {
+        match self.ip_version {
+            IpVersion::Ipv4 => {
+                if ip_repr.protocol() != self.ip_protocol {
+                    return Err(Error::Rejected);
+                }
+                let header_len = ip_repr.buffer_len();
+                let packet_buf = self.rx_buffer.enqueue().map_err(|()| Error::Exhausted)?;
+                packet_buf.size = header_len + payload.len();
+                ip_repr.emit(&mut packet_buf.as_mut()[..header_len]);
+                packet_buf.as_mut()[header_len..header_len + payload.len()]
+                    .copy_from_slice(payload);
+                net_trace!("[{}]:{}:{}: receiving {} octets",
+                           self.debug_id, self.ip_version, self.ip_protocol,
+                           packet_buf.size);
+                Ok(())
+            }
+            IpVersion::__Nonexhaustive => unreachable!()
+        }
+    }
+
+    /// See [Socket::dispatch](enum.Socket.html#method.dispatch).
+    pub fn dispatch<F, R>(&mut self, _timestamp: u64, _limits: &DeviceLimits,
+                          emit: &mut F) -> Result<R, Error>
+            where F: FnMut(&IpRepr, &IpPayload) -> Result<R, Error> {
+        let mut packet_buf = self.tx_buffer.dequeue_mut().map_err(|()| Error::Exhausted)?;
+        net_trace!("[{}]:{}:{}: sending {} octets",
+                   self.debug_id, self.ip_version, self.ip_protocol,
+                   packet_buf.size);
+
+        match self.ip_version {
+            IpVersion::Ipv4 => {
+                let mut ipv4_packet = Ipv4Packet::new(packet_buf.as_mut())?;
+                ipv4_packet.fill_checksum();
+
+                let ipv4_packet = Ipv4Packet::new(&*ipv4_packet.into_inner())?;
+                let raw_repr = RawRepr(ipv4_packet.payload());
+                let ipv4_repr = Ipv4Repr::parse(&ipv4_packet)?;
+                emit(&IpRepr::Ipv4(ipv4_repr), &raw_repr)
+            }
+            IpVersion::__Nonexhaustive => unreachable!()
+        }
+    }
+}
+
+struct RawRepr<'a>(&'a [u8]);
+
+impl<'a> IpPayload for RawRepr<'a> {
+    fn buffer_len(&self) -> usize {
+        self.0.len()
+    }
+
+    fn emit(&self, _repr: &IpRepr, payload: &mut [u8]) {
+        payload.copy_from_slice(self.0);
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use super::*;
+
+    #[test]
+    pub fn test_buffer() {
+        let mut storage = vec![];
+        for _ in 0..5 {
+            storage.push(PacketBuffer::new(vec![0]))
+        }
+        let mut buffer = SocketBuffer::new(&mut storage[..]);
+
+        assert_eq!(buffer.empty(), true);
+        assert_eq!(buffer.full(), false);
+        buffer.enqueue().unwrap().size = 1;
+        assert_eq!(buffer.empty(), false);
+        assert_eq!(buffer.full(), false);
+        buffer.enqueue().unwrap().size = 2;
+        buffer.enqueue().unwrap().size = 3;
+        assert_eq!(buffer.dequeue().unwrap().size, 1);
+        assert_eq!(buffer.dequeue().unwrap().size, 2);
+        buffer.enqueue().unwrap().size = 4;
+        buffer.enqueue().unwrap().size = 5;
+        buffer.enqueue().unwrap().size = 6;
+        buffer.enqueue().unwrap().size = 7;
+        assert_eq!(buffer.enqueue().unwrap_err(), ());
+        assert_eq!(buffer.empty(), false);
+        assert_eq!(buffer.full(), true);
+        assert_eq!(buffer.dequeue().unwrap().size, 3);
+        assert_eq!(buffer.dequeue().unwrap().size, 4);
+        assert_eq!(buffer.dequeue().unwrap().size, 5);
+        assert_eq!(buffer.dequeue().unwrap().size, 6);
+        assert_eq!(buffer.dequeue().unwrap().size, 7);
+        assert_eq!(buffer.dequeue().unwrap_err(), ());
+        assert_eq!(buffer.empty(), true);
+        assert_eq!(buffer.full(), false);
+    }
+}

+ 2 - 0
src/socket/set.rs

@@ -139,6 +139,8 @@ impl<'a, 'b: 'a, 'c: 'a + 'b> Set<'a, 'b, 'c> {
             let mut may_remove = false;
             if let &mut Some(Item { refs: 0, ref mut socket }) = item {
                 match socket {
+                    &mut Socket::Raw(_) =>
+                        may_remove = true,
                     &mut Socket::Udp(_) =>
                         may_remove = true,
                     &mut Socket::Tcp(ref mut socket) =>

+ 18 - 1
src/wire/ip.rs

@@ -3,8 +3,25 @@ use core::fmt;
 use Error;
 use super::{Ipv4Address, Ipv4Packet, Ipv4Repr};
 
+/// Internet protocol version.
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum Version {
+    Ipv4,
+    #[doc(hidden)]
+    __Nonexhaustive,
+}
+
+impl fmt::Display for Version {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        match self {
+            &Version::Ipv4 => write!(f, "IPv4"),
+            &Version::__Nonexhaustive => unreachable!()
+        }
+    }
+}
+
 enum_with_unknown! {
-    /// Internetworking protocol.
+    /// IP datagram encapsulated protocol.
     pub enum Protocol(u8) {
         Icmp = 0x01,
         Tcp  = 0x06,

+ 1 - 0
src/wire/mod.rs

@@ -126,6 +126,7 @@ pub use self::arp::Operation as ArpOperation;
 pub use self::arp::Packet as ArpPacket;
 pub use self::arp::Repr as ArpRepr;
 
+pub use self::ip::Version as IpVersion;
 pub use self::ip::Protocol as IpProtocol;
 pub use self::ip::Address as IpAddress;
 pub use self::ip::Endpoint as IpEndpoint;