瀏覽代碼

Rework responses to TCP packets and factor in RST replies to TcpSocket.

whitequark 7 年之前
父節點
當前提交
bc2a894c00
共有 4 個文件被更改,包括 142 次插入99 次删除
  1. 65 85
      src/iface/ethernet.rs
  2. 46 8
      src/socket/tcp.rs
  3. 17 0
      src/wire/ip.rs
  4. 14 6
      src/wire/tcp.rs

+ 65 - 85
src/iface/ethernet.rs

@@ -27,7 +27,7 @@ enum Response<'a> {
     Nop,
     Arp(ArpRepr),
     Icmpv4(Ipv4Repr, Icmpv4Repr<'a>),
-    Tcpv4(Ipv4Repr, TcpRepr<'a>)
+    Tcp(IpRepr, TcpRepr<'a>)
 }
 
 impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
@@ -220,10 +220,10 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
         match ipv4_repr.protocol {
             IpProtocol::Icmp =>
                 Self::process_icmpv4(ipv4_repr, ipv4_packet.payload()),
-            IpProtocol::Tcp =>
-                Self::process_tcpv4(sockets, timestamp, ipv4_repr, ipv4_packet.payload()),
             IpProtocol::Udp =>
                 Self::process_udpv4(sockets, timestamp, ipv4_repr, ipv4_packet.payload()),
+            IpProtocol::Tcp =>
+                Self::process_tcp(sockets, timestamp, ipv4_repr.into(), ipv4_packet.payload()),
             _ if handled_by_raw_socket =>
                 Ok(Response::Nop),
             _ => {
@@ -307,11 +307,9 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
         Ok(Response::Icmpv4(ipv4_reply_repr, icmp_reply_repr))
     }
 
-    fn process_tcpv4<'frame>(sockets: &mut SocketSet, timestamp: u64,
-                             ipv4_repr: Ipv4Repr, ip_payload: &'frame [u8]) ->
-                            Result<Response<'frame>> {
-        let ip_repr = IpRepr::Ipv4(ipv4_repr);
-
+    fn process_tcp<'frame>(sockets: &mut SocketSet, timestamp: u64,
+                           ip_repr: IpRepr, ip_payload: &'frame [u8]) ->
+                          Result<Response<'frame>> {
         for tcp_socket in sockets.iter_mut().filter_map(
                 <Socket as AsSocket<TcpSocket>>::try_as_socket) {
             match tcp_socket.process(timestamp, &ip_repr, ip_payload) {
@@ -327,99 +325,81 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
 
         // The packet wasn't handled by a socket, send a TCP RST packet.
         let tcp_packet = TcpPacket::new_checked(ip_payload)?;
-        if tcp_packet.rst() {
-            // Don't reply to a TCP RST packet with another TCP RST packet.
-            return Ok(Response::Nop)
+        let tcp_repr = TcpRepr::parse(&tcp_packet, &ip_repr.src_addr(), &ip_repr.dst_addr())?;
+        if tcp_repr.control == TcpControl::Rst {
+            // Never reply to a TCP RST packet with another TCP RST packet.
+            Ok(Response::Nop)
+        } else {
+            let (ip_reply_repr, tcp_reply_repr) = TcpSocket::rst_reply(&ip_repr, &tcp_repr);
+            Ok(Response::Tcp(ip_reply_repr, tcp_reply_repr))
         }
-        let tcp_reply_repr = TcpRepr {
-            src_port:     tcp_packet.dst_port(),
-            dst_port:     tcp_packet.src_port(),
-            control:      TcpControl::Rst,
-            push:         false,
-            seq_number:   tcp_packet.ack_number(),
-            ack_number:   Some(tcp_packet.seq_number() +
-                               tcp_packet.segment_len()),
-            window_len:   0,
-            max_seg_size: None,
-            payload:      &[]
-        };
-        let ipv4_reply_repr = Ipv4Repr {
-            src_addr:    ipv4_repr.dst_addr,
-            dst_addr:    ipv4_repr.src_addr,
-            protocol:    IpProtocol::Tcp,
-            payload_len: tcp_reply_repr.buffer_len()
-        };
-        Ok(Response::Tcpv4(ipv4_reply_repr, tcp_reply_repr))
     }
 
     fn send_response(&mut self, timestamp: u64, response: Response) -> Result<()> {
-        macro_rules! ip_response {
-            ($tx_buffer:ident, $frame:ident, $ip_repr:ident) => ({
-                let dst_hardware_addr =
-                    match self.arp_cache.lookup(&$ip_repr.dst_addr.into()) {
-                        None => return Err(Error::Unaddressable),
-                        Some(hardware_addr) => hardware_addr
-                    };
-
-                let tx_len = EthernetFrame::<&[u8]>::buffer_len($ip_repr.buffer_len() +
-                                                                $ip_repr.payload_len);
-                $tx_buffer = self.device.transmit(timestamp, tx_len)?;
-                debug_assert!($tx_buffer.as_ref().len() == tx_len);
+        macro_rules! emit_packet {
+            (Ethernet, $buffer_len:expr, |$frame:ident| $code:stmt) => ({
+                let tx_len = EthernetFrame::<&[u8]>::buffer_len($buffer_len);
+                let mut tx_buffer = self.device.transmit(timestamp, tx_len)?;
+                debug_assert!(tx_buffer.as_ref().len() == tx_len);
 
-                $frame = EthernetFrame::new(&mut $tx_buffer);
+                let mut $frame = EthernetFrame::new(&mut tx_buffer);
                 $frame.set_src_addr(self.hardware_addr);
-                $frame.set_dst_addr(dst_hardware_addr);
-                $frame.set_ethertype(EthernetProtocol::Ipv4);
 
-                let mut ip_packet = Ipv4Packet::new($frame.payload_mut());
-                $ip_repr.emit(&mut ip_packet);
-                ip_packet
+                $code
+
+                Ok(())
+            });
+
+            (Ip, $ip_repr:expr, |$payload:ident| $code:stmt) => ({
+                let ip_repr = $ip_repr.lower(&self.protocol_addrs)?;
+                match self.arp_cache.lookup(&ip_repr.dst_addr()) {
+                    None => Err(Error::Unaddressable),
+                    Some(dst_hardware_addr) => {
+                        emit_packet!(Ethernet, ip_repr.total_len(), |frame| {
+                            frame.set_dst_addr(dst_hardware_addr);
+                            match ip_repr {
+                                IpRepr::Ipv4(_) => frame.set_ethertype(EthernetProtocol::Ipv4),
+                                _ => unreachable!()
+                            }
+
+                            ip_repr.emit(frame.payload_mut());
+
+                            let $payload = &mut frame.payload_mut()[ip_repr.buffer_len()..];
+                            $code
+                        })
+                    }
+                }
             })
         }
 
         match response {
-            Response::Arp(repr) => {
-                let tx_len = EthernetFrame::<&[u8]>::buffer_len(repr.buffer_len());
-                let mut tx_buffer = self.device.transmit(timestamp, tx_len)?;
-                debug_assert!(tx_buffer.as_ref().len() == tx_len);
-
-                let mut frame = EthernetFrame::new(&mut tx_buffer);
-                frame.set_src_addr(self.hardware_addr);
-                frame.set_dst_addr(match repr {
-                    ArpRepr::EthernetIpv4 { target_hardware_addr, .. } => target_hardware_addr,
-                    _ => unreachable!()
-                });
-                frame.set_ethertype(EthernetProtocol::Arp);
+            Response::Arp(arp_repr) => {
+                let dst_hardware_addr =
+                    match arp_repr {
+                        ArpRepr::EthernetIpv4 { target_hardware_addr, .. } => target_hardware_addr,
+                        _ => unreachable!()
+                    };
 
-                let mut packet = ArpPacket::new(frame.payload_mut());
-                repr.emit(&mut packet);
+                emit_packet!(Ethernet, arp_repr.buffer_len(), |frame| {
+                    frame.set_dst_addr(dst_hardware_addr);
+                    frame.set_ethertype(EthernetProtocol::Arp);
 
-                Ok(())
+                    let mut packet = ArpPacket::new(frame.payload_mut());
+                    arp_repr.emit(&mut packet);
+                })
             },
-
-            Response::Icmpv4(ip_repr, icmp_repr) => {
-                let mut tx_buffer;
-                let mut frame;
-                let mut ip_packet = ip_response!(tx_buffer, frame, ip_repr);
-                let mut icmp_packet = Icmpv4Packet::new(ip_packet.payload_mut());
-                icmp_repr.emit(&mut icmp_packet);
-                Ok(())
+            Response::Icmpv4(ipv4_repr, icmpv4_repr) => {
+                emit_packet!(Ip, IpRepr::Ipv4(ipv4_repr), |payload| {
+                    icmpv4_repr.emit(&mut Icmpv4Packet::new(payload));
+                })
             }
-
-            Response::Tcpv4(ip_repr, tcp_repr) => {
-                let mut tx_buffer;
-                let mut frame;
-                let mut ip_packet = ip_response!(tx_buffer, frame, ip_repr);
-                let mut tcp_packet = TcpPacket::new(ip_packet.payload_mut());
-                tcp_repr.emit(&mut tcp_packet,
-                              &IpAddress::Ipv4(ip_repr.src_addr),
-                              &IpAddress::Ipv4(ip_repr.dst_addr));
-                Ok(())
-            }
-
-            Response::Nop => {
-                Ok(())
+            Response::Tcp(ip_repr, tcp_repr) => {
+                emit_packet!(Ip, ip_repr, |payload| {
+                    tcp_repr.emit(&mut TcpPacket::new(payload),
+                                  &ip_repr.src_addr(), &ip_repr.dst_addr());
+                })
             }
+            Response::Nop => Ok(())
         }
     }
 

+ 46 - 8
src/socket/tcp.rs

@@ -285,10 +285,10 @@ impl<'a> TcpSocket<'a> {
             listen_address:  IpAddress::default(),
             local_endpoint:  IpEndpoint::default(),
             remote_endpoint: IpEndpoint::default(),
-            local_seq_no:    TcpSeqNumber(0),
-            remote_seq_no:   TcpSeqNumber(0),
-            remote_last_seq: TcpSeqNumber(0),
-            remote_last_ack: TcpSeqNumber(0),
+            local_seq_no:    TcpSeqNumber::default(),
+            remote_seq_no:   TcpSeqNumber::default(),
+            remote_last_seq: TcpSeqNumber::default(),
+            remote_last_ack: TcpSeqNumber::default(),
             remote_win_len:  0,
             remote_mss:      DEFAULT_MSS,
             retransmit:      Retransmit::new(),
@@ -335,10 +335,10 @@ impl<'a> TcpSocket<'a> {
         self.listen_address  = IpAddress::default();
         self.local_endpoint  = IpEndpoint::default();
         self.remote_endpoint = IpEndpoint::default();
-        self.local_seq_no    = TcpSeqNumber(0);
-        self.remote_seq_no   = TcpSeqNumber(0);
-        self.remote_last_seq = TcpSeqNumber(0);
-        self.remote_last_ack = TcpSeqNumber(0);
+        self.local_seq_no    = TcpSeqNumber::default();
+        self.remote_seq_no   = TcpSeqNumber::default();
+        self.remote_last_seq = TcpSeqNumber::default();
+        self.remote_last_ack = TcpSeqNumber::default();
         self.remote_win_len  = 0;
         self.remote_mss      = DEFAULT_MSS;
         self.retransmit.reset();
@@ -681,6 +681,44 @@ impl<'a> TcpSocket<'a> {
         self.state = state
     }
 
+    pub(crate) fn reply(ip_repr: &IpRepr, tcp_repr: &TcpRepr) -> (IpRepr, TcpRepr<'static>) {
+        let tcp_reply_repr = TcpRepr {
+            src_port:     tcp_repr.dst_port,
+            dst_port:     tcp_repr.src_port,
+            control:      TcpControl::None,
+            push:         false,
+            seq_number:   TcpSeqNumber(0),
+            ack_number:   None,
+            window_len:   0,
+            max_seg_size: None,
+            payload:      &[]
+        };
+        let ip_reply_repr = IpRepr::Unspecified {
+            src_addr:    ip_repr.dst_addr(),
+            dst_addr:    ip_repr.src_addr(),
+            protocol:    IpProtocol::Tcp,
+            payload_len: tcp_reply_repr.buffer_len()
+        };
+        (ip_reply_repr, tcp_reply_repr)
+    }
+
+    pub(crate) fn rst_reply(ip_repr: &IpRepr, tcp_repr: &TcpRepr) -> (IpRepr, TcpRepr<'static>) {
+        debug_assert!(tcp_repr.control != TcpControl::Rst);
+
+        let (ip_reply_repr, mut tcp_reply_repr) = Self::reply(ip_repr, tcp_repr);
+
+        // See https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for explanation
+        // of why we sometimes send an RST and sometimes an RST|ACK
+        tcp_reply_repr.control = TcpControl::Rst;
+        tcp_reply_repr.seq_number = tcp_repr.ack_number.unwrap_or_default();
+        if tcp_repr.control == TcpControl::Syn {
+            tcp_reply_repr.ack_number = Some(tcp_repr.seq_number +
+                                             tcp_repr.segment_len());
+        }
+
+        (ip_reply_repr, tcp_reply_repr)
+    }
+
     pub(crate) fn process(&mut self, timestamp: u64, ip_repr: &IpRepr,
                           payload: &[u8]) -> Result<()> {
         debug_assert!(ip_repr.protocol() == IpProtocol::Tcp);

+ 17 - 0
src/wire/ip.rs

@@ -177,6 +177,12 @@ pub enum IpRepr {
     __Nonexhaustive
 }
 
+impl From<Ipv4Repr> for IpRepr {
+    fn from(repr: Ipv4Repr) -> IpRepr {
+        IpRepr::Ipv4(repr)
+    }
+}
+
 impl IpRepr {
     /// Return the protocol version.
     pub fn version(&self) -> Version {
@@ -323,6 +329,17 @@ impl IpRepr {
                 unreachable!()
         }
     }
+
+    /// Return the total length of a packet that will be emitted from this
+    /// high-level representation.
+    ///
+    /// This is the same as `repr.buffer_len() + repr.payload_len()`.
+    ///
+    /// # Panics
+    /// This function panics if invoked on an unspecified representation.
+    pub fn total_len(&self) -> usize {
+        self.buffer_len() + self.payload_len()
+    }
 }
 
 pub mod checksum {

+ 14 - 6
src/wire/tcp.rs

@@ -9,7 +9,7 @@ use super::ip::checksum;
 ///
 /// A sequence number is a monotonically advancing integer modulo 2<sup>32</sup>.
 /// Sequence numbers do not have a discontiguity when compared pairwise across a signed overflow.
-#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+#[derive(Debug, PartialEq, Eq, Clone, Copy, Default)]
 pub struct SeqNumber(pub i32);
 
 impl fmt::Display for SeqNumber {
@@ -275,7 +275,6 @@ impl<T: AsRef<[u8]>> Packet<T> {
     }
 
     /// Return the length of the segment, in terms of sequence space.
-    #[inline]
     pub fn segment_len(&self) -> usize {
         let data = self.buffer.as_ref();
         let mut length = data.len() - self.header_len() as usize;
@@ -695,10 +694,9 @@ impl<'a> Repr<'a> {
     }
 
     /// Emit a high-level representation into a Transmission Control Protocol packet.
-    pub fn emit<T: ?Sized>(&self, packet: &mut Packet<&mut T>,
-                           src_addr: &IpAddress,
-                           dst_addr: &IpAddress)
-            where T: AsRef<[u8]> + AsMut<[u8]> {
+    pub fn emit<T>(&self, packet: &mut Packet<&mut T>,
+                   src_addr: &IpAddress, dst_addr: &IpAddress)
+            where T: AsRef<[u8]> + AsMut<[u8]> + ?Sized {
         packet.set_src_port(self.src_port);
         packet.set_dst_port(self.dst_port);
         packet.set_seq_number(self.seq_number);
@@ -727,6 +725,16 @@ impl<'a> Repr<'a> {
         packet.payload_mut().copy_from_slice(self.payload);
         packet.fill_checksum(src_addr, dst_addr)
     }
+
+    /// Return the length of the segment, in terms of sequence space.
+    pub fn segment_len(&self) -> usize {
+        let mut length = self.payload.len();
+        match self.control {
+            Control::Syn | Control::Fin => length += 1,
+            _ => ()
+        }
+        length
+    }
 }
 
 impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {