소스 검색

Rework and test UDP sockets.

Before, errors such as packets not fitting into a buffer would have
resulted in panics, and errors such as unbound sockets were
simply ignored.
whitequark 7 년 전
부모
커밋
492fe3e4b1
5개의 변경된 파일335개의 추가작업 그리고 68개의 파일을 삭제
  1. 1 1
      examples/server.rs
  2. 1 1
      src/socket/tcp.rs
  3. 225 29
      src/socket/udp.rs
  4. 100 37
      src/storage/ring_buffer.rs
  5. 8 0
      src/wire/ip.rs

+ 1 - 1
examples/server.rs

@@ -64,7 +64,7 @@ fn main() {
         {
             let socket: &mut UdpSocket = sockets.get_mut(udp_handle).as_socket();
             if !socket.endpoint().is_specified() {
-                socket.bind(6969)
+                socket.bind(6969).unwrap()
             }
 
             let client = match socket.recv() {

+ 1 - 1
src/socket/tcp.rs

@@ -353,9 +353,9 @@ impl<'a> TcpSocket<'a> {
     pub fn listen<T>(&mut self, local_endpoint: T) -> Result<()>
             where T: Into<IpEndpoint> {
         let local_endpoint = local_endpoint.into();
+        if local_endpoint.port == 0 { return Err(Error::Unaddressable) }
 
         if self.is_open() { return Err(Error::Illegal) }
-        if local_endpoint.port == 0 { return Err(Error::Unaddressable) }
 
         self.reset();
         self.listen_address  = local_endpoint.addr;

+ 225 - 29
src/socket/udp.rs

@@ -1,3 +1,4 @@
+use core::cmp::min;
 use managed::Managed;
 
 use {Error, Result};
@@ -15,13 +16,6 @@ pub struct PacketBuffer<'a> {
     payload:  Managed<'a, [u8]>
 }
 
-impl<'a> Resettable for PacketBuffer<'a> {
-    fn reset(&mut self) {
-        self.endpoint = Default::default();
-        self.size = 0;
-    }
-}
-
 impl<'a> PacketBuffer<'a> {
     /// Create a buffered packet.
     pub fn new<T>(payload: T) -> PacketBuffer<'a>
@@ -40,6 +34,22 @@ impl<'a> PacketBuffer<'a> {
     fn as_mut<'b>(&'b mut self) -> &'b mut [u8] {
         &mut self.payload[..self.size]
     }
+
+    fn resize<'b>(&'b mut self, size: usize) -> Result<&'b mut Self> {
+        if self.payload.len() >= size {
+            self.size = size;
+            Ok(self)
+        } else {
+            Err(Error::Truncated)
+        }
+    }
+}
+
+impl<'a> Resettable for PacketBuffer<'a> {
+    fn reset(&mut self) {
+        self.endpoint = Default::default();
+        self.size = 0;
+    }
 }
 
 /// An UDP packet ring buffer.
@@ -90,8 +100,17 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
     }
 
     /// Bind the socket to the given endpoint.
