Kaynağa Gözat

Generalize the TCP tests to accept multiple packets.

whitequark 8 yıl önce
ebeveyn
işleme
0ad1ac0ef2
3 değiştirilmiş dosya ile 60 ekleme ve 48 silme
  1. 2 2
      src/socket/mod.rs
  2. 56 44
      src/socket/tcp.rs
  3. 2 2
      src/socket/udp.rs

+ 2 - 2
src/socket/mod.rs

@@ -69,8 +69,8 @@ impl<'a, 'b> Socket<'a, 'b> {
     /// is returned.
     ///
     /// This function is used internally by the networking stack.
-    pub fn dispatch<F>(&mut self, emit: &mut F) -> Result<(), Error>
-            where F: FnMut(&IpRepr, &IpPayload) -> Result<(), Error> {
+    pub fn dispatch<F, R>(&mut self, emit: &mut F) -> Result<R, Error>
+            where F: FnMut(&IpRepr, &IpPayload) -> Result<R, Error> {
         match self {
             &mut Socket::Udp(ref mut socket) =>
                 socket.dispatch(emit),

+ 56 - 44
src/socket/tcp.rs

@@ -368,8 +368,8 @@ impl<'a> TcpSocket<'a> {
     }
 
     /// See [Socket::dispatch](enum.Socket.html#method.dispatch).
-    pub fn dispatch<F>(&mut self, emit: &mut F) -> Result<(), Error>
-            where F: FnMut(&IpRepr, &IpPayload) -> Result<(), Error> {
+    pub fn dispatch<F, R>(&mut self, emit: &mut F) -> Result<R, Error>
+            where F: FnMut(&IpRepr, &IpPayload) -> Result<R, Error> {
         let mut repr = TcpRepr {
             src_port:   self.local_endpoint.port,
             dst_port:   self.remote_endpoint.port,
@@ -480,42 +480,54 @@ mod test {
         window_len: 128, payload: &[]
     };
 
+    fn send(socket: &mut TcpSocket, repr: &TcpRepr) -> Result<(), Error> {
+        let mut buffer = vec![0; repr.buffer_len()];
+        let mut packet = TcpPacket::new(&mut buffer).unwrap();
+        repr.emit(&mut packet, &REMOTE_IP, &LOCAL_IP);
+        let ip_repr = IpRepr::Unspecified {
+            src_addr: REMOTE_IP,
+            dst_addr: LOCAL_IP,
+            protocol: IpProtocol::Tcp
+        };
+        socket.collect(&ip_repr, &packet.into_inner()[..])
+    }
+
+    fn recv<F>(socket: &mut TcpSocket, mut f: F)
+            where F: FnMut(Result<TcpRepr, Error>) {
+        let mut buffer = vec![];
+        let result = socket.dispatch(&mut |ip_repr, payload| {
+            assert_eq!(ip_repr.protocol(), IpProtocol::Tcp);
+            assert_eq!(ip_repr.src_addr(), LOCAL_IP);
+            assert_eq!(ip_repr.dst_addr(), REMOTE_IP);
+
+            buffer.resize(payload.buffer_len(), 0);
+            payload.emit(&ip_repr, &mut buffer[..]);
+            let packet = TcpPacket::new(&buffer[..]).unwrap();
+            let repr = try!(TcpRepr::parse(&packet, &ip_repr.src_addr(), &ip_repr.dst_addr()));
+            Ok(f(Ok(repr)))
+        });
+        // Appease borrow checker.
+        match result {
+            Ok(()) => (),
+            Err(e) => f(Err(e))
+        }
+    }
+
     macro_rules! send {
-        ($socket:ident, $repr:expr) => ({
-            let repr = $repr;
-            let mut buffer = vec![0; repr.buffer_len()];
-            let mut packet = TcpPacket::new(&mut buffer).unwrap();
-            repr.emit(&mut packet, &REMOTE_IP, &LOCAL_IP);
-            let ip_repr = IpRepr::Unspecified {
-                src_addr: REMOTE_IP,
-                dst_addr: LOCAL_IP,
-                protocol: IpProtocol::Tcp
-            };
-            let result = $socket.collect(&ip_repr, &packet.into_inner()[..]);
-            result.expect("send error")
-        })
+        ($socket:ident, [$( $repr:expr )*]) => ({
+            $( send!($socket, $repr, Ok(())); )*
+        });
+        ($socket:ident, $repr:expr, $result:expr) =>
+            (assert_eq!(send(&mut $socket, &$repr), $result))
     }
 
     macro_rules! recv {
-        ($socket:ident, $expected:expr) => ({
-            let result = $socket.dispatch(&mut |ip_repr, payload| {
-                assert_eq!(ip_repr.protocol(), IpProtocol::Tcp);
-                assert_eq!(ip_repr.src_addr(), LOCAL_IP);
-                assert_eq!(ip_repr.dst_addr(), REMOTE_IP);
-
-                let mut buffer = vec![0; payload.buffer_len()];
-                payload.emit(&ip_repr, &mut buffer[..]);
-                let packet = TcpPacket::new(&buffer[..]).unwrap();
-                let repr = TcpRepr::parse(&packet, &ip_repr.src_addr(), &ip_repr.dst_addr());
-                assert_eq!(repr, Ok($expected));
-                Ok(())
-            });
-            assert_eq!(result, Ok(()));
-            let result = $socket.dispatch(&mut |_repr, _payload| {
-                Ok(())
-            });
-            assert_eq!(result, Err(Error::Exhausted));
-        })
+        ($socket:ident, [$( $repr:expr )*]) => ({
+            $( recv!($socket, Ok($repr)); )*
+            recv!($socket, Err(Error::Exhausted))
+        });
+        ($socket:ident, $result:expr) =>
+            (recv(&mut $socket, |repr| assert_eq!(repr, $result)))
     }
 
     fn init_logger() {
@@ -559,26 +571,26 @@ mod test {
         s.listen(IpEndpoint::new(IpAddress::default(), LOCAL_PORT));
         assert_eq!(s.state(), State::Listen);
 
-        send!(s, TcpRepr {
+        send!(s, [TcpRepr {
             control: TcpControl::Syn,
             seq_number: REMOTE_SEQ,
             ack_number: None,
             ..SEND_TEMPL
-        });
+        }]);
         assert_eq!(s.state(), State::SynReceived);
         assert_eq!(s.local_endpoint(), LOCAL_END);
         assert_eq!(s.remote_endpoint(), REMOTE_END);
-        recv!(s, TcpRepr {
+        recv!(s, [TcpRepr {
             control: TcpControl::Syn,
             seq_number: LOCAL_SEQ,
             ack_number: Some(REMOTE_SEQ + 1),
             ..RECV_TEMPL
-        });
-        send!(s, TcpRepr {
+        }]);
+        send!(s, [TcpRepr {
             seq_number: REMOTE_SEQ + 1,
             ack_number: Some(LOCAL_SEQ + 1),
             ..SEND_TEMPL
-        });
+        }]);
         assert_eq!(s.state(), State::Established);
         assert_eq!(s.local_seq_no, LOCAL_SEQ + 1);
         assert_eq!(s.remote_seq_no, REMOTE_SEQ + 1);
@@ -593,18 +605,18 @@ mod test {
         s.local_seq_no    = LOCAL_SEQ + 1;
         s.remote_seq_no   = REMOTE_SEQ + 1;
 
-        send!(s, TcpRepr {
+        send!(s, [TcpRepr {
             seq_number: REMOTE_SEQ + 1,
             ack_number: Some(LOCAL_SEQ + 1),
             payload: &b"abcdef"[..],
             ..SEND_TEMPL
-        });
-        recv!(s, TcpRepr {
+        }]);
+        recv!(s, [TcpRepr {
             seq_number: LOCAL_SEQ + 1,
             ack_number: Some(REMOTE_SEQ + 1 + 6),
             window_len: 122,
             ..RECV_TEMPL
-        });
+        }]);
         assert_eq!(s.rx_buffer.dequeue(6), &b"abcdef"[..]);
     }
 }

+ 2 - 2
src/socket/udp.rs

@@ -189,8 +189,8 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
     }
 
     /// See [Socket::dispatch](enum.Socket.html#method.dispatch).
-    pub fn dispatch<F>(&mut self, emit: &mut F) -> Result<(), Error>
-            where F: FnMut(&IpRepr, &IpPayload) -> Result<(), Error> {
+    pub fn dispatch<F, R>(&mut self, emit: &mut F) -> Result<R, Error>
+            where F: FnMut(&IpRepr, &IpPayload) -> Result<R, Error> {
         let packet_buf = try!(self.tx_buffer.dequeue());
         net_trace!("udp:{}:{}: dispatch {} octets",
                    self.endpoint, packet_buf.endpoint, packet_buf.size);