Browse Source

Get rid of explicit backlog.

whitequark 8 năm trước cách đây
mục cha
commit
f89690c278
4 tập tin đã thay đổi với 208 bổ sung180 xóa
  1. 10 16
      examples/smoltcpserver.rs
  2. 15 21
      src/socket/mod.rs
  3. 182 142
      src/socket/tcp.rs
  4. 1 1
      src/socket/udp.rs

+ 10 - 16
examples/smoltcpserver.rs

@@ -1,4 +1,4 @@
-#![feature(associated_consts)]
+#![feature(associated_consts, type_ascription)]
 extern crate smoltcp;
 
 use std::env;
@@ -6,8 +6,9 @@ use smoltcp::Error;
 use smoltcp::phy::{Tracer, TapInterface};
 use smoltcp::wire::{EthernetFrame, EthernetAddress, IpAddress, IpEndpoint};
 use smoltcp::iface::{SliceArpCache, EthernetInterface};
-use smoltcp::socket::{UdpSocket, AsSocket, UdpSocketBuffer, UdpPacketBuffer};
-use smoltcp::socket::{TcpListener, TcpStreamBuffer};
+use smoltcp::socket::AsSocket;
+use smoltcp::socket::{UdpSocket, UdpSocketBuffer, UdpPacketBuffer};
+use smoltcp::socket::{TcpSocket, TcpSocketBuffer};
 
 fn main() {
     let ifname = env::args().nth(1).unwrap();
@@ -22,12 +23,14 @@ fn main() {
     let udp_tx_buffer = UdpSocketBuffer::new(vec![UdpPacketBuffer::new(vec![0; 2048])]);
     let udp_socket = UdpSocket::new(endpoint, udp_rx_buffer, udp_tx_buffer);
 
-    let tcp_backlog = vec![None];
-    let tcp_listener = TcpListener::new(endpoint, tcp_backlog);
+    let tcp_rx_buffer = TcpSocketBuffer::new(vec![0; 8192]);
+    let tcp_tx_buffer = TcpSocketBuffer::new(vec![0; 8192]);
+    let mut tcp_socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
+    (tcp_socket.as_socket() : &mut TcpSocket).listen(endpoint);
 
     let hardware_addr = EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]);
     let protocol_addrs = [IpAddress::v4(192, 168, 69, 1)];
-    let sockets = vec![udp_socket, tcp_listener];
+    let sockets = vec![udp_socket, tcp_socket];
     let mut iface = EthernetInterface::new(device, arp_cache,
         hardware_addr, protocol_addrs, sockets);
 
@@ -57,15 +60,6 @@ fn main() {
             }
         }
 
-        if let Some(incoming) = {
-            let tcp_listener: &mut TcpListener = iface.sockets()[1].as_socket();
-            tcp_listener.accept()
-        } {
-            println!("client from {}", incoming.remote_end());
-
-            let tcp_rx_buffer = TcpStreamBuffer::new(vec![0; 8192]);
-            let tcp_tx_buffer = TcpStreamBuffer::new(vec![0; 4096]);
-            iface.sockets().push(incoming.into_stream(tcp_rx_buffer, tcp_tx_buffer));
-        }
+        let _tcp_socket: &mut TcpSocket = iface.sockets()[1].as_socket();
     }
 }

+ 15 - 21
src/socket/mod.rs

@@ -18,12 +18,11 @@ mod tcp;
 
 pub use self::udp::PacketBuffer as UdpPacketBuffer;
 pub use self::udp::SocketBuffer as UdpSocketBuffer;
-pub use self::udp::UdpSocket as UdpSocket;
+pub use self::udp::UdpSocket;
 
-pub use self::tcp::StreamBuffer as TcpStreamBuffer;
-pub use self::tcp::Stream as TcpStream;
-pub use self::tcp::Incoming as TcpIncoming;
-pub use self::tcp::Listener as TcpListener;
+pub use self::tcp::SocketBuffer as TcpSocketBuffer;
+pub use self::tcp::State as TcpState;
+pub use self::tcp::TcpSocket;
 
 /// A packet representation.
 ///