-    pub fn bind<T: Into<IpEndpoint>>(&mut self, endpoint: T) {
-        self.endpoint = endpoint.into()
+    ///
+    /// Returns `Err(Error::Illegal)` if the socket is already bound,
+    /// and `Err(Error::Unaddressable)` if the port is unspecified.
+    pub fn bind<T: Into<IpEndpoint>>(&mut self, endpoint: T) -> Result<()> {
+        let endpoint = endpoint.into();
+        if endpoint.port == 0 { return Err(Error::Unaddressable) }
+
+        if self.endpoint.port != 0 { return Err(Error::Illegal) }
+
+        self.endpoint = endpoint;
+        Ok(())
     }
 
     /// Check whether the transmit buffer is full.
@@ -109,15 +128,18 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
     /// Enqueue a packet to be sent to a given remote endpoint, and return a pointer
     /// to its payload.
     ///
-    /// This function returns `Err(Error::Exhausted)` if the size is greater than
-    /// the transmit packet buffer size.
+    /// This function returns `Err(Error::Exhausted)` if the transmit buffer is full,
+    /// `Err(Error::Truncated)` if the requested size is larger than the packet buffer
+    /// size, and `Err(Error::Unaddressable)` if local or remote port, or remote address,
+    /// are unspecified.
     pub fn send(&mut self, size: usize, endpoint: IpEndpoint) -> Result<&mut [u8]> {
-        let packet_buf = self.tx_buffer.enqueue()?;
+        if self.endpoint.port == 0 { return Err(Error::Unaddressable) }
+        if !endpoint.is_specified() { return Err(Error::Unaddressable) }
+
+        let packet_buf = self.tx_buffer.try_enqueue(|buf| buf.resize(size))?;
         packet_buf.endpoint = endpoint;
-        packet_buf.size = size;
         net_trace!("[{}]{}:{}: buffer to send {} octets",
-                   self.debug_id, self.endpoint,
-                   packet_buf.endpoint, packet_buf.size);
+                   self.debug_id, self.endpoint, packet_buf.endpoint, size);
         Ok(&mut packet_buf.as_mut()[..size])
     }
 
@@ -125,9 +147,7 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
     ///
     /// See also [send](#method.send).
     pub fn send_slice(&mut self, data: &[u8], endpoint: IpEndpoint) -> Result<usize> {
-        let buffer = self.send(data.len(), endpoint)?;
-        let data = &data[..buffer.len()];
-        buffer.copy_from_slice(data);
+        self.send(data.len(), endpoint)?.copy_from_slice(data);
         Ok(data.len())
     }
 
@@ -140,7 +160,7 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
         net_trace!("[{}]{}:{}: receive {} buffered octets",
                    self.debug_id, self.endpoint,
                    packet_buf.endpoint, packet_buf.size);
-        Ok((&packet_buf.as_ref()[..packet_buf.size], packet_buf.endpoint))
+        Ok((&packet_buf.as_ref(), packet_buf.endpoint))
     }
 
     /// Dequeue a packet received from a remote endpoint, and return the endpoint as well
