Browse Source

fix(net): misc of resources release (#1096)

* fix: TCP socket miss activation after close

* fix: TCP socket miss activation after close (#1085)

* fix: loopback, udp resource aquire
- remove tcp useless status update
- enable smoltcp medium-ip feature
- change loopback device use ip for addressing, avoid arp procedure
- fix udp couldn't close bug
- fix udp resource aquire didn't lock port
- remove useless Timer in network initialization

* fmt: format

* fix: loopback and udp resource problem (#1086)

* fix: loopback, udp resource aquire
- remove tcp useless status update
- enable smoltcp medium-ip feature
- change loopback device use ip for addressing, avoid arp procedure
- fix udp couldn't close bug
- fix udp resource aquire didn't lock port
- remove useless Timer in network initialization

* fix(net): Unix 资源释放 (#1087)

* unix socket 相关资源释放 #991
* 完善streamsocket资源释放
* 解决inode和id不匹配

* fix TCP socketset release (#1095)

* fix: TCP socket miss activation after close

* fix: loopback, udp resource aquire
- remove tcp useless status update
- enable smoltcp medium-ip feature
- change loopback device use ip for addressing, avoid arp procedure
- fix udp couldn't close bug
- fix udp resource aquire didn't lock port
- remove useless Timer in network initialization

---------

Co-authored-by: YuLong Huang <[email protected]>
Samuel Dai 2 weeks ago
parent
commit
69dde46586

+ 1 - 1
kernel/Cargo.toml

@@ -52,7 +52,7 @@ linkme = "=0.3.27"
 num = { version = "=0.4.0", default-features = false }
 num-derive = "=0.3"
 num-traits = { git = "https://git.mirrors.dragonos.org.cn/DragonOS-Community/num-traits.git", rev="1597c1c", default-features = false }
-smoltcp = { git = "https://git.mirrors.dragonos.org.cn/DragonOS-Community/smoltcp.git", rev = "3e61c909fd540d05575068d16dc4574e196499ed", default-features = false, features = ["log", "alloc",  "socket-raw", "socket-udp", "socket-tcp", "socket-icmp", "socket-dhcpv4", "socket-dns", "proto-ipv4", "proto-ipv6"]}
+smoltcp = { git = "https://git.mirrors.dragonos.org.cn/DragonOS-Community/smoltcp.git", rev = "3e61c909fd540d05575068d16dc4574e196499ed", default-features = false, features = ["log", "alloc",  "socket-raw", "socket-udp", "socket-tcp", "socket-icmp", "socket-dhcpv4", "socket-dns", "proto-ipv4", "proto-ipv6", "medium-ip"]}
 system_error = { path = "crates/system_error" }
 uefi = { version = "=0.26.0", features = ["alloc"] }
 uefi-raw = "=0.5.0"

+ 6 - 4
kernel/src/driver/net/loopback.rs

@@ -204,7 +204,7 @@ impl phy::Device for LoopbackDriver {
         let mut result = phy::DeviceCapabilities::default();
         result.max_transmission_unit = 65535;
         result.max_burst_size = Some(1);
-        result.medium = smoltcp::phy::Medium::Ethernet;
+        result.medium = smoltcp::phy::Medium::Ip;
         return result;
     }
     /// ## Loopback驱动处理接受数据事件
@@ -284,9 +284,11 @@ impl LoopbackInterface {
     pub fn new(mut driver: LoopbackDriver) -> Arc<Self> {
         let iface_id = generate_iface_id();
 
-        let hardware_addr = HardwareAddress::Ethernet(smoltcp::wire::EthernetAddress([
-            0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
-        ]));
+        // let hardware_addr = HardwareAddress::Ethernet(smoltcp::wire::EthernetAddress([
+        //     0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+        // ]));
+
+        let hardware_addr = HardwareAddress::Ip;
 
         let mut iface_config = smoltcp::iface::Config::new(hardware_addr);
 

+ 8 - 0
kernel/src/driver/net/mod.rs

@@ -285,6 +285,14 @@ impl IfaceCommon {
         self.bounds.write().push(socket);
     }
 
+    pub fn unbind_socket(&self, socket: Arc<dyn InetSocket>) {
+        let mut bounds = self.bounds.write();
+        if let Some(index) = bounds.iter().position(|s| Arc::ptr_eq(s, &socket)) {
+            bounds.remove(index);
+            log::debug!("unbind socket success");
+        }
+    }
+
     // TODO: 需要在inet实现多网卡监听或路由子系统实现后移除
     pub fn is_default_iface(&self) -> bool {
         self.default_iface

+ 15 - 35
kernel/src/net/net_core.rs

@@ -1,4 +1,4 @@
-use alloc::{boxed::Box, collections::BTreeMap, sync::Arc};
+use alloc::{collections::BTreeMap, sync::Arc};
 use log::{debug, info, warn};
 use smoltcp::{socket::dhcpv4, wire};
 use system_error::SystemError;
@@ -7,45 +7,23 @@ use crate::{
     driver::net::{Iface, Operstate},
     libs::rwlock::RwLockReadGuard,
     net::NET_DEVICES,
-    time::{
-        sleep::nanosleep,
-        timer::{next_n_ms_timer_jiffies, Timer, TimerFunction},
-        PosixTimeSpec,
-    },
+    time::{sleep::nanosleep, PosixTimeSpec},
 };
 
-/// The network poll function, which will be called by timer.
-///
-/// The main purpose of this function is to poll all network interfaces.
-#[derive(Debug)]
-#[allow(dead_code)]
-struct NetWorkPollFunc;
-
-impl TimerFunction for NetWorkPollFunc {
-    fn run(&mut self) -> Result<(), SystemError> {
-        poll_ifaces();
-        let next_time = next_n_ms_timer_jiffies(10);
-        let timer = Timer::new(Box::new(NetWorkPollFunc), next_time);
-        timer.activate();
-        return Ok(());
-    }
-}
-
 pub fn net_init() -> Result<(), SystemError> {
-    dhcp_query()?;
-    // Init poll timer function
-    // let next_time = next_n_ms_timer_jiffies(5);
-    // let timer = Timer::new(Box::new(NetWorkPollFunc), next_time);
-    // timer.activate();
-    return Ok(());
+    dhcp_query()
 }
 
 fn dhcp_query() -> Result<(), SystemError> {
     let binding = NET_DEVICES.write_irqsave();
-    // log::debug!("binding: {:?}", *binding);
-    //由于现在os未实现在用户态为网卡动态分配内存,而lo网卡的id最先分配且ip固定不能被分配
-    //所以特判取用id为1的网卡(也就是virtio_net)
-    let net_face = binding.get(&1).ok_or(SystemError::ENODEV)?.clone();
+
+    // Default iface, misspelled to net_face
+    let net_face = binding
+        .iter()
+        .find(|(_, iface)| iface.common().is_default_iface())
+        .unwrap()
+        .1
+        .clone();
 
     drop(binding);
 
@@ -60,8 +38,10 @@ fn dhcp_query() -> Result<(), SystemError> {
 
     let sockets = || net_face.sockets().lock_irqsave();
 
-    // let dhcp_handle = SOCKET_SET.lock_irqsave().add(dhcp_socket);
     let dhcp_handle = sockets().add(dhcp_socket);
+    defer::defer!({
+        sockets().remove(dhcp_handle);
+    });
 
     const DHCP_TRY_ROUND: u8 = 100;
     for i in 0..DHCP_TRY_ROUND {
@@ -147,7 +127,7 @@ fn dhcp_query() -> Result<(), SystemError> {
 }
 
 pub fn poll_ifaces() {
-    log::debug!("poll_ifaces");
+    // log::debug!("poll_ifaces");
     let guard: RwLockReadGuard<BTreeMap<usize, Arc<dyn Iface>>> = NET_DEVICES.read_irqsave();
     if guard.len() == 0 {
         warn!("poll_ifaces: No net driver found!");

+ 1 - 1
kernel/src/net/socket/common/shutdown.rs

@@ -124,7 +124,7 @@ impl TryFrom<usize> for ShutdownTemp {
 
     fn try_from(value: usize) -> Result<Self, Self::Error> {
         match value {
-            0 | 1 | 2 => Ok(ShutdownTemp {
+            0..2 => Ok(ShutdownTemp {
                 bit: value as u8 + 1,
             }),
             _ => Err(SystemError::EINVAL),

+ 2 - 2
kernel/src/net/socket/inet/common/mod.rs

@@ -53,11 +53,11 @@ impl BoundInner {
                 })
                 .expect("No default interface");
 
-            let handle = iface.sockets().lock_no_preempt().add(socket);
+            let handle = iface.sockets().lock_irqsave().add(socket);
             return Ok(Self { handle, iface });
         } else {
             let iface = get_iface_to_bind(address).ok_or(ENODEV)?;
-            let handle = iface.sockets().lock_no_preempt().add(socket);
+            let handle = iface.sockets().lock_irqsave().add(socket);
             return Ok(Self { handle, iface });
         }
     }

+ 3 - 9
kernel/src/net/socket/inet/datagram/inner.rs

@@ -33,16 +33,14 @@ impl UnboundUdp {
     }
 
     pub fn bind(self, local_endpoint: smoltcp::wire::IpEndpoint) -> Result<BoundUdp, SystemError> {
-        // let (addr, port) = (local_endpoint.addr, local_endpoint.port);
-        // if self.socket.bind(local_endpoint).is_err() {
-        //     log::debug!("bind failed!");
-        //     return Err(EINVAL);
-        // }
         let inner = BoundInner::bind(self.socket, &local_endpoint.addr)?;
         let bind_addr = local_endpoint.addr;
         let bind_port = if local_endpoint.port == 0 {
             inner.port_manager().bind_ephemeral_port(InetTypes::Udp)?
         } else {
+            inner
+                .port_manager()
+                .bind_port(InetTypes::Udp, local_endpoint.port)?;
             local_endpoint.port
         };
 
@@ -77,10 +75,6 @@ impl UnboundUdp {
             remote: SpinLock::new(Some(endpoint)),
         })
     }
-
-    pub fn close(&mut self) {
-        self.socket.close();
-    }
 }
 
 #[derive(Debug)]

+ 20 - 12
kernel/src/net/socket/inet/datagram/mod.rs

@@ -78,28 +78,31 @@ impl UdpSocket {
             bound.close();
             inner.take();
         }
+        // unbound socket just drop (only need to free memory)
     }
 
     pub fn try_recv(
         &self,
         buf: &mut [u8],
     ) -> Result<(usize, smoltcp::wire::IpEndpoint), SystemError> {
-        let received = match self.inner.read().as_ref().expect("Udp Inner is None") {
-            UdpInner::Bound(bound) => bound.try_recv(buf),
+        match self.inner.read().as_ref().expect("Udp Inner is None") {
+            UdpInner::Bound(bound) => {
+                let ret = bound.try_recv(buf);
+                poll_ifaces();
+                ret
+            }
             _ => Err(ENOTCONN),
-        };
-        poll_ifaces();
-        return received;
+        }
     }
 
     #[inline]
     pub fn can_recv(&self) -> bool {
-        self.on_events().contains(EP::EPOLLIN)
+        self.event().contains(EP::EPOLLIN)
     }
 
     #[inline]
     pub fn can_send(&self) -> bool {
-        self.on_events().contains(EP::EPOLLOUT)
+        self.event().contains(EP::EPOLLOUT)
     }
 
     pub fn try_send(
@@ -138,7 +141,7 @@ impl UdpSocket {
         }
     }
 
-    pub fn on_events(&self) -> EPollEventType {
+    pub fn event(&self) -> EPollEventType {
         let mut event = EPollEventType::empty();
         match self.inner.read().as_ref().unwrap() {
             UdpInner::Unbound(_) => {
@@ -154,8 +157,6 @@ impl UdpSocket {
 
                 if can_send {
                     event.insert(EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND);
-                } else {
-                    todo!("缓冲区空间不够,需要使用信号处理");
                 }
             }
         }
@@ -169,7 +170,7 @@ impl Socket for UdpSocket {
     }
 
     fn poll(&self) -> usize {
-        self.on_events().bits() as usize
+        self.event().bits() as usize
     }
 
     fn bind(&self, local_endpoint: Endpoint) -> Result<(), SystemError> {
@@ -195,7 +196,9 @@ impl Socket for UdpSocket {
 
     fn connect(&self, endpoint: Endpoint) -> Result<(), SystemError> {
         if let Endpoint::Ip(remote) = endpoint {
-            self.bind_emphemeral(remote.addr)?;
+            if !self.is_bound() {
+                self.bind_emphemeral(remote.addr)?;
+            }
             if let UdpInner::Bound(inner) = self.inner.read().as_ref().expect("UDP Inner disappear")
             {
                 inner.connect(remote);
@@ -272,6 +275,11 @@ impl Socket for UdpSocket {
         }
         .map(|(len, remote)| (len, Endpoint::Ip(remote)));
     }
+
+    fn close(&self) -> Result<(), SystemError> {
+        self.close();
+        Ok(())
+    }
 }
 
 impl InetSocket for UdpSocket {

+ 18 - 6
kernel/src/net/socket/inet/stream/inner.rs

@@ -268,6 +268,15 @@ impl Connecting {
                     .expect("A Connecting Tcp With No Local Endpoint")
             })
     }
+
+    pub fn get_peer_name(&self) -> smoltcp::wire::IpEndpoint {
+        self.inner
+            .with::<smoltcp::socket::tcp::Socket, _, _>(|socket| {
+                socket
+                    .remote_endpoint()
+                    .expect("A Connecting Tcp With No Remote Endpoint")
+            })
+    }
 }
 
 #[derive(Debug)]
@@ -355,6 +364,13 @@ impl Listening {
             .port_manager()
             .unbind_port(Types::Tcp, port);
     }
+
+    pub fn release(&self) {
+        // log::debug!("Release Listening Socket");
+        for inner in self.inners.iter() {
+            inner.release();
+        }
+    }
 }
 
 #[derive(Debug)]
@@ -370,10 +386,6 @@ impl Established {
         self.inner.with_mut(f)
     }
 
-    pub fn with<R, F: Fn(&smoltcp::socket::tcp::Socket<'static>) -> R>(&self, f: F) -> R {
-        self.inner.with(f)
-    }
-
     pub fn close(&self) {
         self.inner
             .with_mut::<smoltcp::socket::tcp::Socket, _, _>(|socket| socket.close());
@@ -384,13 +396,13 @@ impl Established {
         self.inner.release();
     }
 
-    pub fn local_endpoint(&self) -> smoltcp::wire::IpEndpoint {
+    pub fn get_name(&self) -> smoltcp::wire::IpEndpoint {
         self.inner
             .with::<smoltcp::socket::tcp::Socket, _, _>(|socket| socket.local_endpoint())
             .unwrap()
     }
 
-    pub fn remote_endpoint(&self) -> smoltcp::wire::IpEndpoint {
+    pub fn get_peer_name(&self) -> smoltcp::wire::IpEndpoint {
         self.inner
             .with::<smoltcp::socket::tcp::Socket, _, _>(|socket| socket.remote_endpoint().unwrap())
     }

+ 26 - 18
kernel/src/net/socket/inet/stream/mod.rs

@@ -99,7 +99,6 @@ impl TcpSocket {
     }
 
     pub fn try_accept(&self) -> Result<(Arc<TcpSocket>, smoltcp::wire::IpEndpoint), SystemError> {
-        // poll_ifaces();
         match self.inner.write().as_mut().expect("Tcp Inner is None") {
             Inner::Listening(listening) => listening.accept().map(|(stream, remote)| {
                 (
@@ -227,16 +226,9 @@ impl TcpSocket {
         }
     }
 
-    fn in_notify(&self) -> bool {
-        self.update_events();
-        // shouldn't pollee but just get the status of the socket
+    fn incoming(&self) -> bool {
         EP::from_bits_truncate(self.poll() as u32).contains(EP::EPOLLIN)
     }
-
-    fn out_notify(&self) -> bool {
-        self.update_events();
-        EP::from_bits_truncate(self.poll() as u32).contains(EP::EPOLLOUT)
-    }
 }
 
 impl Socket for TcpSocket {
@@ -252,16 +244,25 @@ impl Socket for TcpSocket {
             })),
             Inner::Init(Init::Bound((_, local))) => Ok(Endpoint::Ip(*local)),
             Inner::Connecting(connecting) => Ok(Endpoint::Ip(connecting.get_name())),
-            Inner::Established(established) => Ok(Endpoint::Ip(established.local_endpoint())),
+            Inner::Established(established) => Ok(Endpoint::Ip(established.get_name())),
             Inner::Listening(listening) => Ok(Endpoint::Ip(listening.get_name())),
         }
     }
 
+    fn get_peer_name(&self) -> Result<Endpoint, SystemError> {
+        match self.inner.read().as_ref().expect("Tcp Inner is None") {
+            Inner::Init(_) => Err(ENOTCONN),
+            Inner::Connecting(connecting) => Ok(Endpoint::Ip(connecting.get_peer_name())),
+            Inner::Established(established) => Ok(Endpoint::Ip(established.get_peer_name())),
+            Inner::Listening(_) => Err(ENOTCONN),
+        }
+    }
+
     fn bind(&self, endpoint: Endpoint) -> Result<(), SystemError> {
         if let Endpoint::Ip(addr) = endpoint {
             return self.do_bind(addr);
         }
-        log::warn!("TcpSocket::bind: invalid endpoint");
+        log::debug!("TcpSocket::bind: invalid endpoint");
         return Err(EINVAL);
     }
 
@@ -295,7 +296,7 @@ impl Socket for TcpSocket {
             loop {
                 match self.try_accept() {
                     Err(EAGAIN_OR_EWOULDBLOCK) => {
-                        wq_wait_event_interruptible!(self.wait_queue, self.in_notify(), {})?;
+                        wq_wait_event_interruptible!(self.wait_queue, self.incoming(), {})?;
                     }
                     result => break result,
                 }
@@ -348,7 +349,15 @@ impl Socket for TcpSocket {
     }
 
     fn close(&self) -> Result<(), SystemError> {
-        let inner = self.inner.write().take().unwrap();
+        let Some(inner) = self.inner.write().take() else {
+            log::warn!("TcpSocket::close: already closed, unexpected");
+            return Ok(());
+        };
+        if let Some(iface) = inner.iface() {
+            iface
+                .common()
+                .unbind_socket(self.self_ref.upgrade().unwrap());
+        }
 
         match inner {
             // complete connecting socket close logic
@@ -356,22 +365,21 @@ impl Socket for TcpSocket {
                 let conn = unsafe { conn.into_established() };
                 conn.close();
                 conn.release();
-                Ok(())
             }
             Inner::Established(es) => {
                 es.close();
                 es.release();
-                Ok(())
             }
             Inner::Listening(ls) => {
                 ls.close();
-                Ok(())
+                ls.release();
             }
             Inner::Init(init) => {
                 init.close();
-                Ok(())
             }
-        }
+        };
+
+        Ok(())
     }
 
     fn set_option(&self, level: PSOL, name: usize, val: &[u8]) -> Result<(), SystemError> {

+ 46 - 7
kernel/src/net/socket/unix/seqpacket/mod.rs

@@ -290,19 +290,58 @@ impl Socket for SeqpacketSocket {
         self.shutdown.recv_shutdown();
         self.shutdown.send_shutdown();
 
-        let path = match self.get_name()? {
+        let endpoint = self.get_name()?;
+        let path = match &endpoint {
             Endpoint::Inode((_, path)) => path,
+            Endpoint::Unixpath((_, path)) => path,
+            Endpoint::Abspath((_, path)) => path,
             _ => return Err(SystemError::EINVAL),
         };
 
-        //如果path是空的说明没有bind,不用释放相关映射资源
         if path.is_empty() {
             return Ok(());
         }
-        // TODO: 释放INODE_MAP相关资源
 
-        // 尝试释放相关抽象地址资源
-        let _ = remove_abs_addr(&path);
+        match &endpoint {
+            Endpoint::Unixpath((inode_id, _)) => {
+                let mut inode_guard = INODE_MAP.write_irqsave();
+                inode_guard.remove(inode_id);
+            }
+            Endpoint::Inode((current_inode, current_path)) => {
+                let mut inode_guard = INODE_MAP.write_irqsave();
+                // 遍历查找匹配的条目
+                let target_entry = inode_guard
+                    .iter()
+                    .find(|(_, ep)| {
+                        if let Endpoint::Inode((map_inode, map_path)) = ep {
+                            // 通过指针相等性比较确保是同一对象
+                            Arc::ptr_eq(map_inode, current_inode) && map_path == current_path
+                        } else {
+                            log::debug!("not match");
+                            false
+                        }
+                    })
+                    .map(|(id, _)| *id);
+
+                if let Some(id) = target_entry {
+                    inode_guard.remove(&id).ok_or(SystemError::EINVAL)?;
+                }
+            }
+            Endpoint::Abspath((abshandle, _)) => {
+                let mut abs_inode_map = ABS_INODE_MAP.lock_irqsave();
+                abs_inode_map.remove(&abshandle.name());
+            }
+            _ => {
+                log::error!("invalid endpoint type");
+                return Err(SystemError::EINVAL);
+            }
+        }
+
+        *self.inner.write() = Inner::Init(Init::new());
+        self.wait_queue.wakeup(None);
+
+        let _ = remove_abs_addr(path);
+
         return Ok(());
     }
 
@@ -471,12 +510,12 @@ impl Socket for SeqpacketSocket {
     }
 
     fn send_buffer_size(&self) -> usize {
-        log::warn!("using default buffer size");
+        // log::warn!("using default buffer size");
         SeqpacketSocket::DEFAULT_BUF_SIZE
     }
 
     fn recv_buffer_size(&self) -> usize {
-        log::warn!("using default buffer size");
+        // log::warn!("using default buffer size");
         SeqpacketSocket::DEFAULT_BUF_SIZE
     }
 

+ 45 - 6
kernel/src/net/socket/unix/stream/mod.rs

@@ -322,20 +322,59 @@ impl Socket for StreamSocket {
         self.shutdown.recv_shutdown();
         self.shutdown.send_shutdown();
 
-        let path = match self.get_name()? {
+        let endpoint = self.get_name()?;
+        let path = match &endpoint {
             Endpoint::Inode((_, path)) => path,
+            Endpoint::Unixpath((_, path)) => path,
+            Endpoint::Abspath((_, path)) => path,
             _ => return Err(SystemError::EINVAL),
         };
 
-        //如果path是空的说明没有bind,不用释放相关映射资源
         if path.is_empty() {
             return Ok(());
         }
-        // TODO: 释放INODE_MAP相关资源
 
-        // 尝试释放相关抽象地址资源
-        let _ = remove_abs_addr(&path);
-        return Ok(());
+        match &endpoint {
+            Endpoint::Unixpath((inode_id, _)) => {
+                let mut inode_guard = INODE_MAP.write_irqsave();
+                inode_guard.remove(inode_id);
+            }
+            Endpoint::Inode((current_inode, current_path)) => {
+                let mut inode_guard = INODE_MAP.write_irqsave();
+                // 遍历查找匹配的条目
+                let target_entry = inode_guard
+                    .iter()
+                    .find(|(_, ep)| {
+                        if let Endpoint::Inode((map_inode, map_path)) = ep {
+                            // 通过指针相等性比较确保是同一对象
+                            Arc::ptr_eq(map_inode, current_inode) && map_path == current_path
+                        } else {
+                            log::debug!("not match");
+                            false
+                        }
+                    })
+                    .map(|(id, _)| *id);
+
+                if let Some(id) = target_entry {
+                    inode_guard.remove(&id).ok_or(SystemError::EINVAL)?;
+                }
+            }
+            Endpoint::Abspath((abshandle, _)) => {
+                let mut abs_inode_map = ABS_INODE_MAP.lock_irqsave();
+                abs_inode_map.remove(&abshandle.name());
+            }
+            _ => {
+                log::error!("invalid endpoint type");
+                return Err(SystemError::EINVAL);
+            }
+        }
+
+        *self.inner.write() = Inner::Init(Init::new());
+        self.wait_queue.wakeup(None);
+
+        let _ = remove_abs_addr(path);
+
+        Ok(())
     }
 
     fn get_peer_name(&self) -> Result<Endpoint, SystemError> {

+ 2 - 0
user/apps/test_unix_stream_socket/src/main.rs

@@ -138,7 +138,9 @@ fn test_stream() -> Result<(), Error> {
         send_message(client_fd, MSG2).expect("Failed to send message");
         println!("Server send finish");
 
+        println!("Server begin close!");
         unsafe { close(server_fd) };
+        println!("Server close finish!");
     });
 
     let client_fd = create_stream_socket()?;