浏览代码

tcp: add Nagle's Algorithm.

Dario Nieuwenhuis 3 年之前
父节点
当前提交
b82b7300aa
共有 5 个文件被更改,包括 140 次插入27 次删除
  1. 127 17
      src/socket/tcp.rs
  2. 3 0
      src/wire/ipv4.rs
  3. 3 0
      src/wire/ipv6.rs
  4. 5 2
      src/wire/mod.rs
  5. 2 8
      src/wire/tcp.rs

+ 127 - 17
src/socket/tcp.rs

@@ -12,7 +12,7 @@ use crate::socket::{Socket, SocketMeta, SocketHandle, PollAt, Context};
 use crate::storage::{Assembler, RingBuffer};
 #[cfg(feature = "async")]
 use crate::socket::WakerRegistration;
-use crate::wire::{IpProtocol, IpRepr, IpAddress, IpEndpoint, TcpSeqNumber, TcpRepr, TcpControl};
+use crate::wire::{IpProtocol, IpRepr, IpAddress, IpEndpoint, TcpSeqNumber, TcpRepr, TcpControl, TCP_HEADER_LEN};
 
 /// A TCP socket ring buffer.
 pub type SocketBuffer<'a> = RingBuffer<'a, u8>;
@@ -354,6 +354,9 @@ pub struct TcpSocket<'a> {
     /// ACK or window updates (ie, no data) won't be sent until expiry.
     ack_delay_until: Option<Instant>,
 
+    /// Nagle's Algorithm enabled.
+    nagle: bool,
+
     #[cfg(feature = "async")]
     rx_waker: WakerRegistration,
     #[cfg(feature = "async")]
