瀏覽代碼

fix: tcp poll没有正确处理posix socket的listen状态的问题 (#859)

LoGin 7 月之前
父節點
當前提交
634349e0eb
共有 5 個文件被更改,包括 222 次插入122 次删除
  1. 5 0
      kernel/src/net/event_poll/mod.rs
  2. 11 14
      kernel/src/net/net_core.rs
  3. 95 44
      kernel/src/net/socket/inet.rs
  4. 95 63
      kernel/src/net/socket/mod.rs
  5. 16 1
      kernel/src/net/socket/unix.rs

+ 5 - 0
kernel/src/net/event_poll/mod.rs

@@ -436,6 +436,7 @@ impl EventPoll {
             }
             // 判断epoll上有没有就绪事件
             let mut available = epoll_guard.ep_events_available();
+
             drop(epoll_guard);
             loop {
                 if available {
@@ -759,6 +760,7 @@ impl EventPoll {
 /// 与C兼容的Epoll事件结构体
 #[derive(Copy, Clone, Default)]
 #[repr(packed)]
+#[repr(C)]
 pub struct EPollEvent {
     /// 表示触发的事件
     events: u32,
@@ -870,5 +872,8 @@ bitflags! {
 
         /// 表示epoll已经被释放,但是在目前的设计中未用到
         const POLLFREE = 0x4000;
+
+        /// listen状态的socket可以接受连接
+        const EPOLL_LISTEN_CAN_ACCEPT = Self::EPOLLIN.bits | Self::EPOLLRDNORM.bits;
     }
 }

+ 11 - 14
kernel/src/net/net_core.rs

@@ -191,25 +191,25 @@ fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> {
     for (handle, socket_type) in sockets.iter() {
         let handle_guard = HANDLE_MAP.read_irqsave();
         let global_handle = GlobalSocketHandle::new_smoltcp_handle(handle);
-        let item = handle_guard.get(&global_handle);
+        let item: Option<&super::socket::SocketHandleItem> = handle_guard.get(&global_handle);
         if item.is_none() {
             continue;
         }
 
         let handle_item = item.unwrap();
+        let posix_item = handle_item.posix_item();
+        if posix_item.is_none() {
+            continue;
+        }
+        let posix_item = posix_item.unwrap();
 
         // 获取socket上的事件
-        let mut events =
-            SocketPollMethod::poll(socket_type, handle_item.shutdown_type()).bits() as u64;
+        let mut events = SocketPollMethod::poll(socket_type, handle_item).bits() as u64;
 
         // 分发到相应类型socket处理
         match socket_type {
             smoltcp::socket::Socket::Raw(_) | smoltcp::socket::Socket::Udp(_) => {
-                handle_guard
-                    .get(&global_handle)
-                    .unwrap()
-                    .wait_queue
-                    .wakeup_any(events);
+                posix_item.wakeup_any(events);
             }
             smoltcp::socket::Socket::Icmp(_) => unimplemented!("Icmp socket hasn't unimplemented"),
             smoltcp::socket::Socket::Tcp(inner_socket) => {
@@ -222,17 +222,14 @@ fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> {
                 if inner_socket.state() == smoltcp::socket::tcp::State::CloseWait {
                     events |= EPollEventType::EPOLLHUP.bits() as u64;
                 }
-                handle_guard
-                    .get(&global_handle)
-                    .unwrap()
-                    .wait_queue
-                    .wakeup_any(events);
+
+                posix_item.wakeup_any(events);
             }
             smoltcp::socket::Socket::Dhcpv4(_) => {}
             smoltcp::socket::Socket::Dns(_) => unimplemented!("Dns socket hasn't unimplemented"),
         }
         EventPoll::wakeup_epoll(
-            &handle_item.epitems,
+            &posix_item.epitems,
             EPollEventType::from_bits_truncate(events as u32),
         )?;
         drop(handle_guard);

+ 95 - 44
kernel/src/net/socket/inet.rs

@@ -16,8 +16,8 @@ use crate::{
 };
 
 use super::{
-    handle::GlobalSocketHandle, Socket, SocketHandleItem, SocketMetadata, SocketOptions,
-    SocketPollMethod, SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET,
+    handle::GlobalSocketHandle, PosixSocketHandleItem, Socket, SocketHandleItem, SocketMetadata,
+    SocketOptions, SocketPollMethod, SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET,
 };
 
 /// @brief 表示原始的socket。原始套接字绕过传输层协议(如 TCP 或 UDP)并提供对网络层协议(如 IP)的直接访问。
@@ -32,6 +32,7 @@ pub struct RawSocket {
     header_included: bool,
     /// socket的metadata
     metadata: SocketMetadata,
+    posix_item: Arc<PosixSocketHandleItem>,
 }
 
 impl RawSocket {
@@ -76,15 +77,22 @@ impl RawSocket {
             options,
         );
 
+        let posix_item = Arc::new(PosixSocketHandleItem::new(None));
+
         return Self {
             handle,
             header_included: false,
             metadata,
+            posix_item,
         };
     }
 }
 
 impl Socket for RawSocket {
+    fn posix_item(&self) -> Arc<PosixSocketHandleItem> {
+        self.posix_item.clone()
+    }
+
     fn close(&mut self) {
         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
         if let smoltcp::socket::Socket::Udp(mut sock) =
@@ -123,11 +131,7 @@ impl Socket for RawSocket {
                 }
             }
             drop(socket_set_guard);
-            SocketHandleItem::sleep(
-                self.socket_handle(),
-                EPollEventType::EPOLLIN.bits() as u64,
-                HANDLE_MAP.read_irqsave(),
-            );
+            self.posix_item.sleep(EPollEventType::EPOLLIN.bits() as u64);
         }
     }
 
@@ -240,6 +244,7 @@ pub struct UdpSocket {
     pub handle: GlobalSocketHandle,
     remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect(), 应该使用IP地址。
     metadata: SocketMetadata,
+    posix_item: Arc<PosixSocketHandleItem>,
 }
 
 impl UdpSocket {
@@ -278,10 +283,13 @@ impl UdpSocket {
             options,
         );
 
+        let posix_item = Arc::new(PosixSocketHandleItem::new(None));
+
         return Self {
             handle,
             remote_endpoint: None,
             metadata,
+            posix_item,
         };
     }
 
@@ -311,6 +319,10 @@ impl UdpSocket {
 }
 
 impl Socket for UdpSocket {
+    fn posix_item(&self) -> Arc<PosixSocketHandleItem> {
+        self.posix_item.clone()
+    }
+
     fn close(&mut self) {
         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
         if let smoltcp::socket::Socket::Udp(mut sock) =
@@ -344,11 +356,7 @@ impl Socket for UdpSocket {
                 // return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
             }
             drop(socket_set_guard);
-            SocketHandleItem::sleep(
-                self.socket_handle(),
-                EPollEventType::EPOLLIN.bits() as u64,
-                HANDLE_MAP.read_irqsave(),
-            );
+            self.posix_item.sleep(EPollEventType::EPOLLIN.bits() as u64);
         }
     }
 
@@ -484,6 +492,7 @@ pub struct TcpSocket {
     local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind()
     is_listening: bool,
     metadata: SocketMetadata,
+    posix_item: Arc<PosixSocketHandleItem>,
 }
 
 impl TcpSocket {
@@ -516,6 +525,7 @@ impl TcpSocket {
             Self::DEFAULT_METADATA_BUF_SIZE,
             options,
         );
+        let posix_item = Arc::new(PosixSocketHandleItem::new(None));
         // debug!("when there's a new tcp socket,its'len: {}",handles.len());
 
         return Self {
@@ -523,6 +533,7 @@ impl TcpSocket {
             local_endpoint: None,
             is_listening: false,
             metadata,
+            posix_item,
         };
     }
 
@@ -532,10 +543,8 @@ impl TcpSocket {
         local_endpoint: wire::IpEndpoint,
     ) -> Result<(), SystemError> {
         let listen_result = if local_endpoint.addr.is_unspecified() {
-            // debug!("Tcp Socket Listen on port {}", local_endpoint.port);
             socket.listen(local_endpoint.port)
         } else {
-            // debug!("Tcp Socket Listen on {local_endpoint}");
             socket.listen(local_endpoint)
         };
         return match listen_result {
@@ -561,9 +570,33 @@ impl TcpSocket {
         let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]);
         tcp::Socket::new(rx_buffer, tx_buffer)
     }
+
+    /// listening状态的posix socket是需要特殊处理的
+    fn tcp_poll_listening(&self) -> EPollEventType {
+        let socketset_guard = SOCKET_SET.lock_irqsave();
+
+        let can_accept = self.handles.iter().any(|h| {
+            if let Some(sh) = h.smoltcp_handle() {
+                let socket = socketset_guard.get::<tcp::Socket>(sh);
+                socket.is_active()
+            } else {
+                false
+            }
+        });
+
+        if can_accept {
+            return EPollEventType::EPOLL_LISTEN_CAN_ACCEPT;
+        } else {
+            return EPollEventType::empty();
+        }
+    }
 }
 
 impl Socket for TcpSocket {
+    fn posix_item(&self) -> Arc<PosixSocketHandleItem> {
+        self.posix_item.clone()
+    }
+
     fn close(&mut self) {
         for handle in self.handles.iter() {
             {
@@ -641,11 +674,8 @@ impl Socket for TcpSocket {
                 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
             }
             drop(socket_set_guard);
-            SocketHandleItem::sleep(
-                self.socket_handle(),
-                (EPollEventType::EPOLLIN.bits() | EPollEventType::EPOLLHUP.bits()) as u64,
-                HANDLE_MAP.read_irqsave(),
-            );
+            self.posix_item
+                .sleep((EPollEventType::EPOLLIN | EPollEventType::EPOLLHUP).bits() as u64);
         }
     }
 
@@ -688,24 +718,31 @@ impl Socket for TcpSocket {
     }
 
     fn poll(&self) -> EPollEventType {
+        // 处理listen的快速路径
+        if self.is_listening {
+            return self.tcp_poll_listening();
+        }
+        // 由于上面处理了listening状态,所以这里只处理非listening状态,这种情况下只有一个handle
+
+        assert!(self.handles.len() == 1);
+
         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
         // debug!("tcp socket:poll, socket'len={}",self.handle.len());
 
         let socket = socket_set_guard
             .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
-        return SocketPollMethod::tcp_poll(
-            socket,
-            HANDLE_MAP
-                .read_irqsave()
-                .get(&self.socket_handle())
-                .unwrap()
-                .shutdown_type(),
-        );
+        let handle_map_guard = HANDLE_MAP.read_irqsave();
+        let handle_item = handle_map_guard.get(&self.socket_handle()).unwrap();
+        let shutdown_type = handle_item.shutdown_type();
+        let is_posix_listen = handle_item.is_posix_listen;
+        drop(handle_map_guard);
+
+        return SocketPollMethod::tcp_poll(socket, shutdown_type, is_posix_listen);
     }
 
     fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
         let mut sockets = SOCKET_SET.lock_irqsave();
-        // debug!("tcp socket:connect, socket'len={}",self.handle.len());
+        // debug!("tcp socket:connect, socket'len={}", self.handles.len());
 
         let socket =
             sockets.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
@@ -739,11 +776,7 @@ impl Socket for TcpSocket {
                             }
                             tcp::State::SynSent => {
                                 drop(sockets);
-                                SocketHandleItem::sleep(
-                                    self.socket_handle(),
-                                    Self::CAN_CONNECT,
-                                    HANDLE_MAP.read_irqsave(),
-                                );
+                                self.posix_item.sleep(Self::CAN_CONNECT);
                             }
                             _ => {
                                 return Err(SystemError::ECONNREFUSED);
@@ -772,6 +805,11 @@ impl Socket for TcpSocket {
             return Ok(());
         }
 
+        // debug!(
+        //     "tcp socket:listen, socket'len={}, backlog = {backlog}",
+        //     self.handles.len()
+        // );
+
         let local_endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?;
         let mut sockets = SOCKET_SET.lock_irqsave();
         // 获取handle的数量
@@ -781,16 +819,19 @@ impl Socket for TcpSocket {
         // 添加剩余需要构建的socket
         // debug!("tcp socket:before listen, socket'len={}", self.handle_list.len());
         let mut handle_guard = HANDLE_MAP.write_irqsave();
-        let wait_queue = Arc::clone(&handle_guard.get(&self.socket_handle()).unwrap().wait_queue);
+        let socket_handle_item_0 = handle_guard.get_mut(&self.socket_handle()).unwrap();
+        socket_handle_item_0.is_posix_listen = true;
 
         self.handles.extend((handlen..backlog).map(|_| {
             let socket = Self::create_new_socket();
             let handle = GlobalSocketHandle::new_smoltcp_handle(sockets.add(socket));
-            let handle_item = SocketHandleItem::new(Some(wait_queue.clone()));
+            let mut handle_item = SocketHandleItem::new(Arc::downgrade(&self.posix_item));
+            handle_item.is_posix_listen = true;
             handle_guard.insert(handle, handle_item);
             handle
         }));
-        // debug!("tcp socket:listen, socket'len={}",self.handle.len());
+
+        // debug!("tcp socket:listen, socket'len={}", self.handles.len());
         // debug!("tcp socket:listen, backlog={backlog}");
 
         // 监听所有的socket
@@ -805,6 +846,7 @@ impl Socket for TcpSocket {
             }
             // debug!("Tcp Socket  before listen, open={}", socket.is_open());
         }
+
         return Ok(());
     }
 
@@ -820,6 +862,7 @@ impl Socket for TcpSocket {
 
             self.local_endpoint = Some(ip);
             self.is_listening = false;
+
             return Ok(());
         }
         return Err(SystemError::EINVAL);
@@ -862,8 +905,7 @@ impl Socket for TcpSocket {
                     .remote_endpoint()
                     .ok_or(SystemError::ENOTCONN)?;
 
-                let mut tcp_socket = Self::create_new_socket();
-                self.do_listen(&mut tcp_socket, endpoint)?;
+                let tcp_socket = Self::create_new_socket();
 
                 let new_handle = GlobalSocketHandle::new_smoltcp_handle(sockset.add(tcp_socket));
 
@@ -883,31 +925,40 @@ impl Socket for TcpSocket {
                     local_endpoint: self.local_endpoint,
                     is_listening: false,
                     metadata,
+                    posix_item: Arc::new(PosixSocketHandleItem::new(None)),
                 });
 
                 {
                     let mut handle_guard = HANDLE_MAP.write_irqsave();
                     // 先删除原来的
                     let item = handle_guard.remove(&old_handle).unwrap();
+                    item.reset_shutdown_type();
+                    assert!(item.is_posix_listen);
 
                     // 按照smoltcp行为,将新的handle绑定到原来的item
-                    let new_item = SocketHandleItem::new(None);
+                    let new_item = SocketHandleItem::new(Arc::downgrade(&sock_ret.posix_item));
                     handle_guard.insert(old_handle, new_item);
                     // 插入新的item
                     handle_guard.insert(new_handle, item);
+
+                    let socket = sockset.get_mut::<tcp::Socket>(
+                        self.handles[handle_index].smoltcp_handle().unwrap(),
+                    );
+
+                    if !socket.is_listening() {
+                        self.do_listen(socket, endpoint)?;
+                    }
+
                     drop(handle_guard);
                 }
+
                 return Ok((sock_ret, Endpoint::Ip(Some(remote_ep))));
             }
 
             drop(sockset);
 
             // debug!("[TCP] [Accept] sleeping socket with handle: {:?}", self.handles.get(0).unwrap().smoltcp_handle().unwrap());
-            SocketHandleItem::sleep(
-                self.socket_handle(), // NOTICE
-                Self::CAN_ACCPET,
-                HANDLE_MAP.read_irqsave(),
-            );
+            self.posix_item.sleep(Self::CAN_ACCPET);
             // debug!("tcp socket:after sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len());
         }
     }

+ 95 - 63
kernel/src/net/socket/mod.rs

@@ -22,7 +22,7 @@ use crate::{
         Metadata,
     },
     libs::{
-        rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard},
+        rwlock::{RwLock, RwLockWriteGuard},
         spinlock::{SpinLock, SpinLockGuard},
         wait_queue::EventWaitQueue,
     },
@@ -87,7 +87,7 @@ pub(super) fn new_socket(
         }
     };
 
-    let handle_item = SocketHandleItem::new(None);
+    let handle_item = SocketHandleItem::new(Arc::downgrade(&socket.posix_item()));
     HANDLE_MAP
         .write_irqsave()
         .insert(socket.socket_handle(), handle_item);
@@ -243,36 +243,26 @@ pub trait Socket: Sync + Send + Debug + Any {
     fn as_any_mut(&mut self) -> &mut dyn Any;
 
     fn add_epoll(&mut self, epitem: Arc<EPollItem>) -> Result<(), SystemError> {
-        HANDLE_MAP
-            .write_irqsave()
-            .get_mut(&self.socket_handle())
-            .unwrap()
-            .add_epoll(epitem);
+        let posix_item = self.posix_item();
+        posix_item.add_epoll(epitem);
         Ok(())
     }
 
     fn remove_epoll(&mut self, epoll: &Weak<SpinLock<EventPoll>>) -> Result<(), SystemError> {
-        HANDLE_MAP
-            .write_irqsave()
-            .get_mut(&self.socket_handle())
-            .unwrap()
-            .remove_epoll(epoll)?;
+        let posix_item = self.posix_item();
+        posix_item.remove_epoll(epoll)?;
 
         Ok(())
     }
 
     fn clear_epoll(&mut self) -> Result<(), SystemError> {
-        let mut handle_map_guard = HANDLE_MAP.write_irqsave();
-        let handle_item = handle_map_guard.get_mut(&self.socket_handle()).unwrap();
+        let posix_item = self.posix_item();
 
-        for epitem in handle_item.epitems.lock_irqsave().iter() {
+        for epitem in posix_item.epitems.lock_irqsave().iter() {
             let epoll = epitem.epoll();
-            if epoll.upgrade().is_some() {
-                EventPoll::ep_remove(
-                    &mut epoll.upgrade().unwrap().lock_irqsave(),
-                    epitem.fd(),
-                    None,
-                )?;
+
+            if let Some(epoll) = epoll.upgrade() {
+                EventPoll::ep_remove(&mut epoll.lock_irqsave(), epitem.fd(), None)?;
             }
         }
 
@@ -280,6 +270,8 @@ pub trait Socket: Sync + Send + Debug + Any {
     }
 
     fn close(&mut self);
+
+    fn posix_item(&self) -> Arc<PosixSocketHandleItem>;
 }
 
 impl Clone for Box<dyn Socket> {
@@ -410,54 +402,35 @@ impl IndexNode for SocketInode {
 }
 
 #[derive(Debug)]
-pub struct SocketHandleItem {
-    /// shutdown状态
-    pub shutdown_type: RwLock<ShutdownType>,
+pub struct PosixSocketHandleItem {
     /// socket的waitqueue
-    pub wait_queue: Arc<EventWaitQueue>,
-    /// epitems,考虑写在这是否是最优解?
+    wait_queue: Arc<EventWaitQueue>,
+
     pub epitems: SpinLock<LinkedList<Arc<EPollItem>>>,
 }
 
-impl SocketHandleItem {
+impl PosixSocketHandleItem {
     pub fn new(wait_queue: Option<Arc<EventWaitQueue>>) -> Self {
         Self {
-            shutdown_type: RwLock::new(ShutdownType::empty()),
             wait_queue: wait_queue.unwrap_or(Arc::new(EventWaitQueue::new())),
             epitems: SpinLock::new(LinkedList::new()),
         }
     }
-
     /// ## 在socket的等待队列上睡眠
-    pub fn sleep(
-        socket_handle: GlobalSocketHandle,
-        events: u64,
-        handle_map_guard: RwLockReadGuard<'_, HashMap<GlobalSocketHandle, SocketHandleItem>>,
-    ) {
+    pub fn sleep(&self, events: u64) {
         unsafe {
-            handle_map_guard
-                .get(&socket_handle)
-                .unwrap()
-                .wait_queue
-                .sleep_without_schedule(events)
-        };
-        drop(handle_map_guard);
+            ProcessManager::preempt_disable();
+            self.wait_queue.sleep_without_schedule(events);
+            ProcessManager::preempt_enable();
+        }
         schedule(SchedMode::SM_NONE);
     }
 
-    pub fn shutdown_type(&self) -> ShutdownType {
-        *self.shutdown_type.read()
-    }
-
-    pub fn shutdown_type_writer(&mut self) -> RwLockWriteGuard<ShutdownType> {
-        self.shutdown_type.write_irqsave()
-    }
-
-    pub fn add_epoll(&mut self, epitem: Arc<EPollItem>) {
+    pub fn add_epoll(&self, epitem: Arc<EPollItem>) {
         self.epitems.lock_irqsave().push_back(epitem)
     }
 
-    pub fn remove_epoll(&mut self, epoll: &Weak<SpinLock<EventPoll>>) -> Result<(), SystemError> {
+    pub fn remove_epoll(&self, epoll: &Weak<SpinLock<EventPoll>>) -> Result<(), SystemError> {
         let is_remove = !self
             .epitems
             .lock_irqsave()
@@ -471,6 +444,50 @@ impl SocketHandleItem {
 
         Err(SystemError::ENOENT)
     }
+
+    /// ### 唤醒该队列上等待events的进程
+    ///
+    ///  ### 参数
+    /// - events: 发生的事件
+    ///
+    /// 需要注意的是,只要触发了events中的任意一件事件,进程都会被唤醒
+    pub fn wakeup_any(&self, events: u64) {
+        self.wait_queue.wakeup_any(events);
+    }
+}
+#[derive(Debug)]
+pub struct SocketHandleItem {
+    /// 对应的posix socket是否为listen的
+    pub is_posix_listen: bool,
+    /// shutdown状态
+    pub shutdown_type: RwLock<ShutdownType>,
+    pub posix_item: Weak<PosixSocketHandleItem>,
+}
+
+impl SocketHandleItem {
+    pub fn new(posix_item: Weak<PosixSocketHandleItem>) -> Self {
+        Self {
+            is_posix_listen: false,
+            shutdown_type: RwLock::new(ShutdownType::empty()),
+            posix_item,
+        }
+    }
+
+    pub fn shutdown_type(&self) -> ShutdownType {
+        *self.shutdown_type.read()
+    }
+
+    pub fn shutdown_type_writer(&mut self) -> RwLockWriteGuard<ShutdownType> {
+        self.shutdown_type.write_irqsave()
+    }
+
+    pub fn reset_shutdown_type(&self) {
+        *self.shutdown_type.write() = ShutdownType::empty();
+    }
+
+    pub fn posix_item(&self) -> Option<Arc<PosixSocketHandleItem>> {
+        self.posix_item.upgrade()
+    }
 }
 
 /// # TCP 和 UDP 的端口管理器。
@@ -763,33 +780,47 @@ impl TryFrom<u8> for PosixSocketType {
 pub struct SocketPollMethod;
 
 impl SocketPollMethod {
-    pub fn poll(socket: &socket::Socket, shutdown: ShutdownType) -> EPollEventType {
+    pub fn poll(socket: &socket::Socket, handle_item: &SocketHandleItem) -> EPollEventType {
+        let shutdown = handle_item.shutdown_type();
         match socket {
             socket::Socket::Udp(udp) => Self::udp_poll(udp, shutdown),
-            socket::Socket::Tcp(tcp) => Self::tcp_poll(tcp, shutdown),
+            socket::Socket::Tcp(tcp) => Self::tcp_poll(tcp, shutdown, handle_item.is_posix_listen),
             socket::Socket::Raw(raw) => Self::raw_poll(raw, shutdown),
             _ => todo!(),
         }
     }
 
-    pub fn tcp_poll(socket: &tcp::Socket, shutdown: ShutdownType) -> EPollEventType {
+    pub fn tcp_poll(
+        socket: &tcp::Socket,
+        shutdown: ShutdownType,
+        is_posix_listen: bool,
+    ) -> EPollEventType {
         let mut events = EPollEventType::empty();
-        if socket.is_listening() && socket.is_active() {
-            events.insert(EPollEventType::EPOLLIN | EPollEventType::EPOLLRDNORM);
+        // debug!("enter tcp_poll! is_posix_listen:{}", is_posix_listen);
+        // 处理listen的socket
+        if is_posix_listen {
+            // 如果是listen的socket,那么只有EPOLLIN和EPOLLRDNORM
+            if socket.is_active() {
+                events.insert(EPollEventType::EPOLL_LISTEN_CAN_ACCEPT);
+            }
+
+            // debug!("tcp_poll listen socket! events:{:?}", events);
             return events;
         }
 
-        // socket已经关闭
-        if !socket.is_open() {
-            events.insert(EPollEventType::EPOLLHUP)
+        let state = socket.state();
+
+        if shutdown == ShutdownType::SHUTDOWN_MASK || state == tcp::State::Closed {
+            events.insert(EPollEventType::EPOLLHUP);
         }
+
         if shutdown.contains(ShutdownType::RCV_SHUTDOWN) {
             events.insert(
                 EPollEventType::EPOLLIN | EPollEventType::EPOLLRDNORM | EPollEventType::EPOLLRDHUP,
             );
         }
 
-        let state = socket.state();
+        // Connected or passive Fast Open socket?
         if state != tcp::State::SynSent && state != tcp::State::SynReceived {
             // socket有可读数据
             if socket.can_recv() {
@@ -797,12 +828,12 @@ impl SocketPollMethod {
             }
 
             if !(shutdown.contains(ShutdownType::SEND_SHUTDOWN)) {
-                // 缓冲区可写
+                // 缓冲区可写(这里判断可写的逻辑好像跟linux不太一样)
                 if socket.send_queue() < socket.send_capacity() {
                     events.insert(EPollEventType::EPOLLOUT | EPollEventType::EPOLLWRNORM);
                 } else {
-                    // TODO:触发缓冲区已满的信号
-                    todo!("A signal that the buffer is full needs to be sent");
+                    // TODO:触发缓冲区已满的信号SIGIO
+                    todo!("A signal SIGIO that the buffer is full needs to be sent");
                 }
             } else {
                 // 如果我们的socket关闭了SEND_SHUTDOWN,epoll事件就是EPOLLOUT
@@ -813,6 +844,7 @@ impl SocketPollMethod {
         }
 
         // socket发生错误
+        // TODO: 这里的逻辑可能有问题,需要进一步验证是否is_active()==false就代表socket发生错误
         if !socket.is_active() {
             events.insert(EPollEventType::EPOLLERR);
         }

+ 16 - 1
kernel/src/net/socket/unix.rs

@@ -4,7 +4,8 @@ use system_error::SystemError;
 use crate::{libs::spinlock::SpinLock, net::Endpoint};
 
 use super::{
-    handle::GlobalSocketHandle, Socket, SocketInode, SocketMetadata, SocketOptions, SocketType,
+    handle::GlobalSocketHandle, PosixSocketHandleItem, Socket, SocketInode, SocketMetadata,
+    SocketOptions, SocketType,
 };
 
 #[derive(Debug, Clone)]
@@ -13,6 +14,7 @@ pub struct StreamSocket {
     buffer: Arc<SpinLock<Vec<u8>>>,
     peer_inode: Option<Arc<SocketInode>>,
     handle: GlobalSocketHandle,
+    posix_item: Arc<PosixSocketHandleItem>,
 }
 
 impl StreamSocket {
@@ -36,16 +38,22 @@ impl StreamSocket {
             options,
         );
 
+        let posix_item = Arc::new(PosixSocketHandleItem::new(None));
+
         Self {
             metadata,
             buffer,
             peer_inode: None,
             handle: GlobalSocketHandle::new_kernel_handle(),
+            posix_item,
         }
     }
 }
 
 impl Socket for StreamSocket {
+    fn posix_item(&self) -> Arc<PosixSocketHandleItem> {
+        self.posix_item.clone()
+    }
     fn socket_handle(&self) -> GlobalSocketHandle {
         self.handle
     }
@@ -121,6 +129,7 @@ pub struct SeqpacketSocket {
     buffer: Arc<SpinLock<Vec<u8>>>,
     peer_inode: Option<Arc<SocketInode>>,
     handle: GlobalSocketHandle,
+    posix_item: Arc<PosixSocketHandleItem>,
 }
 
 impl SeqpacketSocket {
@@ -144,16 +153,22 @@ impl SeqpacketSocket {
             options,
         );
 
+        let posix_item = Arc::new(PosixSocketHandleItem::new(None));
+
         Self {
             metadata,
             buffer,
             peer_inode: None,
             handle: GlobalSocketHandle::new_kernel_handle(),
+            posix_item,
         }
     }
 }
 
 impl Socket for SeqpacketSocket {
+    fn posix_item(&self) -> Arc<PosixSocketHandleItem> {
+        self.posix_item.clone()
+    }
     fn close(&mut self) {}
 
     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {