Просмотр исходного кода

socket统一改用`GlobalSocketHandle`,并且修复fcntl SETFD的错误 (#730)

* socket统一改用`GlobalSocketHandle`,并且修复fcntl SETFD的错误

---------

Co-authored-by: longjin <[email protected]>
GnoCiYeH 11 месяцев назад
Родитель
Сommit
d623e90231

+ 1 - 1
kernel/src/arch/x86_64/syscall/mod.rs

@@ -88,7 +88,7 @@ pub extern "sysv64" fn syscall_handler(frame: &mut TrapFrame) {
     mfence();
     let pid = ProcessManager::current_pcb().pid();
     let show = false;
-    // let show = if syscall_num != SYS_SCHED && pid.data() > 3 {
+    // let show = if syscall_num != SYS_SCHED && pid.data() >= 7 {
     //     true
     // } else {
     //     false

+ 32 - 6
kernel/src/filesystem/vfs/syscall.rs

@@ -1024,6 +1024,15 @@ impl Syscall {
         oldfd: i32,
         newfd: i32,
         fd_table_guard: &mut RwLockWriteGuard<'_, FileDescriptorVec>,
+    ) -> Result<usize, SystemError> {
+        Self::do_dup3(oldfd, newfd, FileMode::empty(), fd_table_guard)
+    }
+
+    fn do_dup3(
+        oldfd: i32,
+        newfd: i32,
+        flags: FileMode,
+        fd_table_guard: &mut RwLockWriteGuard<'_, FileDescriptorVec>,
     ) -> Result<usize, SystemError> {
         // 确认oldfd, newid是否有效
         if !(FileDescriptorVec::validate_fd(oldfd) && FileDescriptorVec::validate_fd(newfd)) {
@@ -1047,8 +1056,12 @@ impl Syscall {
             .get_file_by_fd(oldfd)
             .ok_or(SystemError::EBADF)?;
         let new_file = old_file.try_clone().ok_or(SystemError::EBADF)?;
-        // dup2默认非cloexec
-        new_file.set_close_on_exec(false);
+
+        if flags.contains(FileMode::O_CLOEXEC) {
+            new_file.set_close_on_exec(true);
+        } else {
+            new_file.set_close_on_exec(false);
+        }
         // 申请文件描述符,并把文件对象存入其中
         let res = fd_table_guard
             .alloc_fd(new_file, Some(newfd))
@@ -1064,8 +1077,9 @@ impl Syscall {
     /// - `cmd`:命令
     /// - `arg`:参数
     pub fn fcntl(fd: i32, cmd: FcntlCommand, arg: i32) -> Result<usize, SystemError> {
+        // kdebug!("fcntl ({cmd:?}) fd: {fd}, arg={arg}");
         match cmd {
-            FcntlCommand::DupFd => {
+            FcntlCommand::DupFd | FcntlCommand::DupFdCloexec => {
                 if arg < 0 || arg as usize >= FileDescriptorVec::PROCESS_MAX_FD {
                     return Err(SystemError::EBADF);
                 }
@@ -1074,7 +1088,16 @@ impl Syscall {
                     let binding = ProcessManager::current_pcb().fd_table();
                     let mut fd_table_guard = binding.write();
                     if fd_table_guard.get_file_by_fd(i as i32).is_none() {
-                        return Self::do_dup2(fd, i as i32, &mut fd_table_guard);
+                        if cmd == FcntlCommand::DupFd {
+                            return Self::do_dup2(fd, i as i32, &mut fd_table_guard);
+                        } else {
+                            return Self::do_dup3(
+                                fd,
+                                i as i32,
+                                FileMode::O_CLOEXEC,
+                                &mut fd_table_guard,
+                            );
+                        }
                     }
                 }
                 return Err(SystemError::EMFILE);
@@ -1083,12 +1106,15 @@ impl Syscall {
                 // Get file descriptor flags.
                 let binding = ProcessManager::current_pcb().fd_table();
                 let fd_table_guard = binding.read();
+
                 if let Some(file) = fd_table_guard.get_file_by_fd(fd) {
                     // drop guard 以避免无法调度的问题
                     drop(fd_table_guard);
 
                     if file.close_on_exec() {
                         return Ok(FD_CLOEXEC as usize);
+                    } else {
+                        return Ok(0);
                     }
                 }
                 return Err(SystemError::EBADF);
@@ -1145,8 +1171,8 @@ impl Syscall {
                 // TODO: unimplemented
                 // 未实现的命令,返回0,不报错。
 
-                // kwarn!("fcntl: unimplemented command: {:?}, defaults to 0.", cmd);
-                return Ok(0);
+                kwarn!("fcntl: unimplemented command: {:?}, defaults to 0.", cmd);
+                return Err(SystemError::ENOSYS);
             }
         }
     }

+ 6 - 5
kernel/src/net/net_core.rs

@@ -12,7 +12,7 @@ use crate::{
 
 use super::{
     event_poll::{EPollEventType, EventPoll},
-    socket::{inet::TcpSocket, HANDLE_MAP, SOCKET_SET},
+    socket::{handle::GlobalSocketHandle, inet::TcpSocket, HANDLE_MAP, SOCKET_SET},
 };
 
 /// The network poll function, which will be called by timer.
@@ -188,7 +188,8 @@ pub fn poll_ifaces_try_lock_onetime() -> Result<(), SystemError> {
 fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> {
     for (handle, socket_type) in sockets.iter() {
         let handle_guard = HANDLE_MAP.read_irqsave();
-        let item = handle_guard.get(&handle);
+        let global_handle = GlobalSocketHandle::new_smoltcp_handle(handle);
+        let item = handle_guard.get(&global_handle);
         if item.is_none() {
             continue;
         }
@@ -203,7 +204,7 @@ fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> {
         match socket_type {
             smoltcp::socket::Socket::Raw(_) | smoltcp::socket::Socket::Udp(_) => {
                 handle_guard
-                    .get(&handle)
+                    .get(&global_handle)
                     .unwrap()
                     .wait_queue
                     .wakeup_any(events);
@@ -217,7 +218,7 @@ fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> {
                     events |= TcpSocket::CAN_CONNECT;
                 }
                 handle_guard
-                    .get(&handle)
+                    .get(&global_handle)
                     .unwrap()
                     .wait_queue
                     .wakeup_any(events);
@@ -227,7 +228,7 @@ fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> {
         }
         drop(handle_guard);
         let mut handle_guard = HANDLE_MAP.write_irqsave();
-        let handle_item = handle_guard.get_mut(&handle).unwrap();
+        let handle_item = handle_guard.get_mut(&global_handle).unwrap();
         EventPoll::wakeup_epoll(
             &handle_item.epitems,
             EPollEventType::from_bits_truncate(events as u32),

+ 39 - 0
kernel/src/net/socket/handle.rs

@@ -0,0 +1,39 @@
+use ida::IdAllocator;
+use smoltcp::iface::SocketHandle;
+
+int_like!(KernelHandle, usize);
+
+/// # socket的句柄管理组件
+/// 它在smoltcp的SocketHandle上封装了一层,增加更多的功能。
+/// 比如,在socket被关闭时,自动释放socket的资源,通知系统的其他组件。
+#[derive(Debug, Hash, Eq, PartialEq, Clone, Copy)]
+pub enum GlobalSocketHandle {
+    Smoltcp(SocketHandle),
+    Kernel(KernelHandle),
+}
+
+static KERNEL_HANDLE_IDA: IdAllocator = IdAllocator::new(0, usize::MAX);
+
+impl GlobalSocketHandle {
+    pub fn new_smoltcp_handle(handle: SocketHandle) -> Self {
+        return Self::Smoltcp(handle);
+    }
+
+    pub fn new_kernel_handle() -> Self {
+        return Self::Kernel(KernelHandle::new(KERNEL_HANDLE_IDA.alloc().unwrap()));
+    }
+
+    pub fn smoltcp_handle(&self) -> Option<SocketHandle> {
+        if let Self::Smoltcp(sh) = *self {
+            return Some(sh);
+        }
+        None
+    }
+
+    pub fn kernel_handle(&self) -> Option<KernelHandle> {
+        if let Self::Kernel(kh) = *self {
+            return Some(kh);
+        }
+        None
+    }
+}

+ 140 - 98
kernel/src/net/socket/inet.rs

@@ -1,7 +1,6 @@
 use alloc::{boxed::Box, sync::Arc, vec::Vec};
 use smoltcp::{
-    iface::SocketHandle,
-    socket::{raw, tcp, udp},
+    socket::{raw, tcp, udp, AnySocket},
     wire,
 };
 use system_error::SystemError;
@@ -18,8 +17,8 @@ use crate::{
 };
 
 use super::{
-    GlobalSocketHandle, Socket, SocketHandleItem, SocketMetadata, SocketOptions, SocketPollMethod,
-    SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET,
+    handle::GlobalSocketHandle, Socket, SocketHandleItem, SocketMetadata, SocketOptions,
+    SocketPollMethod, SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET,
 };
 
 /// @brief 表示原始的socket。原始套接字绕过传输层协议(如 TCP 或 UDP)并提供对网络层协议(如 IP)的直接访问。
@@ -27,7 +26,7 @@ use super::{
 /// ref: https://man7.org/linux/man-pages/man7/raw.7.html
 #[derive(Debug, Clone)]
 pub struct RawSocket {
-    handle: Arc<GlobalSocketHandle>,
+    handle: GlobalSocketHandle,
     /// 用户发送的数据包是否包含了IP头.
     /// 如果是true,用户发送的数据包,必须包含IP头。(即用户要自行设置IP头+数据)
     /// 如果是false,用户发送的数据包,不包含IP头。(即用户只要设置数据)
@@ -68,8 +67,7 @@ impl RawSocket {
         );
 
         // 把socket添加到socket集合中,并得到socket的句柄
-        let handle: Arc<GlobalSocketHandle> =
-            GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket));
+        let handle = GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket));
 
         let metadata = SocketMetadata::new(
             SocketType::Raw,
@@ -88,12 +86,20 @@ impl RawSocket {
 }
 
 impl Socket for RawSocket {
+    fn close(&mut self) {
+        let mut socket_set_guard = SOCKET_SET.lock_irqsave();
+        socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()); // 删除的时候,会发送一条FINISH的信息?
+        drop(socket_set_guard);
+        poll_ifaces();
+    }
+
     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
         poll_ifaces();
         loop {
             // 如何优化这里?
             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
-            let socket = socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
+            let socket =
+                socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap());
 
             match socket.recv_slice(buf) {
                 Ok(len) => {
@@ -126,7 +132,8 @@ impl Socket for RawSocket {
         // 如果用户发送的数据包,包含IP头,则直接发送
         if self.header_included {
             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
-            let socket = socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
+            let socket =
+                socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap());
             match socket.send_slice(buf) {
                 Ok(_) => {
                     return Ok(buf.len());
@@ -141,7 +148,7 @@ impl Socket for RawSocket {
             if let Some(Endpoint::Ip(Some(endpoint))) = to {
                 let mut socket_set_guard = SOCKET_SET.lock_irqsave();
                 let socket: &mut raw::Socket =
-                    socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
+                    socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap());
 
                 // 暴力解决方案:只考虑0号网卡。 TODO:考虑多网卡的情况!!!
                 let iface = NET_DRIVERS.read_irqsave().get(&0).unwrap().clone();
@@ -209,8 +216,8 @@ impl Socket for RawSocket {
         Box::new(self.clone())
     }
 
-    fn socket_handle(&self) -> SocketHandle {
-        self.handle.0
+    fn socket_handle(&self) -> GlobalSocketHandle {
+        self.handle
     }
 
     fn as_any_ref(&self) -> &dyn core::any::Any {
@@ -227,7 +234,7 @@ impl Socket for RawSocket {
 /// https://man7.org/linux/man-pages/man7/udp.7.html
 #[derive(Debug, Clone)]
 pub struct UdpSocket {
-    pub handle: Arc<GlobalSocketHandle>,
+    pub handle: GlobalSocketHandle,
     remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect(), 应该使用IP地址。
     metadata: SocketMetadata,
 }
@@ -257,8 +264,8 @@ impl UdpSocket {
         let socket = udp::Socket::new(rx_buffer, tx_buffer);
 
         // 把socket添加到socket集合中,并得到socket的句柄
-        let handle: Arc<GlobalSocketHandle> =
-            GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket));
+        let handle: GlobalSocketHandle =
+            GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket));
 
         let metadata = SocketMetadata::new(
             SocketType::Udp,
@@ -301,13 +308,21 @@ impl UdpSocket {
 }
 
 impl Socket for UdpSocket {
+    fn close(&mut self) {
+        let mut socket_set_guard = SOCKET_SET.lock_irqsave();
+        socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()); // 删除的时候,会发送一条FINISH的信息?
+        drop(socket_set_guard);
+        poll_ifaces();
+    }
+
     /// @brief 在read函数执行之前,请先bind到本地的指定端口
     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
         loop {
             // kdebug!("Wait22 to Read");
             poll_ifaces();
             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
-            let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.0);
+            let socket =
+                socket_set_guard.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
 
             // kdebug!("Wait to Read");
 
@@ -344,7 +359,7 @@ impl Socket for UdpSocket {
         // kdebug!("udp write: remote = {:?}", remote_endpoint);
 
         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
-        let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.0);
+        let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
         // kdebug!("is open()={}", socket.is_open());
         // kdebug!("socket endpoint={:?}", socket.endpoint());
         if socket.can_send() {
@@ -369,14 +384,14 @@ impl Socket for UdpSocket {
 
     fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
         let mut sockets = SOCKET_SET.lock_irqsave();
-        let socket = sockets.get_mut::<udp::Socket>(self.handle.0);
+        let socket = sockets.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
         // kdebug!("UDP Bind to {:?}", endpoint);
         return self.do_bind(socket, endpoint);
     }
 
     fn poll(&self) -> EPollEventType {
         let sockets = SOCKET_SET.lock_irqsave();
-        let socket = sockets.get::<udp::Socket>(self.handle.0);
+        let socket = sockets.get::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
 
         return SocketPollMethod::udp_poll(
             socket,
@@ -417,7 +432,7 @@ impl Socket for UdpSocket {
 
     fn endpoint(&self) -> Option<Endpoint> {
         let sockets = SOCKET_SET.lock_irqsave();
-        let socket = sockets.get::<udp::Socket>(self.handle.0);
+        let socket = sockets.get::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
         let listen_endpoint = socket.endpoint();
 
         if listen_endpoint.port == 0 {
@@ -440,8 +455,8 @@ impl Socket for UdpSocket {
         return self.remote_endpoint.clone();
     }
 
-    fn socket_handle(&self) -> SocketHandle {
-        self.handle.0
+    fn socket_handle(&self) -> GlobalSocketHandle {
+        self.handle
     }
 
     fn as_any_ref(&self) -> &dyn core::any::Any {
@@ -458,7 +473,7 @@ impl Socket for UdpSocket {
 /// https://man7.org/linux/man-pages/man7/tcp.7.html
 #[derive(Debug, Clone)]
 pub struct TcpSocket {
-    handles: Vec<Arc<GlobalSocketHandle>>,
+    handles: Vec<GlobalSocketHandle>,
     local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind()
     is_listening: bool,
     metadata: SocketMetadata,
@@ -483,7 +498,7 @@ impl TcpSocket {
     /// @return 返回创建的tcp的socket
     pub fn new(options: SocketOptions) -> Self {
         // 创建handles数组并把socket添加到socket集合中,并得到socket的句柄
-        let handles: Vec<Arc<GlobalSocketHandle>> = vec![GlobalSocketHandle::new(
+        let handles: Vec<GlobalSocketHandle> = vec![GlobalSocketHandle::new_smoltcp_handle(
             SOCKET_SET.lock_irqsave().add(Self::create_new_socket()),
         )];
 
@@ -542,6 +557,15 @@ impl TcpSocket {
 }
 
 impl Socket for TcpSocket {
+    fn close(&mut self) {
+        for handle in self.handles.iter() {
+            let mut socket_set_guard = SOCKET_SET.lock_irqsave();
+            socket_set_guard.remove(handle.smoltcp_handle().unwrap()); // 删除的时候,会发送一条FINISH的信息?
+            drop(socket_set_guard);
+        }
+        poll_ifaces();
+    }
+
     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
         if HANDLE_MAP
             .read_irqsave()
@@ -558,7 +582,8 @@ impl Socket for TcpSocket {
             poll_ifaces();
             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
 
-            let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().0);
+            let socket = socket_set_guard
+                .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
 
             // 如果socket已经关闭,返回错误
             if !socket.is_active() {
@@ -626,7 +651,8 @@ impl Socket for TcpSocket {
 
         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
 
-        let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().0);
+        let socket = socket_set_guard
+            .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
 
         if socket.is_open() {
             if socket.can_send() {
@@ -653,7 +679,8 @@ impl Socket for TcpSocket {
         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
         // kdebug!("tcp socket:poll, socket'len={}",self.handle.len());
 
-        let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().0);
+        let socket = socket_set_guard
+            .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
         return SocketPollMethod::tcp_poll(
             socket,
             HANDLE_MAP
@@ -668,7 +695,8 @@ impl Socket for TcpSocket {
         let mut sockets = SOCKET_SET.lock_irqsave();
         // kdebug!("tcp socket:connect, socket'len={}",self.handle.len());
 
-        let socket = sockets.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().0);
+        let socket =
+            sockets.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
 
         if let Endpoint::Ip(Some(ip)) = endpoint {
             let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
@@ -689,7 +717,9 @@ impl Socket for TcpSocket {
                     loop {
                         poll_ifaces();
                         let mut sockets = SOCKET_SET.lock_irqsave();
-                        let socket = sockets.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().0);
+                        let socket = sockets.get_mut::<tcp::Socket>(
+                            self.handles.get(0).unwrap().smoltcp_handle().unwrap(),
+                        );
 
                         match socket.state() {
                             tcp::State::Established => {
@@ -741,9 +771,9 @@ impl Socket for TcpSocket {
         let mut handle_guard = HANDLE_MAP.write_irqsave();
         self.handles.extend((handlen..backlog).map(|_| {
             let socket = Self::create_new_socket();
-            let handle = GlobalSocketHandle::new(sockets.add(socket));
+            let handle = GlobalSocketHandle::new_smoltcp_handle(sockets.add(socket));
             let handle_item = SocketHandleItem::new();
-            handle_guard.insert(handle.0, handle_item);
+            handle_guard.insert(handle, handle_item);
             handle
         }));
         // kdebug!("tcp socket:listen, socket'len={}",self.handle.len());
@@ -753,7 +783,7 @@ impl Socket for TcpSocket {
         for i in 0..backlog {
             let handle = self.handles.get(i).unwrap();
 
-            let socket = sockets.get_mut::<tcp::Socket>(handle.0);
+            let socket = sockets.get_mut::<tcp::Socket>(handle.smoltcp_handle().unwrap());
 
             if !socket.is_listening() {
                 // kdebug!("Tcp Socket is already listening on {local_endpoint}");
@@ -803,79 +833,89 @@ impl Socket for TcpSocket {
             // 随机获取访问的socket的handle
             let index: usize = rand() % self.handles.len();
             let handle = self.handles.get(index).unwrap();
-            let socket = sockets.get_mut::<tcp::Socket>(handle.0);
-
-            if socket.is_active() {
-                // kdebug!("tcp accept: socket.is_active()");
-                let remote_ep = socket.remote_endpoint().ok_or(SystemError::ENOTCONN)?;
-
-                let new_socket = {
-                    // The new TCP socket used for sending and receiving data.
-                    let mut tcp_socket = Self::create_new_socket();
-                    self.do_listen(&mut tcp_socket, endpoint)
-                        .expect("do_listen failed");
-
-                    // tcp_socket.listen(endpoint).unwrap();
-
-                    // 之所以把old_handle存入new_socket, 是因为当前时刻,smoltcp已经把old_handle对应的socket与远程的endpoint关联起来了
-                    // 因此需要再为当前的socket分配一个新的handle
-                    let new_handle = GlobalSocketHandle::new(sockets.add(tcp_socket));
-                    let old_handle = ::core::mem::replace(
-                        &mut *self.handles.get_mut(index).unwrap(),
-                        new_handle.clone(),
-                    );
-
-                    let metadata = SocketMetadata::new(
-                        SocketType::Tcp,
-                        Self::DEFAULT_TX_BUF_SIZE,
-                        Self::DEFAULT_RX_BUF_SIZE,
-                        Self::DEFAULT_METADATA_BUF_SIZE,
-                        self.metadata.options,
-                    );
 
-                    let new_socket = Box::new(TcpSocket {
-                        handles: vec![old_handle.clone()],
-                        local_endpoint: self.local_endpoint,
-                        is_listening: false,
-                        metadata,
-                    });
-                    // kdebug!("tcp socket:after accept, socket'len={}",new_socket.handle.len());
-
-                    // 更新端口与 socket 的绑定
-                    if let Some(Endpoint::Ip(Some(ip))) = self.endpoint() {
-                        PORT_MANAGER.unbind_port(self.metadata.socket_type, ip.port)?;
-                        PORT_MANAGER.bind_port(
-                            self.metadata.socket_type,
-                            ip.port,
-                            *new_socket.clone(),
-                        )?;
-                    }
+            let socket = sockets
+                .iter_mut()
+                .find(|y| {
+                    tcp::Socket::downcast(y.1)
+                        .map(|y| y.is_active())
+                        .unwrap_or(false)
+                })
+                .map(|y| tcp::Socket::downcast_mut(y.1).unwrap());
+            if let Some(socket) = socket {
+                if socket.is_active() {
+                    // kdebug!("tcp accept: socket.is_active()");
+                    let remote_ep = socket.remote_endpoint().ok_or(SystemError::ENOTCONN)?;
+
+                    let new_socket = {
+                        // The new TCP socket used for sending and receiving data.
+                        let mut tcp_socket = Self::create_new_socket();
+                        self.do_listen(&mut tcp_socket, endpoint)
+                            .expect("do_listen failed");
+
+                        // tcp_socket.listen(endpoint).unwrap();
+
+                        // 之所以把old_handle存入new_socket, 是因为当前时刻,smoltcp已经把old_handle对应的socket与远程的endpoint关联起来了
+                        // 因此需要再为当前的socket分配一个新的handle
+                        let new_handle =
+                            GlobalSocketHandle::new_smoltcp_handle(sockets.add(tcp_socket));
+                        let old_handle = ::core::mem::replace(
+                            &mut *self.handles.get_mut(index).unwrap(),
+                            new_handle,
+                        );
+
+                        let metadata = SocketMetadata::new(
+                            SocketType::Tcp,
+                            Self::DEFAULT_TX_BUF_SIZE,
+                            Self::DEFAULT_RX_BUF_SIZE,
+                            Self::DEFAULT_METADATA_BUF_SIZE,
+                            self.metadata.options,
+                        );
+
+                        let new_socket = Box::new(TcpSocket {
+                            handles: vec![old_handle],
+                            local_endpoint: self.local_endpoint,
+                            is_listening: false,
+                            metadata,
+                        });
+                        // kdebug!("tcp socket:after accept, socket'len={}",new_socket.handle.len());
+
+                        // 更新端口与 socket 的绑定
+                        if let Some(Endpoint::Ip(Some(ip))) = self.endpoint() {
+                            PORT_MANAGER.unbind_port(self.metadata.socket_type, ip.port)?;
+                            PORT_MANAGER.bind_port(
+                                self.metadata.socket_type,
+                                ip.port,
+                                *new_socket.clone(),
+                            )?;
+                        }
 
-                    // 更新handle表
-                    let mut handle_guard = HANDLE_MAP.write_irqsave();
-                    // 先删除原来的
+                        // 更新handle表
+                        let mut handle_guard = HANDLE_MAP.write_irqsave();
+                        // 先删除原来的
 
-                    let item = handle_guard.remove(&old_handle.0).unwrap();
+                        let item = handle_guard.remove(&old_handle).unwrap();
 
-                    // 按照smoltcp行为,将新的handle绑定到原来的item
-                    handle_guard.insert(new_handle.0, item);
-                    let new_item = SocketHandleItem::new();
+                        // 按照smoltcp行为,将新的handle绑定到原来的item
+                        handle_guard.insert(new_handle, item);
+                        let new_item = SocketHandleItem::new();
 
-                    // 插入新的item
-                    handle_guard.insert(old_handle.0, new_item);
+                        // 插入新的item
+                        handle_guard.insert(old_handle, new_item);
 
-                    new_socket
-                };
-                // kdebug!("tcp accept: new socket: {:?}", new_socket);
-                drop(sockets);
-                poll_ifaces();
+                        new_socket
+                    };
+                    // kdebug!("tcp accept: new socket: {:?}", new_socket);
+                    drop(sockets);
+                    poll_ifaces();
 
-                return Ok((new_socket, Endpoint::Ip(Some(remote_ep))));
+                    return Ok((new_socket, Endpoint::Ip(Some(remote_ep))));
+                }
             }
             // kdebug!("tcp socket:before sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len());
 
             drop(sockets);
-            SocketHandleItem::sleep(handle.0, Self::CAN_ACCPET, HANDLE_MAP.read_irqsave());
+            SocketHandleItem::sleep(*handle, Self::CAN_ACCPET, HANDLE_MAP.read_irqsave());
             // kdebug!("tcp socket:after sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len());
         }
     }
@@ -887,7 +927,8 @@ impl Socket for TcpSocket {
             let sockets = SOCKET_SET.lock_irqsave();
             // kdebug!("tcp socket:endpoint, socket'len={}",self.handle.len());
 
-            let socket = sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().0);
+            let socket =
+                sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
             if let Some(ep) = socket.local_endpoint() {
                 result = Some(Endpoint::Ip(Some(ep)));
             }
@@ -899,7 +940,8 @@ impl Socket for TcpSocket {
         let sockets = SOCKET_SET.lock_irqsave();
         // kdebug!("tcp socket:peer_endpoint, socket'len={}",self.handle.len());
 
-        let socket = sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().0);
+        let socket =
+            sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
         return socket.remote_endpoint().map(|x| Endpoint::Ip(Some(x)));
     }
 
@@ -911,10 +953,10 @@ impl Socket for TcpSocket {
         Box::new(self.clone())
     }
 
-    fn socket_handle(&self) -> SocketHandle {
+    fn socket_handle(&self) -> GlobalSocketHandle {
         // kdebug!("tcp socket:socket_handle, socket'len={}",self.handle.len());
 
-        self.handles.get(0).unwrap().0
+        *self.handles.get(0).unwrap()
     }
 
     fn as_any_ref(&self) -> &dyn core::any::Any {

+ 15 - 35
kernel/src/net/socket/mod.rs

@@ -9,7 +9,7 @@ use alloc::{
 };
 use hashbrown::HashMap;
 use smoltcp::{
-    iface::{SocketHandle, SocketSet},
+    iface::SocketSet,
     socket::{self, tcp, udp},
 };
 use system_error::SystemError;
@@ -29,16 +29,17 @@ use crate::{
 };
 
 use self::{
+    handle::GlobalSocketHandle,
     inet::{RawSocket, TcpSocket, UdpSocket},
     unix::{SeqpacketSocket, StreamSocket},
 };
 
 use super::{
     event_poll::{EPollEventType, EPollItem, EventPoll},
-    net_core::poll_ifaces,
     Endpoint, Protocol, ShutdownType,
 };
 
+pub mod handle;
 pub mod inet;
 pub mod unix;
 
@@ -48,7 +49,7 @@ lazy_static! {
     pub static ref SOCKET_SET: SpinLock<SocketSet<'static >> = SpinLock::new(SocketSet::new(vec![]));
     /// SocketHandle表,每个SocketHandle对应一个SocketHandleItem,
     /// 注意!:在网卡中断中需要拿到这张表的🔓,在获取读锁时应该确保关中断避免死锁
-    pub static ref HANDLE_MAP: RwLock<HashMap<SocketHandle, SocketHandleItem>> = RwLock::new(HashMap::new());
+    pub static ref HANDLE_MAP: RwLock<HashMap<GlobalSocketHandle, SocketHandleItem>> = RwLock::new(HashMap::new());
     /// 端口管理器
     pub static ref PORT_MANAGER: PortManager = PortManager::new();
 }
@@ -83,6 +84,11 @@ pub(super) fn new_socket(
             return Err(SystemError::EAFNOSUPPORT);
         }
     };
+
+    let handle_item = SocketHandleItem::new();
+    HANDLE_MAP
+        .write_irqsave()
+        .insert(socket.socket_handle(), handle_item);
     Ok(socket)
 }
 
@@ -224,9 +230,7 @@ pub trait Socket: Sync + Send + Debug + Any {
         Ok(())
     }
 
-    fn socket_handle(&self) -> SocketHandle {
-        todo!()
-    }
+    fn socket_handle(&self) -> GlobalSocketHandle;
 
     fn write_buffer(&self, _buf: &[u8]) -> Result<usize, SystemError> {
         todo!()
@@ -272,6 +276,8 @@ pub trait Socket: Sync + Send + Debug + Any {
 
         Ok(())
     }
+
+    fn close(&mut self);
 }
 
 impl Clone for Box<dyn Socket> {
@@ -329,6 +335,7 @@ impl IndexNode for SocketInode {
                 .write_irqsave()
                 .remove(&socket.socket_handle())
                 .unwrap();
+            socket.close();
         }
 
         Ok(())
@@ -409,9 +416,9 @@ impl SocketHandleItem {
 
     /// ## 在socket的等待队列上睡眠
     pub fn sleep(
-        socket_handle: SocketHandle,
+        socket_handle: GlobalSocketHandle,
         events: u64,
-        handle_map_guard: RwLockReadGuard<'_, HashMap<SocketHandle, SocketHandleItem>>,
+        handle_map_guard: RwLockReadGuard<'_, HashMap<GlobalSocketHandle, SocketHandleItem>>,
     ) {
         unsafe {
             handle_map_guard
@@ -544,33 +551,6 @@ impl PortManager {
     }
 }
 
-/// # socket的句柄管理组件
-/// 它在smoltcp的SocketHandle上封装了一层,增加更多的功能。
-/// 比如,在socket被关闭时,自动释放socket的资源,通知系统的其他组件。
-#[derive(Debug)]
-pub struct GlobalSocketHandle(SocketHandle);
-
-impl GlobalSocketHandle {
-    pub fn new(handle: SocketHandle) -> Arc<Self> {
-        return Arc::new(Self(handle));
-    }
-}
-
-impl Clone for GlobalSocketHandle {
-    fn clone(&self) -> Self {
-        Self(self.0)
-    }
-}
-
-impl Drop for GlobalSocketHandle {
-    fn drop(&mut self) {
-        let mut socket_set_guard = SOCKET_SET.lock_irqsave();
-        socket_set_guard.remove(self.0); // 删除的时候,会发送一条FINISH的信息?
-        drop(socket_set_guard);
-        poll_ifaces();
-    }
-}
-
 /// @brief socket的类型
 #[derive(Debug, Clone, Copy, PartialEq)]
 pub enum SocketType {

+ 19 - 1
kernel/src/net/socket/unix.rs

@@ -3,13 +3,16 @@ use system_error::SystemError;
 
 use crate::{libs::spinlock::SpinLock, net::Endpoint};
 
-use super::{Socket, SocketInode, SocketMetadata, SocketOptions, SocketType};
+use super::{
+    handle::GlobalSocketHandle, Socket, SocketInode, SocketMetadata, SocketOptions, SocketType,
+};
 
 #[derive(Debug, Clone)]
 pub struct StreamSocket {
     metadata: SocketMetadata,
     buffer: Arc<SpinLock<Vec<u8>>>,
     peer_inode: Option<Arc<SocketInode>>,
+    handle: GlobalSocketHandle,
 }
 
 impl StreamSocket {
@@ -37,11 +40,18 @@ impl StreamSocket {
             metadata,
             buffer,
             peer_inode: None,
+            handle: GlobalSocketHandle::new_kernel_handle(),
         }
     }
 }
 
 impl Socket for StreamSocket {
+    fn socket_handle(&self) -> GlobalSocketHandle {
+        self.handle
+    }
+
+    fn close(&mut self) {}
+
     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
         let mut buffer = self.buffer.lock_irqsave();
 
@@ -110,6 +120,7 @@ pub struct SeqpacketSocket {
     metadata: SocketMetadata,
     buffer: Arc<SpinLock<Vec<u8>>>,
     peer_inode: Option<Arc<SocketInode>>,
+    handle: GlobalSocketHandle,
 }
 
 impl SeqpacketSocket {
@@ -137,11 +148,14 @@ impl SeqpacketSocket {
             metadata,
             buffer,
             peer_inode: None,
+            handle: GlobalSocketHandle::new_kernel_handle(),
         }
     }
 }
 
 impl Socket for SeqpacketSocket {
+    fn close(&mut self) {}
+
     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
         let mut buffer = self.buffer.lock_irqsave();
 
@@ -188,6 +202,10 @@ impl Socket for SeqpacketSocket {
         Ok(len)
     }
 
+    fn socket_handle(&self) -> GlobalSocketHandle {
+        self.handle
+    }
+
     fn metadata(&self) -> SocketMetadata {
         self.metadata.clone()
     }

+ 1 - 8
kernel/src/net/syscall.rs

@@ -19,7 +19,7 @@ use crate::{
 };
 
 use super::{
-    socket::{new_socket, PosixSocketType, Socket, SocketHandleItem, SocketInode, HANDLE_MAP},
+    socket::{new_socket, PosixSocketType, Socket, SocketInode},
     Endpoint, Protocol, ShutdownType,
 };
 
@@ -44,13 +44,6 @@ impl Syscall {
 
         let socket = new_socket(address_family, socket_type, protocol)?;
 
-        if address_family != AddressFamily::Unix {
-            let handle_item = SocketHandleItem::new();
-            HANDLE_MAP
-                .write_irqsave()
-                .insert(socket.socket_handle(), handle_item);
-        }
-
         let socketinode: Arc<SocketInode> = SocketInode::new(socket);
         let f = File::new(socketinode, FileMode::O_RDWR)?;
         // 把socket添加到当前进程的文件描述符表中