ソースを参照

Implement TCP server sockets.

whitequark 8 年 前
コミット
ab61890b09
6 ファイル変更149 行追加29 行削除
  1. 30 16
      examples/smoltcpserver.rs
  2. 2 2
      src/managed.rs
  3. 16 0
      src/socket/mod.rs
  4. 90 0
      src/socket/tcp.rs
  5. 1 1
      src/socket/udp.rs
  6. 10 10
      src/wire/ip.rs

+ 30 - 16
examples/smoltcpserver.rs

@@ -7,6 +7,7 @@ use smoltcp::phy::{Tracer, TapInterface};
 use smoltcp::wire::{EthernetFrame, EthernetAddress, IpAddress, IpEndpoint};
 use smoltcp::iface::{SliceArpCache, EthernetInterface};
 use smoltcp::socket::{UdpSocket, AsSocket, UdpSocketBuffer, UdpPacketBuffer};
+use smoltcp::socket::{TcpListener};
 
 fn main() {
     let ifname = env::args().nth(1).unwrap();
@@ -15,14 +16,18 @@ fn main() {
     let device = Tracer::<_, EthernetFrame<&[u8]>>::new(device);
     let arp_cache = SliceArpCache::new(vec![Default::default(); 8]);
 
+    let endpoint = IpEndpoint::new(IpAddress::default(), 6969);
+
     let udp_rx_buffer = UdpSocketBuffer::new(vec![UdpPacketBuffer::new(vec![0; 2048])]);
     let udp_tx_buffer = UdpSocketBuffer::new(vec![UdpPacketBuffer::new(vec![0; 2048])]);
-    let endpoint = IpEndpoint::new(IpAddress::default(), 6969);
     let udp_socket = UdpSocket::new(endpoint, udp_rx_buffer, udp_tx_buffer);
 
+    let tcp_backlog = vec![None];
+    let tcp_listener = TcpListener::new(endpoint, tcp_backlog);
+
     let hardware_addr = EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]);
     let protocol_addrs = [IpAddress::v4(192, 168, 69, 1)];
-    let sockets = [udp_socket];
+    let sockets = [udp_socket, tcp_listener];
     let mut iface = EthernetInterface::new(device, arp_cache,
         hardware_addr, protocol_addrs, sockets);
 
@@ -32,22 +37,31 @@ fn main() {
             Err(e) => println!("error {}", e)
         }
 
-        let udp_socket = iface.sockets()[0].as_socket();
-        let client = match udp_socket.recv() {
-            Ok((endpoint, data)) => {
-                println!("data {:?} from {}", data, endpoint);
-                Some(endpoint)
+        {
+            let udp_socket: &mut UdpSocket = iface.sockets()[0].as_socket();
+            let udp_client = match udp_socket.recv() {
+                Ok((endpoint, data)) => {
+                    println!("data {:?} from {}", data, endpoint);
+                    Some(endpoint)
+                }
+                Err(Error::Exhausted) => {
+                    None
+                }
+                Err(e) => {
+                    println!("error {}", e);
+                    None
+                }
+            };
+            if let Some(endpoint) = udp_client {
+                udp_socket.send_slice(endpoint, "hihihi".as_bytes()).unwrap()
             }
-            Err(Error::Exhausted) => {
-                None
-            }
-            Err(e) => {
-                println!("error {}", e);
-                None
+        }
+
+        {
+            let tcp_listener: &mut TcpListener = iface.sockets()[1].as_socket();
+            if let Some(stream) = tcp_listener.accept() {
+                println!("client from {}", stream.remote_end())
             }
-        };
-        if let Some(endpoint) = client {
-            udp_socket.send_slice(endpoint, "hihihi".as_bytes()).unwrap()
         }
     }
 }

+ 2 - 2
src/managed.rs

@@ -35,8 +35,8 @@ impl<'a, T: 'a + fmt::Debug + ?Sized> fmt::Debug for Managed<'a, T> {
     }
 }
 