@@ -149,8 +169,9 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
     /// See also [recv](#method.recv).
     pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, IpEndpoint)> {
         let (buffer, endpoint) = self.recv()?;
-        data[..buffer.len()].copy_from_slice(buffer);
-        Ok((buffer.len(), endpoint))
+        let length = min(data.len(), buffer.len());
+        data[..length].copy_from_slice(&buffer[..length]);
+        Ok((length, endpoint))
     }
 
     pub(crate) fn process(&mut self, _timestamp: u64, ip_repr: &IpRepr,
@@ -160,15 +181,12 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
         let packet = UdpPacket::new_checked(&payload[..ip_repr.payload_len()])?;
         let repr = UdpRepr::parse(&packet, &ip_repr.src_addr(), &ip_repr.dst_addr())?;
 
-        if repr.dst_port != self.endpoint.port { return Err(Error::Rejected) }
-        if !self.endpoint.addr.is_unspecified() {
-            if self.endpoint.addr != ip_repr.dst_addr() { return Err(Error::Rejected) }
-        }
+        let endpoint = IpEndpoint { addr: ip_repr.src_addr(), port: repr.src_port };
+        if !self.endpoint.accepts(&endpoint) { return Err(Error::Rejected) }
 
-        let packet_buf = self.rx_buffer.enqueue()?;
-        packet_buf.endpoint = IpEndpoint { addr: ip_repr.src_addr(), port: repr.src_port };
-        packet_buf.size = repr.payload.len();
-        packet_buf.as_mut()[..repr.payload.len()].copy_from_slice(repr.payload);
+        let packet_buf = self.rx_buffer.try_enqueue(|buf| buf.resize(repr.payload.len()))?;
+        packet_buf.as_mut().copy_from_slice(repr.payload);
+        packet_buf.endpoint = endpoint;
         net_trace!("[{}]{}:{}: receiving {} octets",
                    self.debug_id, self.endpoint,
                    packet_buf.endpoint, packet_buf.size);
@@ -182,6 +200,7 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
         net_trace!("[{}]{}:{}: sending {} octets",
                    self.debug_id, self.endpoint,
                    packet_buf.endpoint, packet_buf.size);
+
         let repr = UdpRepr {
             src_port: self.endpoint.port,
             dst_port: packet_buf.endpoint.port,
@@ -207,3 +226,180 @@ impl<'a> IpPayload for UdpRepr<'a> {
         self.emit(&mut packet, &repr.src_addr(), &repr.dst_addr())
     }
 }
+
+#[cfg(test)]
+mod test {
+    use std::vec::Vec;
+    use wire::{IpAddress, Ipv4Address, IpRepr, Ipv4Repr, UdpRepr};
+    use socket::AsSocket;
+    use super::*;
+
+    fn buffer(packets: usize) -> SocketBuffer<'static, 'static> {
+        let mut storage = vec![];
+        for _ in 0..packets {
+            storage.push(PacketBuffer::new(vec![0; 16]))
+        }
+        SocketBuffer::new(storage)
+    }
+
+    fn socket(rx_buffer: SocketBuffer<'static, 'static>,
+              tx_buffer: SocketBuffer<'static, 'static>)
+            -> UdpSocket<'static, 'static> {
+        match UdpSocket::new(rx_buffer, tx_buffer) {
+            Socket::Udp(socket) => socket,
+            _ => unreachable!()
+        }
+    }
+
+    const LOCAL_IP:    IpAddress  = IpAddress::Ipv4(Ipv4Address([10, 0, 0, 1]));
+    const REMOTE_IP:   IpAddress  = IpAddress::Ipv4(Ipv4Address([10, 0, 0, 2]));
+    const LOCAL_PORT:  u16        = 53;
+    const REMOTE_PORT: u16        = 49500;
+    const LOCAL_END:   IpEndpoint = IpEndpoint { addr: LOCAL_IP,  port: LOCAL_PORT  };
+    const REMOTE_END:  IpEndpoint = IpEndpoint { addr: REMOTE_IP, port: REMOTE_PORT };
+
+    #[test]
+    fn test_bind_unaddressable() {
+        let mut socket = socket(buffer(0), buffer(0));
+        assert_eq!(socket.bind(0), Err(Error::Unaddressable));
+    }
+
+    #[test]
+    fn test_bind_twice() {
+        let mut socket = socket(buffer(0), buffer(0));
+        assert_eq!(socket.bind(1), Ok(()));
+        assert_eq!(socket.bind(2), Err(Error::Illegal));
+    }
+
+    const LOCAL_IP_REPR: IpRepr = IpRepr::Unspecified {
+        src_addr: LOCAL_IP,
+        dst_addr: REMOTE_IP,
+        protocol: IpProtocol::Udp,
+        payload_len: 8 + 6
+    };
+    const LOCAL_UDP_REPR: UdpRepr = UdpRepr {
+        src_port: LOCAL_PORT,
+        dst_port: REMOTE_PORT,
+        payload: b"abcdef"
+    };
+
+    #[test]
+    fn test_send_unaddressable() {
+        let mut socket = socket(buffer(0), buffer(1));
+        assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Err(Error::Unaddressable));
+        socket.bind(LOCAL_PORT);
+        assert_eq!(socket.send_slice(b"abcdef",
+                                     IpEndpoint { addr: IpAddress::Unspecified, ..REMOTE_END }),
+                   Err(Error::Unaddressable));
+        assert_eq!(socket.send_slice(b"abcdef",
+                                     IpEndpoint { port: 0, ..REMOTE_END }),
+                   Err(Error::Unaddressable));
+        assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Ok(6));
+    }
+
+    #[test]
+    fn test_send_truncated() {
+        let mut socket = socket(buffer(0), buffer(1));
+        socket.bind(LOCAL_END);
+        assert_eq!(socket.send_slice(&[0; 32][..], REMOTE_END), Err(Error::Truncated));
+    }
+
+    #[test]
+    fn test_send_dispatch() {
+        let limits = DeviceLimits::default();
+
+        let mut socket = socket(buffer(0), buffer(1));
+        socket.bind(LOCAL_END);
+
+        assert!(socket.can_send());
+        assert_eq!(socket.dispatch(0, &limits, &mut |ip_repr, ip_payload| {
+            unreachable!()
+        }), Err(Error::Exhausted) as Result<()>);
+
+        assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Ok(6));
+        assert_eq!(socket.send_slice(b"123456", REMOTE_END), Err(Error::Exhausted));
+        assert!(!socket.can_send());
+
+        macro_rules! assert_payload_eq {
+            ($ip_repr:expr, $ip_payload:expr, $expected:expr) => {{
+                let mut buffer = vec![0; $ip_payload.buffer_len()];
+                $ip_payload.emit($ip_repr, &mut buffer);
+                let udp_packet = UdpPacket::new_checked(&buffer).unwrap();
+                let udp_repr = UdpRepr::parse(&udp_packet, &LOCAL_IP, &REMOTE_IP).unwrap();
+                assert_eq!(&udp_repr, $expected)
+            }}
+        }
+
+        assert_eq!(socket.dispatch(0, &limits, &mut |ip_repr, ip_payload| {
+            assert_eq!(ip_repr, &LOCAL_IP_REPR);
+            assert_payload_eq!(ip_repr, ip_payload, &LOCAL_UDP_REPR);
+            Err(Error::Unaddressable)
+        }), Err(Error::Unaddressable) as Result<()>);
+        /*assert!(!socket.can_send());*/
+
+        assert_eq!(socket.dispatch(0, &limits, &mut |ip_repr, ip_payload| {
+            assert_eq!(ip_repr, &LOCAL_IP_REPR);
+            assert_payload_eq!(ip_repr, ip_payload, &LOCAL_UDP_REPR);
+            Ok(())
+        }), /*Ok(())*/ Err(Error::Exhausted));
+        assert!(socket.can_send());
+    }
+
+    const REMOTE_IP_REPR: IpRepr = IpRepr::Ipv4(Ipv4Repr {
+        src_addr: Ipv4Address([10, 0, 0, 2]),
+        dst_addr: Ipv4Address([10, 0, 0, 1]),
+        protocol: IpProtocol::Udp,
+        payload_len: 8 + 6
+    });
+    const REMOTE_UDP_REPR: UdpRepr = UdpRepr {
+        src_port: REMOTE_PORT,
+        dst_port: LOCAL_PORT,
+        payload: b"abcdef"
+    };
+
+    #[test]
+    fn test_recv_process() {
+        let mut socket = socket(buffer(1), buffer(0));
+        socket.bind(LOCAL_PORT);
+        assert!(!socket.can_recv());
+
+        let mut buffer = vec![0; REMOTE_UDP_REPR.buffer_len()];
+        REMOTE_UDP_REPR.emit(&mut UdpPacket::new(&mut buffer), &LOCAL_IP, &REMOTE_IP);
+
+        assert_eq!(socket.recv(), Err(Error::Exhausted));
+        assert_eq!(socket.process(0, &REMOTE_IP_REPR, &buffer),
+                   Ok(()));
+        assert!(socket.can_recv());
+
+        assert_eq!(socket.process(0, &REMOTE_IP_REPR, &buffer),
+                   Err(Error::Exhausted));
+        assert_eq!(socket.recv(), Ok((&b"abcdef"[..], REMOTE_END)));
+        assert!(!socket.can_recv());
+    }
+
+    #[test]
+    fn test_recv_truncated_slice() {
+        let mut socket = socket(buffer(1), buffer(0));
+        socket.bind(LOCAL_PORT);
+
+        let mut buffer = vec![0; REMOTE_UDP_REPR.buffer_len()];
+        REMOTE_UDP_REPR.emit(&mut UdpPacket::new(&mut buffer), &LOCAL_IP, &REMOTE_IP);
+        assert_eq!(socket.process(0, &REMOTE_IP_REPR, &buffer), Ok(()));
+
+        let mut slice = [0; 4];
+        assert_eq!(socket.recv_slice(&mut slice[..]), Ok((4, REMOTE_END)));
+        assert_eq!(&slice, b"abcd");
+    }
+
+    #[test]
+    fn test_recv_truncated_packet() {
+        let mut socket = socket(buffer(1), buffer(0));
+        socket.bind(LOCAL_PORT);
+
+        let udp_repr = UdpRepr { payload: &[0; 100][..], ..REMOTE_UDP_REPR };
+        let mut buffer = vec![0; udp_repr.buffer_len()];
+        udp_repr.emit(&mut UdpPacket::new(&mut buffer), &LOCAL_IP, &REMOTE_IP);
+        assert_eq!(socket.process(0, &REMOTE_IP_REPR, &buffer),
+                   Err(Error::Truncated));
+    }
+}

