浏览代码

Implement conversion of incoming TCP connections into TCP streams.

whitequark 8 年之前
父节点
当前提交
a454a89b9e
共有 5 个文件被更改,包括 136 次插入29 次删除
  1. 10 6
      examples/smoltcpserver.rs
  2. 4 4
      src/iface/ethernet.rs
  3. 15 9
      src/socket/mod.rs
  4. 106 9
      src/socket/tcp.rs
  5. 1 1
      src/socket/udp.rs

+ 10 - 6
examples/smoltcpserver.rs

@@ -7,7 +7,7 @@ use smoltcp::phy::{Tracer, TapInterface};
 use smoltcp::wire::{EthernetFrame, EthernetAddress, IpAddress, IpEndpoint};
 use smoltcp::iface::{SliceArpCache, EthernetInterface};
 use smoltcp::socket::{UdpSocket, AsSocket, UdpSocketBuffer, UdpPacketBuffer};
-use smoltcp::socket::{TcpListener};
+use smoltcp::socket::{TcpListener, TcpStreamBuffer};
 
 fn main() {
     let ifname = env::args().nth(1).unwrap();
@@ -27,7 +27,7 @@ fn main() {
 
     let hardware_addr = EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]);
     let protocol_addrs = [IpAddress::v4(192, 168, 69, 1)];
-    let sockets = [udp_socket, tcp_listener];
+    let sockets = vec![udp_socket, tcp_listener];
     let mut iface = EthernetInterface::new(device, arp_cache,
         hardware_addr, protocol_addrs, sockets);
 
@@ -57,11 +57,15 @@ fn main() {
             }
         }
 
-        {
+        if let Some(incoming) = {
             let tcp_listener: &mut TcpListener = iface.sockets()[1].as_socket();
-            if let Some(stream) = tcp_listener.accept() {
-                println!("client from {}", stream.remote_end())
-            }
+            tcp_listener.accept()
+        } {
+            println!("client from {}", incoming.remote_end());
+
+            let tcp_rx_buffer = TcpStreamBuffer::new(vec![0; 8192]);
+            let tcp_tx_buffer = TcpStreamBuffer::new(vec![0; 4096]);
+            iface.sockets().push(incoming.into_stream(tcp_rx_buffer, tcp_tx_buffer));
         }
     }
 }

+ 4 - 4
src/iface/ethernet.rs

@@ -95,8 +95,8 @@ impl<'a, 'b: 'a,
     ///
     /// # Panics
     /// This function panics if any of the addresses is not unicast.
-    pub fn update_protocol_addrs<F: FnOnce(&mut [IpAddress])>(&mut self, f: F) {
-        f(self.protocol_addrs.borrow_mut());
+    pub fn update_protocol_addrs<F: FnOnce(&mut ProtocolAddrsT)>(&mut self, f: F) {
+        f(&mut self.protocol_addrs);
         Self::check_protocol_addrs(self.protocol_addrs.borrow())
     }
 
@@ -107,8 +107,8 @@ impl<'a, 'b: 'a,
     }
 
     /// Get the set of sockets owned by the interface.
