Browse Source

fix(net): Fix TCP Unresponsiveness and Inability to Close Connections (#791)

* fix(net): Improve stability. 为RawSocket与UdpSocket实现close时调用close方法,符合smoltcp的行为。为SocketInode实现drop,保证程序任何情况下退出时都能正确close对应socket, 释放被占用的端口。

* fix(net): Correct socket close behavior.
Samuel Dai 10 months ago
parent
commit
37cef00bb4
4 changed files with 143 additions and 134 deletions
  1. 4 3
      kernel/src/net/net_core.rs
  2. 98 102
      kernel/src/net/socket/inet.rs
  3. 39 29
      kernel/src/net/socket/mod.rs
  4. 2 0
      user/apps/http_server/main.c

+ 4 - 3
kernel/src/net/net_core.rs

@@ -217,6 +217,9 @@ fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> {
                 if inner_socket.state() == smoltcp::socket::tcp::State::Established {
                     events |= TcpSocket::CAN_CONNECT;
                 }
+                if inner_socket.state() == smoltcp::socket::tcp::State::CloseWait {
+                    events |= EPollEventType::EPOLLHUP.bits() as u64;
+                }
                 handle_guard
                     .get(&global_handle)
                     .unwrap()
@@ -226,13 +229,11 @@ fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> {
             smoltcp::socket::Socket::Dhcpv4(_) => {}
             smoltcp::socket::Socket::Dns(_) => unimplemented!("Dns socket hasn't unimplemented"),
         }
-        drop(handle_guard);
-        let mut handle_guard = HANDLE_MAP.write_irqsave();
-        let handle_item = handle_guard.get_mut(&global_handle).unwrap();
         EventPoll::wakeup_epoll(
             &handle_item.epitems,
             EPollEventType::from_bits_truncate(events as u32),
         )?;
+        drop(handle_guard);
         // crate::kdebug!(
         //     "{} send_event {:?}",
         //     handle,

+ 98 - 102
kernel/src/net/socket/inet.rs

@@ -1,12 +1,11 @@
 use alloc::{boxed::Box, sync::Arc, vec::Vec};
 use smoltcp::{
-    socket::{raw, tcp, udp, AnySocket},
+    socket::{raw, tcp, udp},
     wire,
 };
 use system_error::SystemError;
 
 use crate::{
-    arch::rand::rand,
     driver::net::NetDevice,
     kerror, kwarn,
     libs::rwlock::RwLock,
@@ -88,7 +87,11 @@ impl RawSocket {
 impl Socket for RawSocket {
     fn close(&mut self) {
         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
-        socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()); // 删除的时候,会发送一条FINISH的信息?
+        if let smoltcp::socket::Socket::Udp(mut sock) =
+            socket_set_guard.remove(self.handle.smoltcp_handle().unwrap())
+        {
+            sock.close();
+        }
         drop(socket_set_guard);
         poll_ifaces();
     }
@@ -289,7 +292,7 @@ impl UdpSocket {
                 ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
             }
             // 检测端口是否已被占用
-            PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port, self.clone())?;
+            PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port)?;
 
             let bind_res = if ip.addr.is_unspecified() {
                 socket.bind(ip.port)
@@ -310,7 +313,11 @@ impl UdpSocket {
 impl Socket for UdpSocket {
     fn close(&mut self) {
         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
-        socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()); // 删除的时候,会发送一条FINISH的信息?
+        if let smoltcp::socket::Socket::Udp(mut sock) =
+            socket_set_guard.remove(self.handle.smoltcp_handle().unwrap())
+        {
+            sock.close();
+        }
         drop(socket_set_guard);
         poll_ifaces();
     }
@@ -559,11 +566,20 @@ impl TcpSocket {
 impl Socket for TcpSocket {
     fn close(&mut self) {
         for handle in self.handles.iter() {
-            let mut socket_set_guard = SOCKET_SET.lock_irqsave();
-            socket_set_guard.remove(handle.smoltcp_handle().unwrap()); // 删除的时候,会发送一条FINISH的信息?
-            drop(socket_set_guard);
+            {
+                let mut socket_set_guard = SOCKET_SET.lock_irqsave();
+                let smoltcp_handle = handle.smoltcp_handle().unwrap();
+                socket_set_guard
+                    .get_mut::<smoltcp::socket::tcp::Socket>(smoltcp_handle)
+                    .close();
+                drop(socket_set_guard);
+            }
+            poll_ifaces();
+            SOCKET_SET
+                .lock_irqsave()
+                .remove(handle.smoltcp_handle().unwrap());
+            // kdebug!("[Socket] [TCP] Close: {:?}", handle);
         }
-        poll_ifaces();
     }
 
     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
@@ -627,7 +643,7 @@ impl Socket for TcpSocket {
             drop(socket_set_guard);
             SocketHandleItem::sleep(
                 self.socket_handle(),
-                EPollEventType::EPOLLIN.bits() as u64,
+                (EPollEventType::EPOLLIN.bits() | EPollEventType::EPOLLHUP.bits()) as u64,
                 HANDLE_MAP.read_irqsave(),
             );
         }
@@ -697,7 +713,7 @@ impl Socket for TcpSocket {
         if let Endpoint::Ip(Some(ip)) = endpoint {
             let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
             // 检测端口是否被占用
-            PORT_MANAGER.bind_port(self.metadata.socket_type, temp_port, self.clone())?;
+            PORT_MANAGER.bind_port(self.metadata.socket_type, temp_port)?;
 
             // kdebug!("temp_port: {}", temp_port);
             let iface: Arc<dyn NetDevice> = NET_DEVICES.write_irqsave().get(&0).unwrap().clone();
@@ -750,7 +766,7 @@ impl Socket for TcpSocket {
 
     /// @brief tcp socket 监听 local_endpoint 端口
     ///
-    /// @param backlog 未处理的连接队列的最大长度. 由于smoltcp不支持backlog,所以这个参数目前无效
+    /// @param backlog 未处理的连接队列的最大长度
     fn listen(&mut self, backlog: usize) -> Result<(), SystemError> {
         if self.is_listening {
             return Ok(());
@@ -763,12 +779,14 @@ impl Socket for TcpSocket {
         let backlog = handlen.max(backlog);
 
         // 添加剩余需要构建的socket
-        // kdebug!("tcp socket:before listen, socket'len={}",self.handle.len());
+        // kdebug!("tcp socket:before listen, socket'len={}", self.handle_list.len());
         let mut handle_guard = HANDLE_MAP.write_irqsave();
+        let wait_queue = Arc::clone(&handle_guard.get(&self.socket_handle()).unwrap().wait_queue);
+
         self.handles.extend((handlen..backlog).map(|_| {
             let socket = Self::create_new_socket();
             let handle = GlobalSocketHandle::new_smoltcp_handle(sockets.add(socket));
-            let handle_item = SocketHandleItem::new();
+            let handle_item = SocketHandleItem::new(Some(wait_queue.clone()));
             handle_guard.insert(handle, handle_item);
             handle
         }));
@@ -797,7 +815,7 @@ impl Socket for TcpSocket {
             }
 
             // 检测端口是否已被占用
-            PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port, self.clone())?;
+            PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port)?;
             // kdebug!("tcp socket:bind, socket'len={}",self.handle.len());
 
             self.local_endpoint = Some(ip);
@@ -818,100 +836,78 @@ impl Socket for TcpSocket {
     }
 
     fn accept(&mut self) -> Result<(Box<dyn Socket>, Endpoint), SystemError> {
+        if !self.is_listening {
+            return Err(SystemError::EINVAL);
+        }
         let endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?;
         loop {
             // kdebug!("tcp accept: poll_ifaces()");
             poll_ifaces();
-            // kdebug!("tcp socket:accept, socket'len={}",self.handle.len());
-
-            let mut sockets = SOCKET_SET.lock_irqsave();
-
-            // 随机获取访问的socket的handle
-            let index: usize = rand() % self.handles.len();
-            let handle = self.handles.get(index).unwrap();
-
-            let socket = sockets
-                .iter_mut()
-                .find(|y| {
-                    tcp::Socket::downcast(y.1)
-                        .map(|y| y.is_active())
-                        .unwrap_or(false)
-                })
-                .map(|y| tcp::Socket::downcast_mut(y.1).unwrap());
-            if let Some(socket) = socket {
-                if socket.is_active() {
-                    // kdebug!("tcp accept: socket.is_active()");
-                    let remote_ep = socket.remote_endpoint().ok_or(SystemError::ENOTCONN)?;
-
-                    let new_socket = {
-                        // The new TCP socket used for sending and receiving data.
-                        let mut tcp_socket = Self::create_new_socket();
-                        self.do_listen(&mut tcp_socket, endpoint)
-                            .expect("do_listen failed");
-
-                        // tcp_socket.listen(endpoint).unwrap();
-
-                        // 之所以把old_handle存入new_socket, 是因为当前时刻,smoltcp已经把old_handle对应的socket与远程的endpoint关联起来了
-                        // 因此需要再为当前的socket分配一个新的handle
-                        let new_handle =
-                            GlobalSocketHandle::new_smoltcp_handle(sockets.add(tcp_socket));
-                        let old_handle = ::core::mem::replace(
-                            &mut *self.handles.get_mut(index).unwrap(),
-                            new_handle,
-                        );
-
-                        let metadata = SocketMetadata::new(
-                            SocketType::Tcp,
-                            Self::DEFAULT_TX_BUF_SIZE,
-                            Self::DEFAULT_RX_BUF_SIZE,
-                            Self::DEFAULT_METADATA_BUF_SIZE,
-                            self.metadata.options,
-                        );
-
-                        let new_socket = Box::new(TcpSocket {
-                            handles: vec![old_handle],
-                            local_endpoint: self.local_endpoint,
-                            is_listening: false,
-                            metadata,
-                        });
-                        // kdebug!("tcp socket:after accept, socket'len={}",new_socket.handle.len());
-
-                        // 更新端口与 socket 的绑定
-                        if let Some(Endpoint::Ip(Some(ip))) = self.endpoint() {
-                            PORT_MANAGER.unbind_port(self.metadata.socket_type, ip.port)?;
-                            PORT_MANAGER.bind_port(
-                                self.metadata.socket_type,
-                                ip.port,
-                                *new_socket.clone(),
-                            )?;
-                        }
-
-                        // 更新handle表
-                        let mut handle_guard = HANDLE_MAP.write_irqsave();
-                        // 先删除原来的
-
-                        let item = handle_guard.remove(&old_handle).unwrap();
-
-                        // 按照smoltcp行为,将新的handle绑定到原来的item
-                        handle_guard.insert(new_handle, item);
-                        let new_item = SocketHandleItem::new();
-
-                        // 插入新的item
-                        handle_guard.insert(old_handle, new_item);
-
-                        new_socket
-                    };
-                    // kdebug!("tcp accept: new socket: {:?}", new_socket);
-                    drop(sockets);
-                    poll_ifaces();
-
-                    return Ok((new_socket, Endpoint::Ip(Some(remote_ep))));
+            // kdebug!("tcp socket:accept, socket'len={}", self.handle_list.len());
+
+            let mut sockset = SOCKET_SET.lock_irqsave();
+            // Get the corresponding activated handler
+            let global_handle_index = self.handles.iter().position(|handle| {
+                let con_smol_sock = sockset.get::<tcp::Socket>(handle.smoltcp_handle().unwrap());
+                con_smol_sock.is_active()
+            });
+
+            if let Some(handle_index) = global_handle_index {
+                let con_smol_sock = sockset
+                    .get::<tcp::Socket>(self.handles[handle_index].smoltcp_handle().unwrap());
+
+                // kdebug!("[Socket] [TCP] Accept: {:?}", handle);
+                // handle is connected socket's handle
+                let remote_ep = con_smol_sock
+                    .remote_endpoint()
+                    .ok_or(SystemError::ENOTCONN)?;
+
+                let mut tcp_socket = Self::create_new_socket();
+                self.do_listen(&mut tcp_socket, endpoint)?;
+
+                let new_handle = GlobalSocketHandle::new_smoltcp_handle(sockset.add(tcp_socket));
+
+                // let handle in TcpSock be the new empty handle, and return the old connected handle
+                let old_handle = core::mem::replace(&mut self.handles[handle_index], new_handle);
+
+                let metadata = SocketMetadata::new(
+                    SocketType::Tcp,
+                    Self::DEFAULT_TX_BUF_SIZE,
+                    Self::DEFAULT_RX_BUF_SIZE,
+                    Self::DEFAULT_METADATA_BUF_SIZE,
+                    self.metadata.options,
+                );
+
+                let sock_ret = Box::new(TcpSocket {
+                    handles: vec![old_handle],
+                    local_endpoint: self.local_endpoint,
+                    is_listening: false,
+                    metadata,
+                });
+
+                {
+                    let mut handle_guard = HANDLE_MAP.write_irqsave();
+                    // 先删除原来的
+                    let item = handle_guard.remove(&old_handle).unwrap();
+
+                    // 按照smoltcp行为,将新的handle绑定到原来的item
+                    let new_item = SocketHandleItem::new(None);
+                    handle_guard.insert(old_handle, new_item);
+                    // 插入新的item
+                    handle_guard.insert(new_handle, item);
+                    drop(handle_guard);
                 }
+                return Ok((sock_ret, Endpoint::Ip(Some(remote_ep))));
             }
-            // kdebug!("tcp socket:before sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len());
 
-            drop(sockets);
-            SocketHandleItem::sleep(*handle, Self::CAN_ACCPET, HANDLE_MAP.read_irqsave());
+            drop(sockset);
+
+            // kdebug!("[TCP] [Accept] sleeping socket with handle: {:?}", self.handles.get(0).unwrap().smoltcp_handle().unwrap());
+            SocketHandleItem::sleep(
+                self.socket_handle(), // NOTICE
+                Self::CAN_ACCPET,
+                HANDLE_MAP.read_irqsave(),
+            );
             // kdebug!("tcp socket:after sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len());
         }
     }

+ 39 - 29
kernel/src/net/socket/mod.rs

@@ -25,6 +25,7 @@ use crate::{
         spinlock::{SpinLock, SpinLockGuard},
         wait_queue::EventWaitQueue,
     },
+    process::{Pid, ProcessManager},
     sched::{schedule, SchedMode},
 };
 
@@ -85,7 +86,7 @@ pub(super) fn new_socket(
         }
     };
 
-    let handle_item = SocketHandleItem::new();
+    let handle_item = SocketHandleItem::new(None);
     HANDLE_MAP
         .write_irqsave()
         .insert(socket.socket_handle(), handle_item);
@@ -303,19 +304,8 @@ impl SocketInode {
     pub unsafe fn inner_no_preempt(&self) -> SpinLockGuard<Box<dyn Socket>> {
         self.0.lock_no_preempt()
     }
-}
 
-impl IndexNode for SocketInode {
-    fn open(
-        &self,
-        _data: SpinLockGuard<FilePrivateData>,
-        _mode: &FileMode,
-    ) -> Result<(), SystemError> {
-        self.1.fetch_add(1, core::sync::atomic::Ordering::SeqCst);
-        Ok(())
-    }
-
-    fn close(&self, _data: SpinLockGuard<FilePrivateData>) -> Result<(), SystemError> {
+    fn do_close(&self) -> Result<(), SystemError> {
         let prev_ref_count = self.1.fetch_sub(1, core::sync::atomic::Ordering::SeqCst);
         if prev_ref_count == 1 {
             // 最后一次关闭,需要释放
@@ -326,7 +316,7 @@ impl IndexNode for SocketInode {
             }
 
             if let Some(Endpoint::Ip(Some(ip))) = socket.endpoint() {
-                PORT_MANAGER.unbind_port(socket.metadata().socket_type, ip.port)?;
+                PORT_MANAGER.unbind_port(socket.metadata().socket_type, ip.port);
             }
 
             socket.clear_epoll()?;
@@ -340,6 +330,29 @@ impl IndexNode for SocketInode {
 
         Ok(())
     }
+}
+
+impl Drop for SocketInode {
+    fn drop(&mut self) {
+        for _ in 0..self.1.load(core::sync::atomic::Ordering::SeqCst) {
+            let _ = self.do_close();
+        }
+    }
+}
+
+impl IndexNode for SocketInode {
+    fn open(
+        &self,
+        _data: SpinLockGuard<FilePrivateData>,
+        _mode: &FileMode,
+    ) -> Result<(), SystemError> {
+        self.1.fetch_add(1, core::sync::atomic::Ordering::SeqCst);
+        Ok(())
+    }
+
+    fn close(&self, _data: SpinLockGuard<FilePrivateData>) -> Result<(), SystemError> {
+        self.do_close()
+    }
 
     fn read_at(
         &self,
@@ -400,16 +413,16 @@ pub struct SocketHandleItem {
     /// shutdown状态
     pub shutdown_type: RwLock<ShutdownType>,
     /// socket的waitqueue
-    pub wait_queue: EventWaitQueue,
+    pub wait_queue: Arc<EventWaitQueue>,
     /// epitems,考虑写在这是否是最优解?
     pub epitems: SpinLock<LinkedList<Arc<EPollItem>>>,
 }
 
 impl SocketHandleItem {
-    pub fn new() -> Self {
+    pub fn new(wait_queue: Option<Arc<EventWaitQueue>>) -> Self {
         Self {
             shutdown_type: RwLock::new(ShutdownType::empty()),
-            wait_queue: EventWaitQueue::new(),
+            wait_queue: wait_queue.unwrap_or(Arc::new(EventWaitQueue::new())),
             epitems: SpinLock::new(LinkedList::new()),
         }
     }
@@ -463,9 +476,9 @@ impl SocketHandleItem {
 /// 如果 TCP/UDP 的 socket 绑定了某个端口,它会在对应的表中记录,以检测端口冲突。
 pub struct PortManager {
     // TCP 端口记录表
-    tcp_port_table: SpinLock<HashMap<u16, Arc<dyn Socket>>>,
+    tcp_port_table: SpinLock<HashMap<u16, Pid>>,
     // UDP 端口记录表
-    udp_port_table: SpinLock<HashMap<u16, Arc<dyn Socket>>>,
+    udp_port_table: SpinLock<HashMap<u16, Pid>>,
 }
 
 impl PortManager {
@@ -517,12 +530,7 @@ impl PortManager {
     /// @brief 检测给定端口是否已被占用,如果未被占用则在 TCP/UDP 对应的表中记录
     ///
     /// TODO: 增加支持端口复用的逻辑
-    pub fn bind_port(
-        &self,
-        socket_type: SocketType,
-        port: u16,
-        socket: impl Socket,
-    ) -> Result<(), SystemError> {
+    pub fn bind_port(&self, socket_type: SocketType, port: u16) -> Result<(), SystemError> {
         if port > 0 {
             let mut listen_table_guard = match socket_type {
                 SocketType::Udp => self.udp_port_table.lock(),
@@ -531,7 +539,7 @@ impl PortManager {
             };
             match listen_table_guard.get(&port) {
                 Some(_) => return Err(SystemError::EADDRINUSE),
-                None => listen_table_guard.insert(port, Arc::new(socket)),
+                None => listen_table_guard.insert(port, ProcessManager::current_pid()),
             };
             drop(listen_table_guard);
         }
@@ -539,15 +547,17 @@ impl PortManager {
     }
 
     /// @brief 在对应的端口记录表中将端口和 socket 解绑
-    pub fn unbind_port(&self, socket_type: SocketType, port: u16) -> Result<(), SystemError> {
+    /// should call this function when socket is closed or aborted
+    pub fn unbind_port(&self, socket_type: SocketType, port: u16) {
         let mut listen_table_guard = match socket_type {
             SocketType::Udp => self.udp_port_table.lock(),
             SocketType::Tcp => self.tcp_port_table.lock(),
-            _ => return Ok(()),
+            _ => {
+                return;
+            }
         };
         listen_table_guard.remove(&port);
         drop(listen_table_guard);
-        return Ok(());
     }
 }
 

+ 2 - 0
user/apps/http_server/main.c

@@ -233,6 +233,8 @@ int main(int argc, char const *argv[])
         // 关闭客户端连接
         close(new_socket);
     }
+    // 关闭tcp socket
+    close(server_fd);
 
     return 0;
 }