+ 100 - 37
src/storage/ring_buffer.rs

@@ -50,28 +50,55 @@ impl<'a, T: 'a> RingBuffer<'a, T> {
 
     /// Enqueue an element into the buffer, and return a pointer to it, or return
     /// `Err(Error::Exhausted)` if the buffer is full.
-    pub fn enqueue(&mut self) -> Result<&mut T> {
-        if self.full() {
-            Err(Error::Exhausted)
-        } else {
-            let index = self.mask(self.read_at + self.length);
-            let result = &mut self.storage[index];
-            self.length += 1;
-            Ok(result)
+    pub fn enqueue<'b>(&'b mut self) -> Result<&'b mut T> {
+        if self.full() { return Err(Error::Exhausted) }
+
+        let index = self.mask(self.read_at + self.length);
+        self.length += 1;
+        Ok(&mut self.storage[index])
+    }
+
+    /// Call `f` with a buffer element, and enqueue the element if `f` returns successfully, or
+    /// return `Err(Error::Exhausted)` if the buffer is full.
+    pub fn try_enqueue<'b, R, F>(&'b mut self, f: F) -> Result<R>
+            where F: Fn(&'b mut T) -> Result<R> {
+        if self.full() { return Err(Error::Exhausted) }
+
+        let index = self.mask(self.read_at + self.length);
+        match f(&mut self.storage[index]) {
+            Ok(result) => {
+                self.length += 1;
+                Ok(result)
+            }
+            Err(error) => Err(error)
         }
     }
 
     /// Dequeue an element from the buffer, and return a mutable reference to it, or return
     /// `Err(Error::Exhausted)` if the buffer is empty.
     pub fn dequeue(&mut self) -> Result<&mut T> {
-        if self.empty() {
-            Err(Error::Exhausted)
-        } else {
-            self.length -= 1;
-            let read_at = self.read_at;
-            self.read_at = self.incr(self.read_at);
-            let result = &mut self.storage[read_at];
-            Ok(result)
+        if self.empty() { return Err(Error::Exhausted) }
+
+        let read_at = self.read_at;
+        self.length -= 1;
+        self.read_at = self.incr(self.read_at);
+        Ok(&mut self.storage[read_at])
+    }
+
+    /// Call `f` with a buffer element, and dequeue the element if `f` returns successfully, or
+    /// return `Err(Error::Exhausted)` if the buffer is empty.
+    pub fn try_dequeue<'b, R, F>(&'b mut self, f: F) -> Result<R>
+            where F: Fn(&'b mut T) -> Result<R> {
+        if self.empty() { return Err(Error::Exhausted) }
+
+        let next_at = self.incr(self.read_at);
+        match f(&mut self.storage[self.read_at]) {
+            Ok(result) => {
+                self.length -= 1;
+                self.read_at = next_at;
+                Ok(result)
+            }
+            Err(error) => Err(error)
         }
     }
 }
