浏览代码

Rethink the buffering strategy with Managed<T>.

whitequark 8 年之前
父节点
当前提交
59dae01e9c
共有 7 个文件被更改,包括 144 次插入74 次删除
  1. 4 3
      examples/smoltcpserver.rs
  2. 16 23
      src/iface/arp_cache.rs
  3. 6 1
      src/iface/ethernet.rs
  4. 5 1
      src/lib.rs
  5. 80 0
      src/managed.rs
  6. 2 1
      src/socket/mod.rs
  7. 31 45
      src/socket/udp.rs

+ 4 - 3
examples/smoltcpserver.rs

@@ -20,13 +20,14 @@ fn main() {
     let listen_address = InternetAddress::ipv4([0, 0, 0, 0]);
     let endpoint = InternetEndpoint::new(listen_address, 6969);
 
-    let udp_rx_buffer = UdpBuffer::new([UdpBufferElem::new(vec![0; 2048])]);
-    let udp_tx_buffer = UdpBuffer::new([UdpBufferElem::new(vec![0; 2048])]);
+    let udp_rx_buffer = UdpBuffer::new(vec![UdpBufferElem::new(vec![0; 2048])]);
+    let udp_tx_buffer = UdpBuffer::new(vec![UdpBufferElem::new(vec![0; 2048])]);
     let mut udp_socket = UdpSocket::new(endpoint, udp_rx_buffer, udp_tx_buffer);
-    let mut sockets: [&mut Socket; 1] = [&mut udp_socket];
 
+    let mut sockets: [&mut Socket; 1] = [&mut udp_socket];
     let mut iface = EthernetInterface::new(device, arp_cache,
         hardware_addr, &mut protocol_addrs[..], &mut sockets[..]);
+
     loop {
         match iface.poll() {
             Ok(()) => (),

+ 16 - 23
src/iface/arp_cache.rs

@@ -1,5 +1,4 @@
-use core::borrow::BorrowMut;
-
+use Managed;
 use wire::{EthernetAddress, InternetAddress};
 
 /// An Address Resolution Protocol cache.
@@ -33,26 +32,24 @@ pub trait Cache {
 /// let mut arp_cache = SliceArpCache::new(&mut arp_cache_storage[..]);
 /// ```
 
-pub struct SliceCache<
-    StorageT: BorrowMut<[(InternetAddress, EthernetAddress, usize)]>
-> {
-    storage: StorageT,
+pub struct SliceCache<'a> {
+    storage: Managed<'a, [(InternetAddress, EthernetAddress, usize)]>,
     counter: usize
 }
 
-impl<
-    StorageT: BorrowMut<[(InternetAddress, EthernetAddress, usize)]>
-> SliceCache<StorageT> {
+impl<'a> SliceCache<'a> {
     /// Create a cache. The backing storage is cleared upon creation.
     ///
     /// # Panics
     /// This function panics if `storage.len() == 0`.
-    pub fn new(mut storage: StorageT) -> SliceCache<StorageT> {
-        if storage.borrow().len() == 0 {
+    pub fn new<T>(storage: T) -> SliceCache<'a>
+            where T: Into<Managed<'a, [(InternetAddress, EthernetAddress, usize)]>> {
+        let mut storage = storage.into();
+        if storage.len() == 0 {
             panic!("ARP slice cache created with empty storage")
         }
 
-        for elem in storage.borrow_mut().iter_mut() {
+        for elem in storage.iter_mut() {
             *elem = Default::default()
         }
         SliceCache {
@@ -65,30 +62,25 @@ impl<
     fn find(&self, protocol_addr: InternetAddress) -> Option<usize> {
         // The order of comparison is important: any valid InternetAddress should
         // sort before InternetAddress::Invalid.
-        let storage = self.storage.borrow();
-        storage.binary_search_by_key(&protocol_addr, |&(key, _, _)| key).ok()
+        self.storage.binary_search_by_key(&protocol_addr, |&(key, _, _)| key).ok()
     }
 
     /// Sort entries in an order suitable for `find`.
     fn sort(&mut self) {
-        let mut storage = self.storage.borrow_mut();
-        storage.sort_by_key(|&(key, _, _)| key)
+        self.storage.sort_by_key(|&(key, _, _)| key)
     }
 
     /// Find the least recently used entry.
     fn lru(&self) -> usize {
-        let storage = self.storage.borrow();
-        storage.iter().enumerate().min_by_key(|&(_, &(_, _, counter))| counter).unwrap().0
+        self.storage.iter().enumerate().min_by_key(|&(_, &(_, _, counter))| counter).unwrap().0
     }
 }
 
-impl<
-    StorageT: BorrowMut<[(InternetAddress, EthernetAddress, usize)]>
-> Cache for SliceCache<StorageT> {
+impl<'a> Cache for SliceCache<'a> {
     fn fill(&mut self, protocol_addr: InternetAddress, hardware_addr: EthernetAddress) {
         if let None = self.find(protocol_addr) {
             let lru_index = self.lru();
-            self.storage.borrow_mut()[lru_index] =
+            self.storage[lru_index] =
                 (protocol_addr, hardware_addr, self.counter);
             self.sort()
         }
@@ -97,7 +89,7 @@ impl<
     fn lookup(&mut self, protocol_addr: InternetAddress) -> Option<EthernetAddress> {
         if let Some(index) = self.find(protocol_addr) {
             let (_protocol_addr, hardware_addr, ref mut counter) =
-                self.storage.borrow_mut()[index];
+                self.storage[index];
             self.counter += 1;
             *counter = self.counter;
             Some(hardware_addr)
@@ -146,3 +138,4 @@ mod test {
         assert_eq!(cache.lookup(PADDR_D), Some(HADDR_D));
     }
 }
+

+ 6 - 1
src/iface/ethernet.rs

@@ -99,12 +99,17 @@ impl<'a,
         Self::check_protocol_addrs(self.protocol_addrs.borrow())
     }
 
-    /// Checks whether the interface has the given protocol address assigned.
+    /// Check whether the interface has the given protocol address assigned.
     pub fn has_protocol_addr<T: Into<InternetAddress>>(&self, addr: T) -> bool {
         let addr = addr.into();
         self.protocol_addrs.borrow().iter().any(|&probe| probe == addr)
     }
 
+    /// Get the set of sockets owned by the interface.
+    pub fn with_sockets<R, F: FnOnce(&mut [&'a mut Socket]) -> R>(&mut self, f: F) -> R {
+        f(self.sockets.borrow_mut())
+    }
+
     /// Receive and process a packet, if available.
     pub fn poll(&mut self) -> Result<(), Error> {
         enum Response<'a> {

+ 5 - 1
src/lib.rs

@@ -1,4 +1,4 @@
-#![feature(associated_consts, const_fn, step_by)]
+#![feature(associated_consts, const_fn, step_by, intrinsics)]
 #![no_std]
 
 extern crate byteorder;
@@ -11,11 +11,15 @@ extern crate libc;
 
 use core::fmt;
 
+mod managed;
+
 pub mod phy;
 pub mod wire;
 pub mod iface;
 pub mod socket;
 
+pub use managed::Managed;
+
 /// The error type for the networking stack.
 #[derive(Debug, PartialEq, Eq, Clone, Copy)]
 pub enum Error {

+ 80 - 0
src/managed.rs

@@ -0,0 +1,80 @@
+use core::ops::{Deref, DerefMut};
+use core::borrow::BorrowMut;
+use core::fmt;
+
+#[cfg(feature = "std")]
+use std::boxed::Box;
+#[cfg(feature = "std")]
+use std::vec::Vec;
+
+/// A managed object.
+///
+/// This enum can be used to represent exclusive access to objects. In Rust, exclusive access
+/// to an object is obtained by either owning the object, or owning a mutable pointer
+/// to the object; hence, "managed".
+///
+/// The purpose of this enum is providing good ergonomics with `std` present while making
+/// it possible to avoid having a heap at all (which of course means that `std` is not present).
+/// To achieve this, the `Managed::Owned` variant is only available when the "std" feature
+/// is enabled.
+///
+/// A function that requires a managed object should be generic over an `Into<Managed<'a, T>>`
+/// argument; then, it will be possible to pass either a `Box<T>`, `Vec<T>`, or a `&'a mut T`
+/// without any conversion at the call site.
+pub enum Managed<'a, T: 'a + ?Sized> {
+    /// Borrowed variant, either a single element or a slice.
+    Borrowed(&'a mut T),
+    /// Owned variant, only available with `std` present.
+    #[cfg(feature = "std")]
+    Owned(Box<BorrowMut<T>>)
+}
+
+impl<'a, T: 'a + fmt::Debug + ?Sized> fmt::Debug for Managed<'a, T> {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        write!(f, "Managed::from({:?})", self.deref())
+    }
+}
+
+impl<'a, T: 'a + ?Sized> From<&'a mut T> for Managed<'a, T> {
+    fn from(value: &'a mut T) -> Self {
+        Managed::Borrowed(value)
+    }
+}
+
+#[cfg(feature = "std")]
+impl<T, U: BorrowMut<T> + 'static> From<Box<U>> for Managed<'static, T> {
+    fn from(value: Box<U>) -> Self {
+        Managed::Owned(value)
+    }
+}
+
+#[cfg(feature = "std")]
+impl<T: 'static> From<Vec<T>> for Managed<'static, [T]> {
+    fn from(mut value: Vec<T>) -> Self {
+        value.shrink_to_fit();
+        Managed::Owned(Box::new(value))
+    }
+}
+
+impl<'a, T: 'a + ?Sized> Deref for Managed<'a, T> {
+    type Target = T;
+
+    fn deref(&self) -> &Self::Target {
+        match self {
+            &Managed::Borrowed(ref value) => value,
+            #[cfg(feature = "std")]
+            &Managed::Owned(ref value) => (**value).borrow()
+        }
+    }
+}
+
+impl<'a, T: 'a + ?Sized> DerefMut for Managed<'a, T> {
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        match self {
+            &mut Managed::Borrowed(ref mut value) => value,
+            #[cfg(feature = "std")]
+            &mut Managed::Owned(ref mut value) => (**value).borrow_mut()
+        }
+    }
+}
+

+ 2 - 1
src/socket/mod.rs

@@ -35,7 +35,8 @@ pub trait PacketRepr {
 ///
 /// This interface abstracts the various types of sockets based on the IP protocol.
 /// It is necessarily implemented as a trait, and not as an enumeration, to allow using different
-/// buffering strategies in sockets assigned to the same interface.
+/// buffer types in sockets assigned to the same interface. To access a socket through this
+/// interface, cast it using `.as::<T>()`.
 ///
 /// The `collect` and `dispatch` functions are fundamentally asymmetric and thus differ in
 /// their use of the [trait PacketRepr](trait.PacketRepr.html). When `collect` is called,

+ 31 - 45
src/socket/udp.rs

@@ -1,50 +1,46 @@
-use core::marker::PhantomData;
 use core::borrow::BorrowMut;
 
 use Error;
+use Managed;
 use wire::{InternetAddress as Address, InternetProtocolType as ProtocolType};
 use wire::{InternetEndpoint as Endpoint};
 use wire::{UdpPacket, UdpRepr};
 use socket::{Socket, PacketRepr};
 
 /// A buffered UDP packet.
-#[derive(Debug, Default)]
-pub struct BufferElem<T: BorrowMut<[u8]>> {
+#[derive(Debug)]
+pub struct BufferElem<'a> {
     endpoint: Endpoint,
     size:     usize,
-    payload:  T
+    payload:  Managed<'a, [u8]>
 }
 
-impl<T: BorrowMut<[u8]>> BufferElem<T> {
+impl<'a> BufferElem<'a> {
     /// Create a buffered packet.
-    pub fn new(payload: T) -> BufferElem<T> {
+    pub fn new<T>(payload: T) -> BufferElem<'a>
+            where T: Into<Managed<'a, [u8]>> {
         BufferElem {
             endpoint: Endpoint::INVALID,
             size:     0,
-            payload:  payload
+            payload:  payload.into()
         }
     }
 }
 
 /// An UDP packet buffer.
 #[derive(Debug)]
-pub struct Buffer<
-    T: BorrowMut<[u8]>,
-    U: BorrowMut<[BufferElem<T>]>
-> {
-    storage: U,
+pub struct Buffer<'a> {
+    storage: Managed<'a, [BufferElem<'a>]>,
     read_at: usize,
-    length:  usize,
-    phantom: PhantomData<T>
+    length:  usize
 }
 
-impl<
-    T: BorrowMut<[u8]>,
-    U: BorrowMut<[BufferElem<T>]>
-> Buffer<T, U> {
+impl<'a> Buffer<'a> {
     /// Create a packet buffer with the given storage.
-    pub fn new(mut storage: U) -> Buffer<T, U> {
-        for elem in storage.borrow_mut() {
+    pub fn new<T>(storage: T) -> Buffer<'a>
+            where T: Into<Managed<'a, [BufferElem<'a>]>> {
+        let mut storage = storage.into();
+        for elem in storage.iter_mut() {
             elem.endpoint = Default::default();
             elem.size = 0;
         }
@@ -52,13 +48,12 @@ impl<
         Buffer {
             storage: storage,
             read_at: 0,
-            length:  0,
-            phantom: PhantomData
+            length:  0
         }
     }
 
     fn mask(&self, index: usize) -> usize {
-        index % self.storage.borrow().len()
+        index % self.storage.len()
     }
 
     fn incr(&self, index: usize) -> usize {
@@ -70,12 +65,12 @@ impl<
     }
 
     fn full(&self) -> bool {
-        self.length == self.storage.borrow().len()
+        self.length == self.storage.len()
     }
 
     /// Enqueue an element into the buffer, and return a pointer to it, or return
     /// `Err(Error::Exhausted)` if the buffer is full.
-    pub fn enqueue(&mut self) -> Result<&mut BufferElem<T>, Error> {
+    pub fn enqueue(&mut self) -> Result<&mut BufferElem<'a>, Error> {
         if self.full() {
             Err(Error::Exhausted)
         } else {
@@ -88,12 +83,12 @@ impl<
 
     /// Dequeue an element from the buffer, and return a pointer to it, or return
     /// `Err(Error::Exhausted)` if the buffer is empty.
-    pub fn dequeue(&mut self) -> Result<&BufferElem<T>, Error> {
+    pub fn dequeue(&mut self) -> Result<&BufferElem<'a>, Error> {
         if self.empty() {
             Err(Error::Exhausted)
         } else {
             self.length -= 1;
-            let result = &self.storage.borrow()[self.read_at];
+            let result = &self.storage[self.read_at];
             self.read_at = self.incr(self.read_at);
             Ok(result)
         }
@@ -104,22 +99,16 @@ impl<
 ///
 /// An UDP socket is bound to a specific endpoint, and owns transmit and receive
 /// packet buffers.
-pub struct UdpSocket<
-    T: BorrowMut<[u8]>,
-    U: BorrowMut<[BufferElem<T>]>
-> {
+pub struct UdpSocket<'a> {
     endpoint:  Endpoint,
-    rx_buffer: Buffer<T, U>,
-    tx_buffer: Buffer<T, U>
+    rx_buffer: Buffer<'a>,
+    tx_buffer: Buffer<'a>
 }
 
-impl<
-    T: BorrowMut<[u8]>,
-    U: BorrowMut<[BufferElem<T>]>
-> UdpSocket<T, U> {
+impl<'a> UdpSocket<'a> {
     /// Create an UDP socket with the given buffers.
-    pub fn new(endpoint: Endpoint, rx_buffer: Buffer<T, U>, tx_buffer: Buffer<T, U>)
-            -> UdpSocket<T, U> {
+    pub fn new(endpoint: Endpoint, rx_buffer: Buffer<'a>, tx_buffer: Buffer<'a>)
+            -> UdpSocket<'a> {
         UdpSocket {
             endpoint:  endpoint,
             rx_buffer: rx_buffer,
@@ -145,14 +134,11 @@ impl<
     /// This function returns `Err(Error::Exhausted)` if the receive buffer is empty.
     pub fn recv<R, F>(&mut self) -> Result<(Endpoint, &[u8]), Error> {
         let packet_buf = try!(self.rx_buffer.dequeue());
-        Ok((packet_buf.endpoint, &packet_buf.payload.borrow()[..packet_buf.size]))
+        Ok((packet_buf.endpoint, &packet_buf.payload[..packet_buf.size]))
     }
 }
 
-impl<
-    T: BorrowMut<[u8]>,
-    U: BorrowMut<[BufferElem<T>]>
-> Socket for UdpSocket<T, U> {
+impl<'a> Socket for UdpSocket<'a> {
     fn collect(&mut self, src_addr: &Address, dst_addr: &Address,
                protocol: ProtocolType, payload: &[u8])
             -> Result<(), Error> {
@@ -183,7 +169,7 @@ impl<
           &UdpRepr {
             src_port: self.endpoint.port,
             dst_port: packet_buf.endpoint.port,
-            payload:  packet_buf.payload.borrow()
+            payload:  &packet_buf.payload[..]
           })
     }
 }