-    pub fn sockets(&mut self) -> &mut [Socket<'a, 'b>] {
-        self.sockets.borrow_mut()
+    pub fn sockets(&mut self) -> &mut SocketsT {
+        &mut self.sockets
     }
 
     /// Receive and process a packet, if available.

+ 15 - 9
src/socket/mod.rs

@@ -20,7 +20,8 @@ pub use self::udp::PacketBuffer as UdpPacketBuffer;
 pub use self::udp::SocketBuffer as UdpSocketBuffer;
 pub use self::udp::UdpSocket as UdpSocket;
 
-pub use self::tcp::SocketBuffer as TcpSocketBuffer;
+pub use self::tcp::StreamBuffer as TcpStreamBuffer;
+pub use self::tcp::Stream as TcpStream;
 pub use self::tcp::Incoming as TcpIncoming;
 pub use self::tcp::Listener as TcpListener;
 
@@ -50,8 +51,9 @@ pub trait PacketRepr {
 /// not yet known and the packet storage has to be allocated; but the `&PacketRepr` is sufficient
 /// since the lower layers treat the packet as an opaque octet sequence.
 pub enum Socket<'a, 'b: 'a> {
-    Udp(UdpSocket<'a, 'b>),
-    TcpServer(TcpListener<'a>),
+    UdpSocket(UdpSocket<'a, 'b>),
+    TcpStream(TcpStream<'a>),
+    TcpListener(TcpListener<'a>),
     #[doc(hidden)]
     __Nonexhaustive
 }
@@ -68,9 +70,11 @@ impl<'a, 'b> Socket<'a, 'b> {
                    protocol: IpProtocol, payload: &[u8])
             -> Result<(), Error> {
         match self {
-            &mut Socket::Udp(ref mut socket) =>
+            &mut Socket::UdpSocket(ref mut socket) =>
                 socket.collect(src_addr, dst_addr, protocol, payload),
-            &mut Socket::TcpServer(ref mut socket) =>
+            &mut Socket::TcpStream(ref mut socket) =>
+                socket.collect(src_addr, dst_addr, protocol, payload),
+            &mut Socket::TcpListener(ref mut socket) =>
                 socket.collect(src_addr, dst_addr, protocol, payload),
             &mut Socket::__Nonexhaustive => unreachable!()
         }
@@ -87,9 +91,11 @@ impl<'a, 'b> Socket<'a, 'b> {
                                              IpProtocol, &PacketRepr) -> Result<(), Error>)
             -> Result<(), Error> {
         match self {
-            &mut Socket::Udp(ref mut socket) =>
+            &mut Socket::UdpSocket(ref mut socket) =>
+                socket.dispatch(f),
+            &mut Socket::TcpStream(ref mut socket) =>
                 socket.dispatch(f),
-            &mut Socket::TcpServer(_) =>
+            &mut Socket::TcpListener(_) =>
                 Err(Error::Exhausted),
             &mut Socket::__Nonexhaustive => unreachable!()
         }
@@ -107,7 +113,7 @@ pub trait AsSocket<T> {
 impl<'a, 'b> AsSocket<UdpSocket<'a, 'b>> for Socket<'a, 'b> {
     fn as_socket(&mut self) -> &mut UdpSocket<'a, 'b> {
         match self {
-            &mut Socket::Udp(ref mut socket) => socket,
+            &mut Socket::UdpSocket(ref mut socket) => socket,
             _ => panic!(".as_socket::<UdpSocket> called on wrong socket type")
         }
     }
@@ -116,7 +122,7 @@ impl<'a, 'b> AsSocket<UdpSocket<'a, 'b>> for Socket<'a, 'b> {
 impl<'a, 'b> AsSocket<TcpListener<'a>> for Socket<'a, 'b> {
     fn as_socket(&mut self) -> &mut TcpListener<'a> {
         match self {
-            &mut Socket::TcpServer(ref mut socket) => socket,
+            &mut Socket::TcpListener(ref mut socket) => socket,
             _ => panic!(".as_socket::<TcpListener> called on wrong socket type")
         }
     }

+ 106 - 9
src/socket/tcp.rs

@@ -2,21 +2,21 @@ use Error;
 use Managed;
 use wire::{IpProtocol, IpAddress, IpEndpoint};
 use wire::{TcpPacket, TcpRepr, TcpControl};
-use socket::{Socket};
+use socket::{Socket, PacketRepr};
 
 /// A TCP stream ring buffer.
 #[derive(Debug)]
-pub struct SocketBuffer<'a> {
+pub struct StreamBuffer<'a> {
     storage: Managed<'a, [u8]>,
     read_at: usize,
     length:  usize
 }
 
-impl<'a> SocketBuffer<'a> {
+impl<'a> StreamBuffer<'a> {
     /// Create a packet buffer with the given storage.
-    pub fn new<T>(storage: T) -> SocketBuffer<'a>
+    pub fn new<T>(storage: T) -> StreamBuffer<'a>
             where T: Into<Managed<'a, [u8]>> {
-        SocketBuffer {
+        StreamBuffer {
             storage: storage.into(),
             read_at: 0,
             length:  0
@@ -60,24 +60,119 @@ impl<'a> SocketBuffer<'a> {
     }
 }
 
+impl<'a> Into<StreamBuffer<'a>> for Managed<'a, [u8]> {
+    fn into(self) -> StreamBuffer<'a> {
+        StreamBuffer::new(self)
+    }
+}
+
+/// A Transmission Control Protocol data stream.
+#[derive(Debug)]
+pub struct Stream<'a> {
+    local_end:  IpEndpoint,
+    remote_end: IpEndpoint,
+    local_seq:  u32,
+    remote_seq: u32,
+    rx_buffer:  StreamBuffer<'a>,
+    tx_buffer:  StreamBuffer<'a>
+}
+
+impl<'a> Stream<'a> {
+    /// Return the local endpoint.
+    #[inline(always)]
+    pub fn local_end(&self) -> IpEndpoint {
+        self.local_end
+    }
+
+    /// Return the remote endpoint.
+    #[inline(always)]
+    pub fn remote_end(&self) -> IpEndpoint {
+        self.remote_end
+    }
+
+    /// See [Socket::collect](enum.Socket.html#method.collect).
+    pub fn collect(&mut self, src_addr: &IpAddress, dst_addr: &IpAddress,
+                   protocol: IpProtocol, payload: &[u8])
+            -> Result<(), Error> {
+        if protocol != IpProtocol::Tcp { return Err(Error::Rejected) }
+
+        let packet = try!(TcpPacket::new(payload));
+        let repr = try!(TcpRepr::parse(&packet, src_addr, dst_addr));
+
+        if self.local_end  != IpEndpoint::new(*dst_addr, repr.dst_port) {
+            return Err(Error::Rejected)
+        }
+        if self.remote_end != IpEndpoint::new(*src_addr, repr.src_port) {
+            return Err(Error::Rejected)
+        }
+
+        // FIXME: process
+        Ok(())
+    }
+
+    /// See [Socket::dispatch](enum.Socket.html#method.dispatch).
+    pub fn dispatch(&mut self, _f: &mut FnMut(&IpAddress, &IpAddress,
+                                              IpProtocol, &PacketRepr) -> Result<(), Error>)
+            -> Result<(), Error> {
+        // FIXME: process
+        // f(&self.local_end.addr,
+        //   &self.remote_end.addr,
+        //   IpProtocol::Tcp,
+        //   &TcpRepr {
+        //     src_port: self.local_end.port,
+        //     dst_port: self.remote_end.port,
+        //     payload:  &packet_buf.as_ref()[..]
+        //   })
+
+        Ok(())
+    }
+}
+
+impl<'a> PacketRepr for TcpRepr<'a> {
+    fn buffer_len(&self) -> usize {
+        self.buffer_len()
+    }
+
+    fn emit(&self, src_addr: &IpAddress, dst_addr: &IpAddress, payload: &mut [u8]) {
+        let mut packet = TcpPacket::new(payload).expect("undersized payload");
+        self.emit(&mut packet, src_addr, dst_addr)
+    }
+}
+
 /// A description of incoming TCP connection.
 #[derive(Debug)]
 pub struct Incoming {
     local_end:  IpEndpoint,
     remote_end: IpEndpoint,
-    seq_number: u32
+    local_seq:  u32,
+    remote_seq: u32
 }
 
 impl Incoming {
     /// Return the local endpoint.
+    #[inline(always)]
     pub fn local_end(&self) -> IpEndpoint {
         self.local_end
     }
 
     /// Return the remote endpoint.
+    #[inline(always)]
     pub fn remote_end(&self) -> IpEndpoint {
         self.remote_end
     }
+
+    /// Convert into a data stream using the given buffers.
+    pub fn into_stream<'a, T>(self, rx_buffer: T, tx_buffer: T) -> Socket<'a, 'static>
+            where T: Into<StreamBuffer<'a>> {
+        Socket::TcpStream(Stream {
+            rx_buffer:  rx_buffer.into(),
+            tx_buffer:  tx_buffer.into(),
+            local_end:  self.local_end,
+            remote_end: self.remote_end,
+            local_seq:  self.local_seq,
+            remote_seq: self.remote_seq
+        })
+    }
 }
 
 /// A Transmission Control Protocol server socket.
@@ -93,7 +188,7 @@ impl<'a> Listener<'a> {
     /// Create a server socket with the given backlog.
     pub fn new<T>(endpoint: IpEndpoint, backlog: T) -> Socket<'a, 'static>
             where T: Into<Managed<'a, [Option<Incoming>]>> {
-        Socket::TcpServer(Listener {
+        Socket::TcpListener(Listener {
             endpoint:  endpoint,
             backlog:   backlog.into(),
             accept_at: 0,
@@ -137,7 +232,9 @@ impl<'a> Listener<'a> {
                 self.backlog[inject_at] = Some(Incoming {
                     local_end:  IpEndpoint::new(*dst_addr, repr.dst_port),
                     remote_end: IpEndpoint::new(*src_addr, repr.src_port),
-                    seq_number: repr.seq_number
+                    // FIXME: choose something more secure?
+                    local_seq:  !repr.seq_number,
+                    remote_seq: repr.seq_number
                 });
                 Ok(())
             }
@@ -152,7 +249,7 @@ mod test {
 
     #[test]
     fn test_buffer() {
-        let mut buffer = SocketBuffer::new(vec![0; 8]);       // ........
+        let mut buffer = StreamBuffer::new(vec![0; 8]); // ........
         buffer.enqueue(6).copy_from_slice(b"foobar");   // foobar..
         assert_eq!(buffer.dequeue(3), b"foo");          // ...bar..
         buffer.enqueue(6).copy_from_slice(b"ba");       // ...barba

+ 1 - 1
src/socket/udp.rs

@@ -115,7 +115,7 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
     pub fn new(endpoint: IpEndpoint,
                rx_buffer: SocketBuffer<'a, 'b>, tx_buffer: SocketBuffer<'a, 'b>)
             -> Socket<'a, 'b> {
-        Socket::Udp(UdpSocket {
+        Socket::UdpSocket(UdpSocket {
             endpoint:  endpoint,
             rx_buffer: rx_buffer,
             tx_buffer: tx_buffer