Browse Source

Rework TcpSocket::{send,recv} to remove need for precomputing size.

Now, these functions give you the largest contiguous slice they can
grab, and you return however much you took from it.
whitequark 7 years ago
parent
commit
0091191cce
5 changed files with 121 additions and 78 deletions
  1. 4 4
      examples/client.rs
  2. 3 1
      examples/loopback.rs
  3. 12 10
      examples/server.rs
  4. 100 61
      src/socket/tcp.rs
  5. 2 2
      src/socket/udp.rs

+ 4 - 4
examples/client.rs

@@ -66,8 +66,8 @@ fn main() {
             tcp_active = socket.is_active();
 
             if socket.may_recv() {
-                let data = {
-                    let mut data = socket.recv(128).unwrap().to_owned();
+                let data = socket.recv(|data| {
+                    let mut data = data.to_owned();
                     if data.len() > 0 {
                         debug!("recv data: {:?}",
                                str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)"));
@@ -75,8 +75,8 @@ fn main() {
                         data.reverse();
                         data.extend(b"\n");
                     }
-                    data
-                };
+                    (data.len(), data)
+                }).unwrap();
                 if socket.can_send() && data.len() > 0 {
                     debug!("send data: {:?}",
                            str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)"));

+ 3 - 1
examples/loopback.rs

@@ -133,7 +133,9 @@ fn main() {
             }
 
             if socket.can_recv() {
-                debug!("got {:?}", str::from_utf8(socket.recv(32).unwrap()).unwrap());
+                debug!("got {:?}", socket.recv(|buffer| {
+                    (buffer.len(), str::from_utf8(buffer).unwrap())
+                }));
                 socket.close();
                 done = true;
             }

+ 12 - 10
examples/server.rs

@@ -121,8 +121,8 @@ fn main() {
             tcp_6970_active = socket.is_active();
 
             if socket.may_recv() {
-                let data = {
-                    let mut data = socket.recv(128).unwrap().to_owned();
+                let data = socket.recv(|buffer| {
+                    let mut data = buffer.to_owned();
                     if data.len() > 0 {
                         debug!("tcp:6970 recv data: {:?}",
                                str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)"));
@@ -130,8 +130,8 @@ fn main() {
                         data.reverse();
                         data.extend(b"\n");
                     }
-                    data
-                };
+                    (data.len(), data)
+                }).unwrap();
                 if socket.can_send() && data.len() > 0 {
                     debug!("tcp:6970 send data: {:?}",
                            str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)"));
@@ -153,11 +153,12 @@ fn main() {
             }
 
             if socket.may_recv() {
-                if let Ok(data) = socket.recv(65535) {
-                    if data.len() > 0 {
-                        debug!("tcp:6971 recv {:?} octets", data.len());
+                socket.recv(|buffer| {
+                    if buffer.len() > 0 {
+                        debug!("tcp:6971 recv {:?} octets", buffer.len());
                     }
-                }
+                    (buffer.len(), ())
+                }).unwrap();
             } else if socket.may_send() {
                 socket.close();
             }
@@ -171,14 +172,15 @@ fn main() {
             }
 
             if socket.may_send() {
-                if let Ok(data) = socket.send(65535) {
+                socket.send(|data| {
                     if data.len() > 0 {
                         debug!("tcp:6972 send {:?} octets", data.len());
                         for (i, b) in data.iter_mut().enumerate() {
                             *b = (i % 256) as u8;
                         }
                     }
-                }
+                    (data.len(), ())
+                }).unwrap();
             }
         }
 

+ 100 - 61
src/socket/tcp.rs

@@ -593,15 +593,8 @@ impl<'a> TcpSocket<'a> {
         !self.rx_buffer.is_empty()
     }
 
-    /// Enqueue a sequence of octets to be sent, and return a pointer to it.
-    ///
-    /// This function may return a slice smaller than the requested size in case
-    /// there is not enough contiguous free space in the transmit buffer, down to
-    /// an empty slice.
-    ///
-    /// This function returns `Err(Error::Illegal) if the transmit half of
-    /// the connection is not open; see [may_send](#method.may_send).
-    pub fn send(&mut self, size: usize) -> Result<&mut [u8]> {
+    fn send_impl<'b, F, R>(&'b mut self, f: F) -> Result<R>
+            where F: FnOnce(&'b mut SocketBuffer<'a>) -> (usize, R) {
         if !self.may_send() { return Err(Error::Illegal) }
 
         // The connection might have been idle for a long time, and so remote_last_ts
@@ -610,14 +603,26 @@ impl<'a> TcpSocket<'a> {
         if self.tx_buffer.is_empty() { self.remote_last_ts = None }
 
         let _old_length = self.tx_buffer.len();
-        let buffer = self.tx_buffer.enqueue_many(size);
-        if buffer.len() > 0 {
+        let (size, result) = f(&mut self.tx_buffer);
+        if size > 0 {
             #[cfg(any(test, feature = "verbose"))]
             net_trace!("{}:{}:{}: tx buffer: enqueueing {} octets (now {})",
                        self.handle, self.local_endpoint, self.remote_endpoint,
-                       buffer.len(), _old_length + buffer.len());
+                       size, _old_length + size);
         }
-        Ok(buffer)
+        Ok(result)
+    }
+
+    /// Call `f` with the largest contiguous slice of octets in the transmit buffer,
+    /// and enqueue the amount of elements returned by `f`.
+    ///
+    /// This function returns `Err(Error::Illegal) if the transmit half of
+    /// the connection is not open; see [may_send](#method.may_send).
+    pub fn send<'b, F, R>(&'b mut self, f: F) -> Result<R>
+            where F: FnOnce(&'b mut [u8]) -> (usize, R) {
+        self.send_impl(|tx_buffer| {
+            tx_buffer.enqueue_many_with(f)
+        })
     }
 
     /// Enqueue a sequence of octets to be sent, and fill it from a slice.
@@ -627,46 +632,42 @@ impl<'a> TcpSocket<'a> {
     ///
     /// See also [send](#method.send).
     pub fn send_slice(&mut self, data: &[u8]) -> Result<usize> {
-        if !self.may_send() { return Err(Error::Illegal) }
-
-        // See above.
-        if self.tx_buffer.is_empty() { self.remote_last_ts = None }
-
-        let _old_length = self.tx_buffer.len();
-        let enqueued = self.tx_buffer.enqueue_slice(data);
-        if enqueued != 0 {
-            #[cfg(any(test, feature = "verbose"))]
-            net_trace!("{}:{}:{}: tx buffer: enqueueing {} octets (now {})",
-                       self.handle, self.local_endpoint, self.remote_endpoint,
-                       enqueued, _old_length + enqueued);
-        }
-        Ok(enqueued)
+        self.send_impl(|tx_buffer| {
+            let size = tx_buffer.enqueue_slice(data);
+            (size, size)
+        })
     }
 
-    /// Dequeue a sequence of received octets, and return a pointer to it.
-    ///
-    /// This function may return a slice smaller than the requested size in case
-    /// there are not enough octets queued in the receive buffer, down to
-    /// an empty slice.
-    ///
-    /// This function returns `Err(Error::Illegal) if the receive half of
-    /// the connection is not open; see [may_recv](#method.may_recv).
-    pub fn recv(&mut self, size: usize) -> Result<&[u8]> {
+    pub fn recv_impl<'b, F, R>(&'b mut self, f: F) -> Result<R>
+            where F: FnOnce(&'b mut SocketBuffer<'a>) -> (usize, R) {
         // We may have received some data inside the initial SYN, but until the connection
         // is fully open we must not dequeue any data, as it may be overwritten by e.g.
-        // another (stale) SYN.
+        // another (stale) SYN. (We do not support TCP Fast Open.)
         if !self.may_recv() { return Err(Error::Illegal) }
 
         let _old_length = self.rx_buffer.len();
-        let buffer = self.rx_buffer.dequeue_many(size);
-        self.remote_seq_no += buffer.len();
-        if buffer.len() > 0 {
+        let (size, result) = f(&mut self.rx_buffer);
+        self.remote_seq_no += size;
+        if size > 0 {
             #[cfg(any(test, feature = "verbose"))]
             net_trace!("{}:{}:{}: rx buffer: dequeueing {} octets (now {})",
                        self.handle, self.local_endpoint, self.remote_endpoint,
-                       buffer.len(), _old_length - buffer.len());
+                       size, _old_length - size);
         }
-        Ok(buffer)
+        Ok(result)
+    }
+
+
+    /// Call `f` with the largest contiguous slice of octets in the receive buffer,
+    /// and dequeue the amount of elements returned by `f`.
+    ///
+    /// This function returns `Err(Error::Illegal) if the receive half of
+    /// the connection is not open; see [may_recv](#method.may_recv).
+    pub fn recv<'b, F, R>(&'b mut self, f: F) -> Result<R>
+            where F: FnOnce(&'b mut [u8]) -> (usize, R) {
+        self.recv_impl(|rx_buffer| {
+            rx_buffer.dequeue_many_with(f)
+        })
     }
 
     /// Dequeue a sequence of received octets, and fill a slice from it.
@@ -676,19 +677,10 @@ impl<'a> TcpSocket<'a> {
     ///
     /// See also [recv](#method.recv).
     pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<usize> {
-        // See recv() above.
-        if !self.may_recv() { return Err(Error::Illegal) }
-
-        let _old_length = self.rx_buffer.len();
-        let dequeued = self.rx_buffer.dequeue_slice(data);
-        self.remote_seq_no += dequeued;
-        if dequeued > 0 {
-            #[cfg(any(test, feature = "verbose"))]
-            net_trace!("{}:{}:{}: rx buffer: dequeueing {} octets (now {})",
-                       self.handle, self.local_endpoint, self.remote_endpoint,
-                       dequeued, _old_length - dequeued);
-        }
-        Ok(dequeued)
+        self.recv_impl(|rx_buffer| {
+            let size = rx_buffer.dequeue_slice(data);
+            (size, size)
+        })
     }
 
     /// Peek at a sequence of received octets without removing them from
@@ -3145,7 +3137,10 @@ mod test {
             ..RECV_TEMPL
         }]);
         recv!(s, time 0, Err(Error::Exhausted));
-        assert_eq!(s.recv(3), Ok(&b"abc"[..]));
+        s.recv(|buffer| {
+            assert_eq!(&buffer[..3], b"abc");
+            (3, ())
+        }).unwrap();
         recv!(s, time 0, Ok(TcpRepr {
             seq_number: LOCAL_SEQ + 1,
             ack_number: Some(REMOTE_SEQ + 1 + 6),
@@ -3153,7 +3148,10 @@ mod test {
             ..RECV_TEMPL
         }));
         recv!(s, time 0, Err(Error::Exhausted));
-        assert_eq!(s.recv(3), Ok(&b"def"[..]));
+        s.recv(|buffer| {
+            assert_eq!(buffer, b"def");
+            (buffer.len(), ())
+        }).unwrap();
         recv!(s, time 0, Ok(TcpRepr {
             seq_number: LOCAL_SEQ + 1,
             ack_number: Some(REMOTE_SEQ + 1 + 6),
@@ -3457,7 +3455,10 @@ mod test {
             ack_number: Some(REMOTE_SEQ + 1),
             ..RECV_TEMPL
         })));
-        assert_eq!(s.recv(10), Ok(&b""[..]));
+        s.recv(|buffer| {
+            assert_eq!(buffer, b"");
+            (buffer.len(), ())
+        }).unwrap();
         send!(s, TcpRepr {
             seq_number: REMOTE_SEQ + 1,
             ack_number: Some(LOCAL_SEQ + 1),
@@ -3469,11 +3470,14 @@ mod test {
             window_len: 58,
             ..RECV_TEMPL
         })));
-        assert_eq!(s.recv(10), Ok(&b"abcdef"[..]));
+        s.recv(|buffer| {
+            assert_eq!(buffer, b"abcdef");
+            (buffer.len(), ())
+        }).unwrap();
     }
 
     #[test]
-    fn test_buffer_wraparound() {
+    fn test_buffer_wraparound_rx() {
         let mut s = socket_established();
         s.rx_buffer = SocketBuffer::new(vec![0; 6]);
         s.assembler = Assembler::new(s.rx_buffer.capacity());
@@ -3483,7 +3487,10 @@ mod test {
             payload:    &b"abc"[..],
             ..SEND_TEMPL
         });
-        assert_eq!(s.recv(3), Ok(&b"abc"[..]));
+        s.recv(|buffer| {
+            assert_eq!(buffer, b"abc");
+            (buffer.len(), ())
+        }).unwrap();
         send!(s, TcpRepr {
             seq_number: REMOTE_SEQ + 1 + 3,
             ack_number: Some(LOCAL_SEQ + 1),
@@ -3495,6 +3502,38 @@ mod test {
         assert_eq!(data, &b"defghi"[..]);
     }
 
+    #[test]
+    fn test_buffer_wraparound_tx() {
+        let mut s = socket_established();
+        s.tx_buffer = SocketBuffer::new(vec![0; 6]);
+        assert_eq!(s.send_slice(b"abc"), Ok(3));
+        recv!(s, Ok(TcpRepr {
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload:    &b"abc"[..],
+            ..RECV_TEMPL
+        }));
+        send!(s, TcpRepr {
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1 + 3),
+            ..SEND_TEMPL
+        });
+        assert_eq!(s.send_slice(b"defghi"), Ok(6));
+        recv!(s, Ok(TcpRepr {
+            seq_number: LOCAL_SEQ + 1 + 3,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload:    &b"def"[..],
+            ..RECV_TEMPL
+        }));
+        // "defghi" not contiguous in tx buffer
+        recv!(s, Ok(TcpRepr {
+            seq_number: LOCAL_SEQ + 1 + 3 + 3,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload:    &b"ghi"[..],
+            ..RECV_TEMPL
+        }));
+    }
+
     // =========================================================================================//
     // Tests for packet filtering.
     // =========================================================================================//

+ 2 - 2
src/socket/udp.rs

@@ -195,8 +195,8 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
         Ok((&packet_buf.as_ref(), packet_buf.endpoint))
     }
 
-    /// Dequeue a packet received from a remote endpoint, and return the endpoint as well
-    /// as copy the payload into the given slice.
+    /// Dequeue a packet received from a remote endpoint, copy the payload into the given slice,
+    /// and return the amount of octets copied as well as the endpoint.
     ///
     /// See also [recv](#method.recv).
     pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, IpEndpoint)> {