-impl<'a, 'b: 'a, T: 'b + ?Sized> From<&'b mut T> for Managed<'b, T> {
-    fn from(value: &'b mut T) -> Self {
+impl<'a, T: 'a + ?Sized> From<&'a mut T> for Managed<'a, T> {
+    fn from(value: &'a mut T) -> Self {
         Managed::Borrowed(value)
     }
 }

+ 16 - 0
src/socket/mod.rs

@@ -21,6 +21,8 @@ pub use self::udp::SocketBuffer as UdpSocketBuffer;
 pub use self::udp::UdpSocket as UdpSocket;
 
 pub use self::tcp::SocketBuffer as TcpSocketBuffer;
+pub use self::tcp::Incoming as TcpIncoming;
+pub use self::tcp::Listener as TcpListener;
 
 /// A packet representation.
 ///
@@ -49,6 +51,7 @@ pub trait PacketRepr {
 /// since the lower layers treat the packet as an opaque octet sequence.
 pub enum Socket<'a, 'b: 'a> {
     Udp(UdpSocket<'a, 'b>),
+    TcpServer(TcpListener<'a>),
     #[doc(hidden)]
     __Nonexhaustive
 }
@@ -67,6 +70,8 @@ impl<'a, 'b> Socket<'a, 'b> {
         match self {
             &mut Socket::Udp(ref mut socket) =>
                 socket.collect(src_addr, dst_addr, protocol, payload),
+            &mut Socket::TcpServer(ref mut socket) =>
+                socket.collect(src_addr, dst_addr, protocol, payload),
             &mut Socket::__Nonexhaustive => unreachable!()
         }
     }
@@ -84,6 +89,8 @@ impl<'a, 'b> Socket<'a, 'b> {
         match self {
             &mut Socket::Udp(ref mut socket) =>
                 socket.dispatch(f),
+            &mut Socket::TcpServer(_) =>
+                Err(Error::Exhausted),
             &mut Socket::__Nonexhaustive => unreachable!()
         }
     }
@@ -105,3 +112,12 @@ impl<'a, 'b> AsSocket<UdpSocket<'a, 'b>> for Socket<'a, 'b> {
         }
     }
 }
+
+impl<'a, 'b> AsSocket<TcpListener<'a>> for Socket<'a, 'b> {
+    fn as_socket(&mut self) -> &mut TcpListener<'a> {
+        match self {
+            &mut Socket::TcpServer(ref mut socket) => socket,
+            _ => panic!(".as_socket::<TcpListener> called on wrong socket type")
+        }
+    }
+}

+ 90 - 0
src/socket/tcp.rs

@@ -1,4 +1,8 @@
+use Error;
 use Managed;
+use wire::{IpProtocol, IpAddress, IpEndpoint};
+use wire::{TcpPacket, TcpRepr, TcpControl};
+use socket::{Socket};
 
 /// A TCP stream ring buffer.
 #[derive(Debug)]
@@ -56,6 +60,92 @@ impl<'a> SocketBuffer<'a> {
     }
 }
 