@@ -412,6 +415,7 @@ impl<'a> TcpSocket<'a> {
             local_rx_dup_acks: 0,
             ack_delay:       Some(ACK_DELAY_DEFAULT),
             ack_delay_until: None,
+            nagle: true,
 
             #[cfg(feature = "async")]
             rx_waker: WakerRegistration::new(),
@@ -475,6 +479,13 @@ impl<'a> TcpSocket<'a> {
         self.ack_delay
     }
 
+    /// Return whether Nagle's Algorithm is enabled.
+    ///
+    /// See also the [set_nagle_enabled](#method.set_nagle_enabled) method.
+    pub fn nagle_enabled(&self) -> Option<Duration> {
+        self.ack_delay
+    }
+
     /// Return the current window field value, including scaling according to RFC 1323.
     ///
     /// Used in internal calculations as well as packet generation.
@@ -507,6 +518,22 @@ impl<'a> TcpSocket<'a> {
         self.ack_delay = duration
     }
 
+    /// Enable or disable Nagle's Algorithm.
+    ///
+    /// Also known as "tinygram prevention". By default, it is enabled. 
+    /// Disabling it is equivalent to Linux's TCP_NODELAY flag.
+    ///
+    /// When enabled, Nagle's Algorithm prevents sending segments smaller than MSS if 
+    /// there is data in flight (sent but not acknowledged). In other words, it ensures
+    /// at most only one segment smaller than MSS is in flight at a time.
+    ///
+    /// It ensures better network utilization by preventing sending many very small packets,
+    /// at the cost of increased latency in some situations, particularly when the remote peer
+    /// has ACK delay enabled.
+    pub fn set_nagle_enabled(&mut self, enabled: bool) {
+        self.nagle = enabled
+    }
+
     /// Return the keep-alive interval.
     ///
     /// See also the [set_keep_alive](#method.set_keep_alive) method.
@@ -609,6 +636,7 @@ impl<'a> TcpSocket<'a> {
         self.remote_last_ts  = None;
         self.ack_delay       = Some(ACK_DELAY_DEFAULT);
         self.ack_delay_until = None;
+        self.nagle = true;
 
         #[cfg(feature = "async")]
         {
@@ -1639,12 +1667,38 @@ impl<'a> TcpSocket<'a> {
         }
     }
 
-    fn seq_to_transmit(&self) -> bool {
-        // We can send data if we have data that:
-        // - hasn't been sent before
-        // - fits in the remote window
-        let can_data = self.remote_last_seq
-            < self.local_seq_no + core::cmp::min(self.remote_win_len, self.tx_buffer.len());
+    fn seq_to_transmit(&self, cx: &Context) -> bool {
+        let ip_header_len = match self.local_endpoint.addr {
+            #[cfg(feature = "proto-ipv4")]
+            IpAddress::Ipv4(_) => crate::wire::IPV4_HEADER_LEN,
+            #[cfg(feature = "proto-ipv6")]
+            IpAddress::Ipv6(_) => crate::wire::IPV6_HEADER_LEN,
+            IpAddress::Unspecified => unreachable!(),
+        };
+
+        // Max segment size we're able to send due to MTU limitations.
+        let local_mss = cx.caps.ip_mtu() - ip_header_len - TCP_HEADER_LEN;
+
+        // The effective max segment size, taking into account our and remote's limits.
+        let effective_mss = local_mss.min(self.remote_mss);
+
+        // Have we sent data that hasn't been ACKed yet?
+        let data_in_flight = self.remote_last_seq != self.local_seq_no;
+
+        // max sequence number we can send.
+        let max_send_seq = self.local_seq_no + core::cmp::min(self.remote_win_len, self.tx_buffer.len());
+
+        // Max amount of octets we can send.
+        let max_send = if max_send_seq >= self.remote_last_seq {
+            max_send_seq - self.remote_last_seq
+        } else {
+            0
+        };
+
+        // Can we send at least 1 octet?
+        let mut can_send = max_send != 0;
+        // Can we send at least 1 full segment?
+        let can_send_full = max_send >= effective_mss;
 
         // Do we have to send a FIN?
         let want_fin = match self.state {
@@ -1654,6 +1708,10 @@ impl<'a> TcpSocket<'a> {
             _ => false,
         };
 
+        if self.nagle && data_in_flight && !can_send_full {
+            can_send = false;
+        }
+
         // Can we actually send the FIN? We can send it if:
         // 1. We have unsent data that fits in the remote window.
         // 2. We have no unsent data.
@@ -1661,7 +1719,7 @@ impl<'a> TcpSocket<'a> {
         let can_fin =
             want_fin && self.remote_last_seq == self.local_seq_no + self.tx_buffer.len();
 
-        can_data || can_fin
+        can_send || can_fin
     }
 
     fn delayed_ack_expired(&self, timestamp: Instant) -> bool {
@@ -1708,7 +1766,7 @@ impl<'a> TcpSocket<'a> {
             net_debug!("{}:{}:{}: timeout exceeded",
                        self.meta.handle, self.local_endpoint, self.remote_endpoint);
             self.set_state(State::Closed);
-        } else if !self.seq_to_transmit() {
+        } else if !self.seq_to_transmit(cx) {
             if let Some(retransmit_delta) = self.timer.should_retransmit(cx.now) {
                 // If a retransmit timer expired, we should resend data starting at the last ACK.
                 net_debug!("{}:{}:{}: retransmitting at t+{}",
@@ -1720,7 +1778,7 @@ impl<'a> TcpSocket<'a> {
         }
 
         // Decide whether we're sending a packet.
-        if self.seq_to_transmit() {
+        if self.seq_to_transmit(cx) {
             // If we have data to transmit and it fits into partner's window, do it.
             net_trace!("{}:{}:{}: outgoing segment will send data or flags",
                        self.meta.handle, self.local_endpoint, self.remote_endpoint);
@@ -1832,7 +1890,7 @@ impl<'a> TcpSocket<'a> {
                 // 3. MSS we can send, determined by our MTU.
                 let size = win_limit
                     .min(self.remote_mss)
-                    .min(cx.caps.ip_mtu() - ip_repr.buffer_len() - repr.mss_header_len());
+                    .min(cx.caps.ip_mtu() - ip_repr.buffer_len() - TCP_HEADER_LEN);
 
                 let offset = self.remote_last_seq - self.local_seq_no;
                 repr.payload = self.tx_buffer.get_allocated(offset, size);
@@ -1894,9 +1952,7 @@ impl<'a> TcpSocket<'a> {
 
         if repr.control == TcpControl::Syn {
             // Fill the MSS option. See RFC 6691 for an explanation of this calculation.
-            let mut max_segment_size = cx.caps.ip_mtu();
-            max_segment_size -= ip_repr.buffer_len();
-            max_segment_size -= repr.mss_header_len();
+            let max_segment_size = cx.caps.ip_mtu() - ip_repr.buffer_len() - TCP_HEADER_LEN;
             repr.max_seg_size = Some(max_segment_size as u16);
         }
 
@@ -1936,7 +1992,7 @@ impl<'a> TcpSocket<'a> {
             self.rtte.on_send(cx.now, repr.seq_number + repr.segment_len());
         }
 
-        if !self.seq_to_transmit() && repr.segment_len() > 0 {
+        if !self.seq_to_transmit(cx) && repr.segment_len() > 0 {
             // If we've transmitted all data we could (and there was something at all,
             // data or flag, to transmit, not just an ACK), wind up the retransmit timer.
             self.timer.set_for_retransmit(cx.now, self.rtte.retransmission_timeout());
@@ -1952,7 +2008,7 @@ impl<'a> TcpSocket<'a> {
     }
 
     #[allow(clippy::if_same_then_else)]
-    pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt {
+    pub(crate) fn poll_at(&self, cx: &Context) -> PollAt {
         // The logic here mirrors the beginning of dispatch() closely.
         if !self.remote_endpoint.is_specified() {
             // No one to talk to, nothing to transmit.
@@ -1963,7 +2019,7 @@ impl<'a> TcpSocket<'a> {
         } else if self.state == State::Closed {
             // Socket was aborted, we have an RST packet to transmit.
             PollAt::Now
-        } else if self.seq_to_transmit() {
+        } else if self.seq_to_transmit(cx) {
             // We have a data or flag packet to transmit.
             PollAt::Now
         } else {
@@ -3043,6 +3099,7 @@ mod test {
     #[test]
     fn test_established_send_no_ack_send() {
         let mut s = socket_established();
+        s.set_nagle_enabled(false);
         s.send_slice(b"abcdef").unwrap();
         recv!(s, [TcpRepr {
             seq_number: LOCAL_SEQ + 1,
@@ -5121,6 +5178,8 @@ mod test {
     #[test]
     fn test_buffer_wraparound_tx() {
         let mut s = socket_established();
+        s.set_nagle_enabled(false);
+
         s.tx_buffer = SocketBuffer::new(vec![b'.'; 9]);
         assert_eq!(s.send_slice(b"xxxyyy"), Ok(6));
         assert_eq!(s.tx_buffer.dequeue_many(3), &b"xxx"[..]);
@@ -5409,6 +5468,57 @@ mod test {
         }));
     }
 
+    // =========================================================================================//
+    // Tests for Nagle's Algorithm
+    // =========================================================================================//
+
+    #[test]
+    fn test_nagle() {
+        let mut s = socket_established();
+        s.remote_mss = 6;
+
+        s.send_slice(b"abcdef").unwrap();
+        recv!(s, [TcpRepr {
+            seq_number: LOCAL_SEQ + 1,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload: &b"abcdef"[..],
+            ..RECV_TEMPL
+        }]);
+
+        // If there's data in flight, full segments get sent.
+        s.send_slice(b"foobar").unwrap();
+        recv!(s, [TcpRepr {
+            seq_number: LOCAL_SEQ + 1 + 6,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload: &b"foobar"[..],
+            ..RECV_TEMPL
+        }]);
+
+        s.send_slice(b"aaabbbccc").unwrap();
+        // If there's data in flight, not-full segments don't get sent.
+        recv!(s, [TcpRepr {
+            seq_number: LOCAL_SEQ + 1 + 6 + 6,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload: &b"aaabbb"[..],
+            ..RECV_TEMPL
+        }]);
+
+        // Data gets ACKd, so there's no longer data in flight
+        send!(s, TcpRepr {
+            seq_number: REMOTE_SEQ + 1,
+            ack_number: Some(LOCAL_SEQ + 1 + 6 + 6 + 6),
+            ..SEND_TEMPL
+        });
+
+        // Now non-full segment gets sent.
+        recv!(s, [TcpRepr {
+            seq_number: LOCAL_SEQ + 1 + 6 + 6 + 6,
+            ack_number: Some(REMOTE_SEQ + 1),
+            payload: &b"ccc"[..],
+            ..RECV_TEMPL
+        }]);
+    }
+
     // =========================================================================================//
     // Tests for packet filtering.
     // =========================================================================================//

+ 3 - 0
src/wire/ipv4.rs

@@ -263,6 +263,9 @@ mod field {
     pub const DST_ADDR: Field = 16..20;
 }
 
+pub const HEADER_LEN: usize = field::DST_ADDR.end;
+
+
 impl<T: AsRef<[u8]>> Packet<T> {
     /// Imbue a raw octet buffer with IPv4 packet structure.
     pub fn new_unchecked(buffer: T) -> Packet<T> {

+ 3 - 0
src/wire/ipv6.rs

@@ -380,6 +380,9 @@ mod field {
     pub const DST_ADDR:    Field = 24..40;
 }
 
+/// Length of an IPv6 header.
+pub const HEADER_LEN: usize = field::DST_ADDR.end;
+
 impl<T: AsRef<[u8]>> Packet<T> {
     /// Create a raw octet buffer with an IPv6 packet structure.
     #[inline]

+ 5 - 2
src/wire/mod.rs

@@ -140,13 +140,15 @@ pub use self::ipv4::{Address as Ipv4Address,
                      Packet as Ipv4Packet,
                      Repr as Ipv4Repr,
                      Cidr as Ipv4Cidr,
-                     MIN_MTU as IPV4_MIN_MTU};
+                         HEADER_LEN as IPV4_HEADER_LEN,
+                         MIN_MTU as IPV4_MIN_MTU};
 
 #[cfg(feature = "proto-ipv6")]
 pub use self::ipv6::{Address as Ipv6Address,
                      Packet as Ipv6Packet,
                      Repr as Ipv6Repr,
                      Cidr as Ipv6Cidr,
+                     HEADER_LEN as IPV6_HEADER_LEN,
                      MIN_MTU as IPV6_MIN_MTU};
 
 #[cfg(feature = "proto-ipv6")]
@@ -218,7 +220,8 @@ pub use self::tcp::{SeqNumber as TcpSeqNumber,
                     Packet as TcpPacket,
                     TcpOption,
                     Repr as TcpRepr,
-                    Control as TcpControl};
+                    Control as TcpControl,
+                    HEADER_LEN as TCP_HEADER_LEN};
 
 #[cfg(feature = "proto-dhcpv4")]
 pub use self::dhcpv4::{Packet as DhcpPacket,

+ 2 - 8
src/wire/tcp.rs

@@ -109,6 +109,8 @@ mod field {
     pub const OPT_SACKRNG:  u8 = 0x05;
 }
 
+pub const HEADER_LEN: usize = field::URGENT.end;
+
 impl<T: AsRef<[u8]>> Packet<T> {
     /// Imbue a raw octet buffer with TCP packet structure.
     pub fn new_unchecked(buffer: T) -> Packet<T> {
@@ -857,14 +859,6 @@ impl<'a> Repr<'a> {
         length
     }
 
-    /// Return the length of the header for the TCP protocol.
-    ///
-    /// Per RFC 6691, this should be used for MSS calculations. It may be smaller than the buffer
-    /// space required to accomodate this packet's data.
-    pub fn mss_header_len(&self) -> usize {
-        field::URGENT.end
-    }
-
     /// Return the length of a packet that will be emitted from this high-level representation.
     pub fn buffer_len(&self) -> usize {
         self.header_len() + self.payload.len()