@@ -51,9 +50,8 @@ pub trait PacketRepr {
 /// not yet known and the packet storage has to be allocated; but the `&PacketRepr` is sufficient
 /// since the lower layers treat the packet as an opaque octet sequence.
 pub enum Socket<'a, 'b: 'a> {
-    UdpSocket(UdpSocket<'a, 'b>),
-    TcpStream(TcpStream<'a>),
-    TcpListener(TcpListener<'a>),
+    Udp(UdpSocket<'a, 'b>),
+    Tcp(TcpSocket<'a>),
     #[doc(hidden)]
     __Nonexhaustive
 }
@@ -70,11 +68,9 @@ impl<'a, 'b> Socket<'a, 'b> {
                    protocol: IpProtocol, payload: &[u8])
             -> Result<(), Error> {
         match self {
-            &mut Socket::UdpSocket(ref mut socket) =>
+            &mut Socket::Udp(ref mut socket) =>
                 socket.collect(src_addr, dst_addr, protocol, payload),
-            &mut Socket::TcpStream(ref mut socket) =>
-                socket.collect(src_addr, dst_addr, protocol, payload),
-            &mut Socket::TcpListener(ref mut socket) =>
+            &mut Socket::Tcp(ref mut socket) =>
                 socket.collect(src_addr, dst_addr, protocol, payload),
             &mut Socket::__Nonexhaustive => unreachable!()
         }
@@ -91,12 +87,10 @@ impl<'a, 'b> Socket<'a, 'b> {
                                              IpProtocol, &PacketRepr) -> Result<(), Error>)
             -> Result<(), Error> {
         match self {
-            &mut Socket::UdpSocket(ref mut socket) =>
+            &mut Socket::Udp(ref mut socket) =>
                 socket.dispatch(f),
-            &mut Socket::TcpStream(ref mut socket) =>
+            &mut Socket::Tcp(ref mut socket) =>
                 socket.dispatch(f),
-            &mut Socket::TcpListener(_) =>
-                Err(Error::Exhausted),
             &mut Socket::__Nonexhaustive => unreachable!()
         }
     }
@@ -113,17 +107,17 @@ pub trait AsSocket<T> {
 impl<'a, 'b> AsSocket<UdpSocket<'a, 'b>> for Socket<'a, 'b> {
     fn as_socket(&mut self) -> &mut UdpSocket<'a, 'b> {
         match self {
-            &mut Socket::UdpSocket(ref mut socket) => socket,
+            &mut Socket::Udp(ref mut socket) => socket,
             _ => panic!(".as_socket::<UdpSocket> called on wrong socket type")
         }
     }
 }
 
-impl<'a, 'b> AsSocket<TcpListener<'a>> for Socket<'a, 'b> {
-    fn as_socket(&mut self) -> &mut TcpListener<'a> {
+impl<'a, 'b> AsSocket<TcpSocket<'a>> for Socket<'a, 'b> {
+    fn as_socket(&mut self) -> &mut TcpSocket<'a> {
         match self {
-            &mut Socket::TcpListener(ref mut socket) => socket,
-            _ => panic!(".as_socket::<TcpListener> called on wrong socket type")
+            &mut Socket::Tcp(ref mut socket) => socket,
+            _ => panic!(".as_socket::<TcpSocket> called on wrong socket type")
         }
     }
 }

+ 182 - 142
src/socket/tcp.rs

@@ -1,3 +1,5 @@
+use core::fmt;
+
 use Error;
 use Managed;
 use wire::{IpProtocol, IpAddress, IpEndpoint};
@@ -6,23 +8,33 @@ use socket::{Socket, PacketRepr};
 
 /// A TCP stream ring buffer.
 #[derive(Debug)]
-pub struct StreamBuffer<'a> {
+pub struct SocketBuffer<'a> {
     storage: Managed<'a, [u8]>,
     read_at: usize,
     length:  usize
 }
 
-impl<'a> StreamBuffer<'a> {
+impl<'a> SocketBuffer<'a> {
     /// Create a packet buffer with the given storage.
-    pub fn new<T>(storage: T) -> StreamBuffer<'a>
+    pub fn new<T>(storage: T) -> SocketBuffer<'a>
             where T: Into<Managed<'a, [u8]>> {
-        StreamBuffer {
+        SocketBuffer {
             storage: storage.into(),
             read_at: 0,
             length:  0
         }
     }
 
+    /// Return the amount of octets enqueued in the buffer.
+    pub fn len(&self) -> usize {
+        self.length
+    }
+
+    /// Return the maximum amount of octets that can be enqueued in the buffer.
+    pub fn capacity(&self) -> usize {
+        self.storage.len()
+    }
+
     /// Enqueue a slice of octets up to the given size into the buffer, and return a pointer
     /// to the slice.
     ///
@@ -60,151 +72,128 @@ impl<'a> StreamBuffer<'a> {
     }
 }
 
-impl<'a> Into<StreamBuffer<'a>> for Managed<'a, [u8]> {
-    fn into(self) -> StreamBuffer<'a> {
-        StreamBuffer::new(self)
+impl<'a> Into<SocketBuffer<'a>> for Managed<'a, [u8]> {
+    fn into(self) -> SocketBuffer<'a> {
+        SocketBuffer::new(self)
     }
 }
 