+/// A description of incoming TCP connection.
+#[derive(Debug)]
+pub struct Incoming {
+    local_end:  IpEndpoint,
+    remote_end: IpEndpoint,
+    seq_number: u32
+}
+
+impl Incoming {
+    /// Return the local endpoint.
+    pub fn local_end(&self) -> IpEndpoint {
+        self.local_end
+    }
+
+    /// Return the remote endpoint.
+    pub fn remote_end(&self) -> IpEndpoint {
+        self.remote_end
+    }
+}
+
+/// A Transmission Control Protocol server socket.
+#[derive(Debug)]
+pub struct Listener<'a> {
+    endpoint:   IpEndpoint,
+    backlog:    Managed<'a, [Option<Incoming>]>,
+    accept_at:  usize,
+    length:     usize
+}
+
+impl<'a> Listener<'a> {
+    /// Create a server socket with the given backlog.
+    pub fn new<T>(endpoint: IpEndpoint, backlog: T) -> Socket<'a, 'static>
+            where T: Into<Managed<'a, [Option<Incoming>]>> {
+        Socket::TcpServer(Listener {
+            endpoint:  endpoint,
+            backlog:   backlog.into(),
+            accept_at: 0,
+            length:    0
+        })
+    }
+
+    /// Accept a connection from this server socket,
+    pub fn accept(&mut self) -> Option<Incoming> {
+        if self.length == 0 { return None }
+
+        let accept_at = self.accept_at;
+        self.accept_at = (self.accept_at + 1) % self.backlog.len();
+        self.length -= 1;
+
+        self.backlog[accept_at].take()
+    }
+
+    /// See [Socket::collect](enum.Socket.html#method.collect).
+    pub fn collect(&mut self, src_addr: &IpAddress, dst_addr: &IpAddress,
+                   protocol: IpProtocol, payload: &[u8])
+            -> Result<(), Error> {
+        if protocol != IpProtocol::Tcp { return Err(Error::Rejected) }
+
+        let packet = try!(TcpPacket::new(payload));
+        let repr = try!(TcpRepr::parse(&packet, src_addr, dst_addr));
+
+        if repr.dst_port != self.endpoint.port { return Err(Error::Rejected) }
+        if !self.endpoint.addr.is_unspecified() {
+            if self.endpoint.addr != *dst_addr { return Err(Error::Rejected) }
+        }
+
+        match (repr.control, repr.ack_number) {
+            (TcpControl::Syn, None) => {
+                if self.length == self.backlog.len() { return Err(Error::Exhausted) }
+
+                let inject_at = (self.accept_at + self.length) % self.backlog.len();
+                self.length += 1;
+
+                assert!(self.backlog[inject_at].is_none());
+                self.backlog[inject_at] = Some(Incoming {
+                    local_end:  IpEndpoint::new(*dst_addr, repr.dst_port),
+                    remote_end: IpEndpoint::new(*src_addr, repr.src_port),
+                    seq_number: repr.seq_number
+                });
+                Ok(())
+            }
+            _ => Err(Error::Rejected)
+        }
+    }
+}
+
 #[cfg(test)]
 mod test {
     use super::*;

+ 1 - 1
src/socket/udp.rs

@@ -17,7 +17,7 @@ impl<'a> PacketBuffer<'a> {
     pub fn new<T>(payload: T) -> PacketBuffer<'a>
             where T: Into<Managed<'a, [u8]>> {
         PacketBuffer {
-            endpoint: IpEndpoint::INVALID,
+            endpoint: IpEndpoint::UNSPECIFIED,
             size:     0,
             payload:  payload.into()
         }

+ 10 - 10
src/wire/ip.rs

@@ -25,9 +25,9 @@ impl fmt::Display for Protocol {
 /// An internetworking address.
 #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
 pub enum Address {
-    /// An invalid address.
+    /// An unspecified address.
     /// May be used as a placeholder for storage where the address is not assigned yet.
-    Invalid,
+    Unspecified,
     /// An IPv4 address.
     Ipv4(Ipv4Address)
 }
@@ -41,23 +41,23 @@ impl Address {
     /// Query whether the address is a valid unicast address.
     pub fn is_unicast(&self) -> bool {
         match self {
-            &Address::Invalid    => false,
-            &Address::Ipv4(addr) => addr.is_unicast()
+            &Address::Unspecified => false,
+            &Address::Ipv4(addr)  => addr.is_unicast()
         }
     }
 
     /// Query whether the address falls into the "unspecified" range.
     pub fn is_unspecified(&self) -> bool {
         match self {
-            &Address::Invalid    => false,
-            &Address::Ipv4(addr) => addr.is_unspecified()
+            &Address::Unspecified => true,
+            &Address::Ipv4(addr)  => addr.is_unspecified()
         }
     }
 }
 
 impl Default for Address {
     fn default() -> Address {
-        Address::Invalid
+        Address::Unspecified
     }
 }
 
@@ -70,8 +70,8 @@ impl From<Ipv4Address> for Address {
 impl fmt::Display for Address {
     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
         match self {
-            &Address::Invalid    => write!(f, "(invalid)"),
-            &Address::Ipv4(addr) => write!(f, "{}", addr)
+            &Address::Unspecified => write!(f, "*"),
+            &Address::Ipv4(addr)  => write!(f, "{}", addr)
         }
     }
 }
@@ -84,7 +84,7 @@ pub struct Endpoint {
 }
 
 impl Endpoint {
-    pub const INVALID: Endpoint = Endpoint { addr: Address::Invalid, port: 0 };
+    pub const UNSPECIFIED: Endpoint = Endpoint { addr: Address::Unspecified, port: 0 };
 
     /// Create an endpoint address from given address and port.
     pub fn new(addr: Address, port: u16) -> Endpoint {