@@ -86,33 +113,69 @@ mod test {
         }
     }
 
-    #[test]
-    pub fn test_buffer() {
-        const TEST_BUFFER_SIZE: usize = 5;
+    const SIZE: usize = 5;
+
+    fn buffer() -> RingBuffer<'static, usize> {
         let mut storage = vec![];
-        for i in 0..TEST_BUFFER_SIZE {
+        for i in 0..SIZE {
             storage.push(i + 10);
         }
 
-        let mut ring_buffer = RingBuffer::new(&mut storage[..]);
-        assert!(ring_buffer.empty());
-        assert!(!ring_buffer.full());
-        assert_eq!(ring_buffer.dequeue(), Err(Error::Exhausted));
-        ring_buffer.enqueue().unwrap();
-        assert!(!ring_buffer.empty());
-        assert!(!ring_buffer.full());
-        for i in 1..TEST_BUFFER_SIZE {
-            *ring_buffer.enqueue().unwrap() = i;
-            assert!(!ring_buffer.empty());
+        RingBuffer::new(storage)
+    }
+
+    #[test]
+    pub fn test_buffer() {
+        let mut buf = buffer();
+        assert!(buf.empty());
+        assert!(!buf.full());
+        assert_eq!(buf.dequeue(), Err(Error::Exhausted));
+
+        buf.enqueue().unwrap();
+        assert!(!buf.empty());
+        assert!(!buf.full());
+
+        for i in 1..SIZE {
+            *buf.enqueue().unwrap() = i;
+            assert!(!buf.empty());
+        }
+        assert!(buf.full());
+        assert_eq!(buf.enqueue(), Err(Error::Exhausted));
+
+        for i in 0..SIZE {
+            assert_eq!(*buf.dequeue().unwrap(), i);
+            assert!(!buf.full());
+        }
+        assert_eq!(buf.dequeue(), Err(Error::Exhausted));
+        assert!(buf.empty());
+    }
+
+    #[test]
+    pub fn test_buffer_try() {
+        let mut buf = buffer();
+        assert!(buf.empty());
+        assert!(!buf.full());
+        assert_eq!(buf.try_dequeue(|_| unreachable!()) as Result<()>,
+                   Err(Error::Exhausted));
+
+        buf.try_enqueue(|e| Ok(e)).unwrap();
+        assert!(!buf.empty());
+        assert!(!buf.full());
+
+        for i in 1..SIZE {
+            buf.try_enqueue(|e| Ok(*e = i)).unwrap();
+            assert!(!buf.empty());
         }
-        assert!(ring_buffer.full());
-        assert_eq!(ring_buffer.enqueue(), Err(Error::Exhausted));
+        assert!(buf.full());
+        assert_eq!(buf.try_enqueue(|_| unreachable!()) as Result<()>,
+                   Err(Error::Exhausted));
 
-        for i in 0..TEST_BUFFER_SIZE {
-            assert_eq!(*ring_buffer.dequeue().unwrap(), i);
-            assert!(!ring_buffer.full());
+        for i in 0..SIZE {
+            assert_eq!(buf.try_dequeue(|e| Ok(*e)).unwrap(), i);
+            assert!(!buf.full());
         }
-        assert_eq!(ring_buffer.dequeue(), Err(Error::Exhausted));
-        assert!(ring_buffer.empty());
+        assert_eq!(buf.try_dequeue(|_| unreachable!()) as Result<()>,
+                   Err(Error::Exhausted));
+        assert!(buf.empty());
     }
 }

+ 8 - 0
src/wire/ip.rs

@@ -124,6 +124,14 @@ impl Endpoint {
     pub fn is_specified(&self) -> bool {
         !self.addr.is_unspecified() && self.port != 0
     }
+
+    /// Query whether `self` should accept packets from `other`.
+    ///
+    /// Returns `true` if `other` is a specified endpoint and `self` either
+    /// has an unspecified address, or the addresses are equal.
+    pub fn accepts(&self, other: &Endpoint) -> bool {
+        other.is_specified() && (self.addr == other.addr || self.addr.is_unspecified())
+    }
 }
 
 impl fmt::Display for Endpoint {