-/// A Transmission Control Protocol data stream.
-#[derive(Debug)]
-pub struct Stream<'a> {
-    local_end:  IpEndpoint,
-    remote_end: IpEndpoint,
-    local_seq:  i32,
-    remote_seq: i32,
-    rx_buffer:  StreamBuffer<'a>,
-    tx_buffer:  StreamBuffer<'a>
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub enum State {
+    Closed,
+    Listen,
+    SynSent,
+    SynReceived,
+    Established,
+    FinWait1,
+    FinWait2,
+    CloseWait,
+    Closing,
+    LastAck,
+    TimeWait
 }
 
-impl<'a> Stream<'a> {
-    /// Return the local endpoint.
-    #[inline(always)]
-    pub fn local_end(&self) -> IpEndpoint {
-        self.local_end
-    }
-
-    /// Return the remote endpoint.
-    #[inline(always)]
-    pub fn remote_end(&self) -> IpEndpoint {
-        self.remote_end
-    }
-
-    /// See [Socket::collect](enum.Socket.html#method.collect).
-    pub fn collect(&mut self, src_addr: &IpAddress, dst_addr: &IpAddress,
-                   protocol: IpProtocol, payload: &[u8])
-            -> Result<(), Error> {
-        if protocol != IpProtocol::Tcp { return Err(Error::Rejected) }
-
-        let packet = try!(TcpPacket::new(payload));
-        let repr = try!(TcpRepr::parse(&packet, src_addr, dst_addr));
-
-        if self.local_end  != IpEndpoint::new(*dst_addr, repr.dst_port) {
-            return Err(Error::Rejected)
+impl fmt::Display for State {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        match self {
+            &State::Closed      => write!(f, "CLOSED"),
+            &State::Listen      => write!(f, "LISTEN"),
+            &State::SynSent     => write!(f, "SYN_SENT"),
+            &State::SynReceived => write!(f, "SYN_RECEIVED"),
+            &State::Established => write!(f, "ESTABLISHED"),
+            &State::FinWait1    => write!(f, "FIN_WAIT_1"),
+            &State::FinWait2    => write!(f, "FIN_WAIT_2"),
+            &State::CloseWait   => write!(f, "CLOSE_WAIT"),
+            &State::Closing     => write!(f, "CLOSING"),
+            &State::LastAck     => write!(f, "LAST_ACK"),
+            &State::TimeWait    => write!(f, "TIME_WAIT")
         }
-        if self.remote_end != IpEndpoint::new(*src_addr, repr.src_port) {
-            return Err(Error::Rejected)
-        }
-
-        // FIXME: process
-        Ok(())
     }
+}
 
-    /// See [Socket::dispatch](enum.Socket.html#method.dispatch).
-    pub fn dispatch(&mut self, _f: &mut FnMut(&IpAddress, &IpAddress,
-                                              IpProtocol, &PacketRepr) -> Result<(), Error>)
-            -> Result<(), Error> {
-        // FIXME: process
-        // f(&self.local_end.addr,
-        //   &self.remote_end.addr,
-        //   IpProtocol::Tcp,
-        //   &TcpRepr {
-        //     src_port: self.local_end.port,
-        //     dst_port: self.remote_end.port,
-        //     payload:  &packet_buf.as_ref()[..]
-        //   })
-
-        Ok(())
-    }
+#[derive(Debug)]
+struct Retransmit {
+    sent: bool // FIXME
 }
 
-impl<'a> PacketRepr for TcpRepr<'a> {
-    fn buffer_len(&self) -> usize {
-        self.buffer_len()
+impl Retransmit {
+    fn new() -> Retransmit {
+        Retransmit { sent: false }
     }
 
-    fn emit(&self, src_addr: &IpAddress, dst_addr: &IpAddress, payload: &mut [u8]) {
-        let mut packet = TcpPacket::new(payload).expect("undersized payload");
-        self.emit(&mut packet, src_addr, dst_addr)
+    fn reset(&mut self) {
+        self.sent = false
+    }
+
+    fn check(&mut self) -> bool {
+        let result = !self.sent;
+        self.sent = true;
+        result
     }
 }
 
-/// A description of incoming TCP connection.
+/// A Transmission Control Protocol data stream.
 #[derive(Debug)]
-pub struct Incoming {
-    local_end:  IpEndpoint,
-    remote_end: IpEndpoint,
-    local_seq:  i32,
-    remote_seq: i32
+pub struct TcpSocket<'a> {
+    state:         State,
+    local_end:     IpEndpoint,
+    remote_end:    IpEndpoint,
+    local_seq_no:  i32,
+    remote_seq_no: i32,
+    retransmit:    Retransmit,
+    rx_buffer:     SocketBuffer<'a>,
+    tx_buffer:     SocketBuffer<'a>
 }
 
-impl Incoming {
+impl<'a> TcpSocket<'a> {
+    /// Create a socket using the given buffers.
+    pub fn new<T>(rx_buffer: T, tx_buffer: T) -> Socket<'a, 'static>
+            where T: Into<SocketBuffer<'a>> {
+        let rx_buffer = rx_buffer.into();
+        if rx_buffer.capacity() > <u16>::max_value() as usize {
+            panic!("buffers larger than {} require window scaling, which is not implemented",
+                   <u16>::max_value())
+        }
+
+        Socket::Tcp(TcpSocket {
+            state:         State::Closed,
+            local_end:     IpEndpoint::default(),
+            remote_end:    IpEndpoint::default(),
+            local_seq_no:  0,
+            remote_seq_no: 0,
+            retransmit:    Retransmit::new(),
+            tx_buffer:     tx_buffer.into(),
+            rx_buffer:     rx_buffer.into()
+        })
+    }
+
+    /// Return the connection state.
+    #[inline(always)]
+    pub fn state(&self) -> State {
+        self.state
+    }
+
     /// Return the local endpoint.
     #[inline(always)]
-    pub fn local_end(&self) -> IpEndpoint {
+    pub fn local_endpoint(&self) -> IpEndpoint {
         self.local_end
     }
 
     /// Return the remote endpoint.
     #[inline(always)]
-    pub fn remote_end(&self) -> IpEndpoint {
+    pub fn remote_endpoint(&self) -> IpEndpoint {
         self.remote_end
     }
 
-    /// Convert into a data stream using the given buffers.
-    pub fn into_stream<'a, T>(self, rx_buffer: T, tx_buffer: T) -> Socket<'a, 'static>
-            where T: Into<StreamBuffer<'a>> {
-        Socket::TcpStream(Stream {
-            rx_buffer:  rx_buffer.into(),
-            tx_buffer:  tx_buffer.into(),
-            local_end:  self.local_end,
-            remote_end: self.remote_end,
-            local_seq:  self.local_seq,
-            remote_seq: self.remote_seq
-        })
-    }
-}
-
-/// A Transmission Control Protocol server socket.
-#[derive(Debug)]
-pub struct Listener<'a> {
-    endpoint:   IpEndpoint,
-    backlog:    Managed<'a, [Option<Incoming>]>,
-    accept_at:  usize,
-    length:     usize
-}
-
-impl<'a> Listener<'a> {
-    /// Create a server socket with the given backlog.
-    pub fn new<T>(endpoint: IpEndpoint, backlog: T) -> Socket<'a, 'static>
-            where T: Into<Managed<'a, [Option<Incoming>]>> {
-        Socket::TcpListener(Listener {
-            endpoint:  endpoint,
-            backlog:   backlog.into(),
-            accept_at: 0,
-            length:    0
-        })
-    }
-
-    /// Accept a connection from this server socket,
-    pub fn accept(&mut self) -> Option<Incoming> {
-        if self.length == 0 { return None }
-
-        let accept_at = self.accept_at;
-        self.accept_at = (self.accept_at + 1) % self.backlog.len();
-        self.length -= 1;
-
-        self.backlog[accept_at].take()
+    /// Start listening on the given endpoint.
+    ///
+    /// # Panics
+    /// This function will panic if the socket is not in the CLOSED state.
+    pub fn listen(&mut self, endpoint: IpEndpoint) {
+        assert!(self.state == State::Closed);
+        self.state      = State::Listen;
+        self.local_end  = endpoint;
+        self.remote_end = IpEndpoint::default()
     }
 
     /// See [Socket::collect](enum.Socket.html#method.collect).
@@ -216,30 +205,81 @@ impl<'a> Listener<'a> {
         let packet = try!(TcpPacket::new(payload));
         let repr = try!(TcpRepr::parse(&packet, src_addr, dst_addr));
 
-        if repr.dst_port != self.endpoint.port { return Err(Error::Rejected) }
-        if !self.endpoint.addr.is_unspecified() {
-            if self.endpoint.addr != *dst_addr { return Err(Error::Rejected) }
+        if self.local_end.port != repr.dst_port { return Err(Error::Rejected) }
+        if !self.local_end.addr.is_unspecified() &&
+           self.local_end.addr != *dst_addr { return Err(Error::Rejected) }
+
+        if self.remote_end.port != 0 &&
+           self.remote_end.port != repr.src_port { return Err(Error::Rejected) }
+        if !self.remote_end.addr.is_unspecified() &&
+           self.remote_end.addr != *src_addr { return Err(Error::Rejected) }
+
+        match (self.state, repr) {
+            (State::Closed, _) => Err(Error::Rejected),
+
+            (State::Listen, TcpRepr {
+                src_port, dst_port, control: TcpControl::Syn, seq_number, ack_number: None, ..
+            }) => {
+                self.state         = State::SynReceived;
+                self.local_end     = IpEndpoint::new(*dst_addr, dst_port);
+                self.remote_end    = IpEndpoint::new(*src_addr, src_port);
+                self.remote_seq_no = seq_number;
+                // FIXME: use something more secure
+                self.local_seq_no  = !seq_number;
+                // FIXME: queue data from SYN
+                self.retransmit.reset();
+                Ok(())
+            }
+
+            _ => {
+                // This will cause the interface to reply with an RST.
+                Err(Error::Rejected)
+            }
         }
+    }
 
-        match (repr.control, repr.ack_number) {
-            (TcpControl::Syn, None) => {
-                if self.length == self.backlog.len() { return Err(Error::Exhausted) }
-
-                let inject_at = (self.accept_at + self.length) % self.backlog.len();
-                self.length += 1;
-
-                assert!(self.backlog[inject_at].is_none());
-                self.backlog[inject_at] = Some(Incoming {
-                    local_end:  IpEndpoint::new(*dst_addr, repr.dst_port),
-                    remote_end: IpEndpoint::new(*src_addr, repr.src_port),
-                    // FIXME: choose something more secure?
-                    local_seq:  !repr.seq_number,
-                    remote_seq: repr.seq_number
-                });
-                Ok(())
+    /// See [Socket::dispatch](enum.Socket.html#method.dispatch).
+    pub fn dispatch(&mut self, f: &mut FnMut(&IpAddress, &IpAddress,
+                                             IpProtocol, &PacketRepr) -> Result<(), Error>)
+            -> Result<(), Error> {
+        let mut repr = TcpRepr {
+            src_port:   self.local_end.port,
+            dst_port:   self.remote_end.port,
+            control:    TcpControl::None,
+            seq_number: 0,
+            ack_number: None,
+            window_len: (self.rx_buffer.capacity() - self.rx_buffer.len()) as u16,
+            payload:    &[]
+        };
+
+        // FIXME: process
+
+        match self.state {
+            State::Closed |
+            State::Listen => {
+                return Err(Error::Exhausted)
             }
-            _ => Err(Error::Rejected)
+            State::SynReceived => {
+                if !self.retransmit.check() { return Err(Error::Exhausted) }
+                repr.control    = TcpControl::Syn;
+                repr.seq_number = self.local_seq_no;
+                repr.ack_number = Some(self.remote_seq_no + 1);
+            }
+            _ => unreachable!()
         }
+
+        f(&self.local_end.addr, &self.remote_end.addr, IpProtocol::Tcp, &repr)
+    }
+}
+
+impl<'a> PacketRepr for TcpRepr<'a> {
+    fn buffer_len(&self) -> usize {
+        self.buffer_len()
+    }
+
+    fn emit(&self, src_addr: &IpAddress, dst_addr: &IpAddress, payload: &mut [u8]) {
+        let mut packet = TcpPacket::new(payload).expect("undersized payload");
+        self.emit(&mut packet, src_addr, dst_addr)
     }
 }
 
@@ -249,7 +289,7 @@ mod test {
 
     #[test]
     fn test_buffer() {
-        let mut buffer = StreamBuffer::new(vec![0; 8]); // ........
+        let mut buffer = SocketBuffer::new(vec![0; 8]); // ........
         buffer.enqueue(6).copy_from_slice(b"foobar");   // foobar..
         assert_eq!(buffer.dequeue(3), b"foo");          // ...bar..
         buffer.enqueue(6).copy_from_slice(b"ba");       // ...barba

+ 1 - 1
src/socket/udp.rs

@@ -115,7 +115,7 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
     pub fn new(endpoint: IpEndpoint,
                rx_buffer: SocketBuffer<'a, 'b>, tx_buffer: SocketBuffer<'a, 'b>)
             -> Socket<'a, 'b> {
-        Socket::UdpSocket(UdpSocket {
+        Socket::Udp(UdpSocket {
             endpoint:  endpoint,
             rx_buffer: rx_buffer,
             tx_buffer: tx_buffer