Browse Source

feat(net): 实现unix抽象地址空间 (#1017)

Cai Junyuan 4 months ago
parent
commit
fad1c09757

+ 39 - 1
kernel/src/net/posix.rs

@@ -35,8 +35,10 @@ impl PosixArgsSocketType {
     }
 }
 
+use alloc::string::String;
 use alloc::sync::Arc;
 use core::ffi::CStr;
+use unix::ns::abs::{alloc_abs_addr, look_up_abs_addr};
 
 use crate::{
     filesystem::vfs::{FileType, IndexNode, ROOT_INODE, VFS_MAX_FOLLOW_SYMLINK_TIMES},
@@ -45,7 +47,7 @@ use crate::{
     process::ProcessManager,
 };
 use smoltcp;
-use system_error::SystemError;
+use system_error::SystemError::{self};
 
 // 参考资料: https://pubs.opengroup.org/onlinepubs/9699919799/basedefs/netinet_in.h.html#tag_13_32
 #[repr(C)]
@@ -146,6 +148,42 @@ impl SockAddr {
                 AddressFamily::Unix => {
                     let addr_un: SockAddrUn = addr.addr_un;
 
+                    if addr_un.sun_path[0] == 0 {
+                        // 抽象地址空间,与文件系统没有关系
+                        let path = CStr::from_bytes_until_nul(&addr_un.sun_path[1..])
+                            .map_err(|_| {
+                                log::error!("CStr::from_bytes_until_nul fail");
+                                SystemError::EINVAL
+                            })?
+                            .to_str()
+                            .map_err(|_| {
+                                log::error!("CStr::to_str fail");
+                                SystemError::EINVAL
+                            })?;
+
+                        // 向抽象地址管理器申请或查找抽象地址
+                        let spath = String::from(path);
+                        log::debug!("abs path: {}", spath);
+                        let abs_find = match look_up_abs_addr(&spath) {
+                            Ok(result) => result,
+                            Err(_) => {
+                                //未找到尝试分配abs
+                                match alloc_abs_addr(spath.clone()) {
+                                    Ok(result) => {
+                                        log::debug!("alloc abs addr success!");
+                                        return Ok(result);
+                                    }
+                                    Err(e) => {
+                                        log::debug!("alloc abs addr failed!");
+                                        return Err(e);
+                                    }
+                                };
+                            }
+                        };
+                        log::debug!("find alloc abs addr success!");
+                        return Ok(abs_find);
+                    }
+
                     let path = CStr::from_bytes_until_nul(&addr_un.sun_path)
                         .map_err(|_| {
                             log::error!("CStr::from_bytes_until_nul fail");

+ 4 - 0
kernel/src/net/socket/endpoint.rs

@@ -3,6 +3,8 @@ use alloc::{string::String, sync::Arc};
 
 pub use smoltcp::wire::IpEndpoint;
 
+use super::unix::ns::abs::AbsHandle;
+
 #[derive(Debug, Clone)]
 pub enum Endpoint {
     /// 链路层端点
@@ -13,6 +15,8 @@ pub enum Endpoint {
     Inode((Arc<socket::Inode>, String)),
     /// Unix传递id索引和path所用的端点
     Unixpath((InodeId, String)),
+    /// Unix抽象端点
+    Abspath((AbsHandle, String)),
 }
 
 /// @brief 链路层端点

+ 2 - 1
kernel/src/net/socket/unix/mod.rs

@@ -1,5 +1,6 @@
+pub mod ns;
 pub(crate) mod seqpacket;
-mod stream;
+pub mod stream;
 use crate::{filesystem::vfs::InodeId, libs::rwlock::RwLock, net::socket::*};
 use alloc::sync::Arc;
 use hashbrown::HashMap;

+ 172 - 0
kernel/src/net/socket/unix/ns/abs.rs

@@ -0,0 +1,172 @@
+use core::fmt;
+
+use crate::libs::spinlock::SpinLock;
+use crate::net::socket::Endpoint;
+use alloc::string::String;
+use hashbrown::HashMap;
+use ida::IdAllocator;
+use system_error::SystemError;
+
+lazy_static! {
+    pub static ref ABSHANDLE_MAP: AbsHandleMap = AbsHandleMap::new();
+}
+
+lazy_static! {
+    pub static ref ABS_INODE_MAP: SpinLock<HashMap<usize, Endpoint>> =
+        SpinLock::new(HashMap::new());
+}
+
+static ABS_ADDRESS_ALLOCATOR: SpinLock<IdAllocator> =
+    SpinLock::new(IdAllocator::new(0, (1 << 20) as usize).unwrap());
+
+#[derive(Debug, Clone)]
+pub struct AbsHandle(usize);
+
+impl AbsHandle {
+    pub fn new(name: usize) -> Self {
+        Self(name)
+    }
+
+    pub fn name(&self) -> usize {
+        self.0
+    }
+}
+
+impl fmt::Display for AbsHandle {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "{:05x}", self.0)
+    }
+}
+
+/// 抽象地址映射表
+///
+/// 负责管理抽象命名空间内的地址
+pub struct AbsHandleMap {
+    abs_handle_map: SpinLock<HashMap<String, Endpoint>>,
+}
+
+impl AbsHandleMap {
+    pub fn new() -> Self {
+        Self {
+            abs_handle_map: SpinLock::new(HashMap::new()),
+        }
+    }
+
+    /// 插入新的地址映射
+    pub fn insert(&self, name: String) -> Result<Endpoint, SystemError> {
+        let mut guard = self.abs_handle_map.lock();
+
+        //检查name是否被占用
+        if guard.contains_key(&name) {
+            return Err(SystemError::ENOMEM);
+        }
+
+        let ads_addr = match self.alloc(name.clone()) {
+            Some(addr) => addr.clone(),
+            None => return Err(SystemError::ENOMEM),
+        };
+        guard.insert(name, ads_addr.clone());
+        return Ok(ads_addr);
+    }
+
+    /// 抽象空间地址分配器
+    ///
+    /// ## 返回
+    ///
+    /// 分配到的可用的抽象端点
+    pub fn alloc(&self, name: String) -> Option<Endpoint> {
+        let abs_addr = match ABS_ADDRESS_ALLOCATOR.lock().alloc() {
+            Some(addr) => addr,
+            //地址被分配
+            None => return None,
+        };
+
+        let result = Some(Endpoint::Abspath((AbsHandle::new(abs_addr), name)));
+
+        return result;
+    }
+
+    /// 进行地址映射
+    ///
+    /// ## 参数
+    ///
+    /// name:用户定义的地址
+    pub fn look_up(&self, name: &String) -> Option<Endpoint> {
+        let guard = self.abs_handle_map.lock();
+        return guard.get(name).cloned();
+    }
+
+    /// 移除绑定的地址
+    ///
+    /// ## 参数
+    ///
+    /// name:待删除的地址
+    pub fn remove(&self, name: &String) -> Result<(), SystemError> {
+        let abs_addr = match look_up_abs_addr(name) {
+            Ok(result) => match result {
+                Endpoint::Abspath((abshandle, _)) => abshandle.name(),
+                _ => return Err(SystemError::EINVAL),
+            },
+            Err(_) => return Err(SystemError::EINVAL),
+        };
+
+        //释放abs地址分配实例
+        ABS_ADDRESS_ALLOCATOR.lock().free(abs_addr);
+
+        //释放entry
+        let mut guard = self.abs_handle_map.lock();
+        guard.remove(name);
+
+        Ok(())
+    }
+}
+
+/// 分配抽象地址
+///
+/// ## 返回
+///
+/// 分配到的抽象地址
+pub fn alloc_abs_addr(name: String) -> Result<Endpoint, SystemError> {
+    ABSHANDLE_MAP.insert(name)
+}
+
+/// 查找抽象地址
+///
+/// ## 参数
+///
+/// name:用户socket字符地址
+///
+/// ## 返回
+///
+/// 查询到的抽象地址
+pub fn look_up_abs_addr(name: &String) -> Result<Endpoint, SystemError> {
+    match ABSHANDLE_MAP.look_up(name) {
+        Some(result) => return Ok(result),
+        None => return Err(SystemError::EINVAL),
+    }
+}
+
+/// 删除抽象地址
+///
+/// ## 参数
+/// name:待删除的地址
+///
+/// ## 返回
+/// 删除的抽象地址
+pub fn remove_abs_addr(name: &String) -> Result<(), SystemError> {
+    let abs_addr = match look_up_abs_addr(name) {
+        Ok(addr) => match addr {
+            Endpoint::Abspath((addr, _)) => addr,
+            _ => return Err(SystemError::EINVAL),
+        },
+        Err(_) => return Err(SystemError::EINVAL),
+    };
+
+    match ABS_INODE_MAP.lock_irqsave().remove(&abs_addr.name()) {
+        Some(_) => log::debug!("free abs inode"),
+        None => log::debug!("not free abs inode"),
+    }
+    ABSHANDLE_MAP.remove(name)?;
+    log::debug!("free abs!");
+    Ok(())
+}

+ 1 - 0
kernel/src/net/socket/unix/ns/mod.rs

@@ -0,0 +1 @@
+pub mod abs;

+ 44 - 1
kernel/src/net/socket/unix/seqpacket/mod.rs

@@ -4,6 +4,7 @@ use alloc::{
     sync::{Arc, Weak},
 };
 use core::sync::atomic::{AtomicBool, Ordering};
+use unix::ns::abs::{remove_abs_addr, ABS_INODE_MAP};
 
 use crate::sched::SchedMode;
 use crate::{libs::rwlock::RwLock, net::socket::*};
@@ -136,6 +137,23 @@ impl Socket for SeqpacketSocket {
                     _ => return Err(SystemError::EINVAL),
                 }
             }
+            Endpoint::Abspath((abs_addr, _)) => {
+                let inode_guard = ABS_INODE_MAP.lock_irqsave();
+                let inode = match inode_guard.get(&abs_addr.name()) {
+                    Some(inode) => inode,
+                    None => {
+                        log::debug!("can not find inode from absInodeMap");
+                        return Err(SystemError::EINVAL);
+                    }
+                };
+                match inode {
+                    Endpoint::Inode((inode, _)) => inode.clone(),
+                    _ => {
+                        log::debug!("when connect, find inode failed!");
+                        return Err(SystemError::EINVAL);
+                    }
+                }
+            }
             _ => return Err(SystemError::EINVAL),
         };
         // 远端为服务端
@@ -197,6 +215,17 @@ impl Socket for SeqpacketSocket {
                 INODE_MAP.write_irqsave().insert(inodeid, inode);
                 Ok(())
             }
+            Endpoint::Abspath((abshandle, path)) => {
+                let inode = match &mut *self.inner.write() {
+                    Inner::Init(init) => init.bind_path(path)?,
+                    _ => {
+                        log::error!("socket has listen or connected");
+                        return Err(SystemError::EINVAL);
+                    }
+                };
+                ABS_INODE_MAP.lock_irqsave().insert(abshandle.name(), inode);
+                Ok(())
+            }
             _ => return Err(SystemError::EINVAL),
         }
     }
@@ -260,7 +289,21 @@ impl Socket for SeqpacketSocket {
         // log::debug!("seqpacket close");
         self.shutdown.recv_shutdown();
         self.shutdown.send_shutdown();
-        Ok(())
+
+        let path = match self.get_name()? {
+            Endpoint::Inode((_, path)) => path,
+            _ => return Err(SystemError::EINVAL),
+        };
+
+        //如果path是空的说明没有bind,不用释放相关映射资源
+        if path.is_empty() {
+            return Ok(());
+        }
+        // TODO: 释放INODE_MAP相关资源
+
+        // 尝试释放相关抽象地址资源
+        let _ = remove_abs_addr(&path);
+        return Ok(());
     }
 
     fn get_peer_name(&self) -> Result<Endpoint, SystemError> {

+ 2 - 1
kernel/src/net/socket/unix/stream/inner.rs

@@ -49,8 +49,9 @@ impl Init {
         }
         if let Some(Endpoint::Inode((inode, mut path))) = self.addr.take() {
             path = sun_path;
-            let epoint = Endpoint::Inode((inode, path));
+            let epoint = Endpoint::Inode((inode, path.clone()));
             self.addr.replace(epoint.clone());
+            log::debug!("bind path in inode : {:?}", path);
             return Ok(epoint);
         };
 

+ 47 - 2
kernel/src/net/socket/unix/stream/mod.rs

@@ -6,7 +6,10 @@ use alloc::{
 use inner::{Connected, Init, Inner, Listener};
 use log::debug;
 use system_error::SystemError;
-use unix::INODE_MAP;
+use unix::{
+    ns::abs::{remove_abs_addr, ABSHANDLE_MAP, ABS_INODE_MAP},
+    INODE_MAP,
+};
 
 use crate::{
     libs::rwlock::RwLock,
@@ -157,6 +160,23 @@ impl Socket for StreamSocket {
                     _ => return Err(SystemError::EINVAL),
                 }
             }
+            Endpoint::Abspath((abs_addr, path)) => {
+                let inode_guard = ABS_INODE_MAP.lock_irqsave();
+                let inode = match inode_guard.get(&abs_addr.name()) {
+                    Some(inode) => inode,
+                    None => {
+                        log::debug!("can not find inode from absInodeMap");
+                        return Err(SystemError::EINVAL);
+                    }
+                };
+                match inode {
+                    Endpoint::Inode((inode, _)) => (inode.clone(), path),
+                    _ => {
+                        debug!("when connect, find inode failed!");
+                        return Err(SystemError::EINVAL);
+                    }
+                }
+            }
             _ => return Err(SystemError::EINVAL),
         };
 
@@ -200,6 +220,17 @@ impl Socket for StreamSocket {
                 INODE_MAP.write_irqsave().insert(inodeid, inode);
                 Ok(())
             }
+            Endpoint::Abspath((abshandle, path)) => {
+                let inode = match &mut *self.inner.write() {
+                    Inner::Init(init) => init.bind_path(path)?,
+                    _ => {
+                        log::error!("socket has listen or connected");
+                        return Err(SystemError::EINVAL);
+                    }
+                };
+                ABS_INODE_MAP.lock_irqsave().insert(abshandle.name(), inode);
+                Ok(())
+            }
             _ => return Err(SystemError::EINVAL),
         }
     }
@@ -290,7 +321,21 @@ impl Socket for StreamSocket {
     fn close(&self) -> Result<(), SystemError> {
         self.shutdown.recv_shutdown();
         self.shutdown.send_shutdown();
-        Ok(())
+
+        let path = match self.get_name()? {
+            Endpoint::Inode((_, path)) => path,
+            _ => return Err(SystemError::EINVAL),
+        };
+
+        //如果path是空的说明没有bind,不用释放相关映射资源
+        if path.is_empty() {
+            return Ok(());
+        }
+        // TODO: 释放INODE_MAP相关资源
+
+        // 尝试释放相关抽象地址资源
+        let _ = remove_abs_addr(&path);
+        return Ok(());
     }
 
     fn get_peer_name(&self) -> Result<Endpoint, SystemError> {

+ 144 - 2
user/apps/test_unix_stream_socket/src/main.rs

@@ -5,7 +5,8 @@ use std::io::Error;
 use std::mem;
 use std::os::fd::RawFd;
 
-const SOCKET_PATH: &str = "/test.stream";
+const SOCKET_PATH: &str = "./test.stream";
+const SOCKET_ABSTRUCT_PATH: &str = "/abs.stream";
 const MSG1: &str = "Hello, unix stream socket from Client!";
 const MSG2: &str = "Hello, unix stream socket from Server!";
 
@@ -44,6 +45,32 @@ fn bind_socket(fd: RawFd) -> Result<(), Error> {
     Ok(())
 }
 
+fn bind_abstruct_socket(fd: RawFd) -> Result<(), Error> {
+    unsafe {
+        let mut addr = sockaddr_un {
+            sun_family: AF_UNIX as u16,
+            sun_path: [0; 108],
+        };
+        addr.sun_path[0] = 0;
+        let path_cstr = CString::new(SOCKET_ABSTRUCT_PATH).unwrap();
+        let path_bytes = path_cstr.as_bytes();
+        for (i, &byte) in path_bytes.iter().enumerate() {
+            addr.sun_path[i + 1] = byte as i8;
+        }
+
+        if bind(
+            fd,
+            &addr as *const _ as *const sockaddr,
+            mem::size_of_val(&addr) as socklen_t,
+        ) == -1
+        {
+            return Err(Error::last_os_error());
+        }
+    }
+
+    Ok(())
+}
+
 fn listen_socket(fd: RawFd) -> Result<(), Error> {
     unsafe {
         if listen(fd, 5) == -1 {
@@ -111,7 +138,7 @@ fn test_stream() -> Result<(), Error> {
         send_message(client_fd, MSG2).expect("Failed to send message");
         println!("Server send finish");
 
-        unsafe { close(client_fd) };
+        unsafe { close(server_fd) };
     });
 
     let client_fd = create_stream_socket()?;
@@ -173,9 +200,124 @@ fn test_stream() -> Result<(), Error> {
     Ok(())
 }
 
+fn test_abstruct_namespace() -> Result<(), Error> {
+    let server_fd = create_stream_socket()?;
+    bind_abstruct_socket(server_fd)?;
+    listen_socket(server_fd)?;
+
+    let server_thread = std::thread::spawn(move || {
+        let client_fd = accept_conn(server_fd).expect("Failed to accept connection");
+        println!("accept success!");
+        let recv_msg = recv_message(client_fd).expect("Failed to receive message");
+
+        println!("Server: Received message: {}", recv_msg);
+        send_message(client_fd, MSG2).expect("Failed to send message");
+        println!("Server send finish");
+
+        unsafe { close(server_fd) }
+    });
+
+    let client_fd = create_stream_socket()?;
+    unsafe {
+        let mut addr = sockaddr_un {
+            sun_family: AF_UNIX as u16,
+            sun_path: [0; 108],
+        };
+        addr.sun_path[0] = 0;
+        let path_cstr = CString::new(SOCKET_ABSTRUCT_PATH).unwrap();
+        let path_bytes = path_cstr.as_bytes();
+
+        for (i, &byte) in path_bytes.iter().enumerate() {
+            addr.sun_path[i + 1] = byte as i8;
+        }
+
+        if connect(
+            client_fd,
+            &addr as *const _ as *const sockaddr,
+            mem::size_of_val(&addr) as socklen_t,
+        ) == -1
+        {
+            return Err(Error::last_os_error());
+        }
+    }
+
+    send_message(client_fd, MSG1)?;
+    // get peer_name
+    unsafe {
+        let mut addrss = sockaddr_un {
+            sun_family: AF_UNIX as u16,
+            sun_path: [0; 108],
+        };
+        let mut len = mem::size_of_val(&addrss) as socklen_t;
+        let res = getpeername(client_fd, &mut addrss as *mut _ as *mut sockaddr, &mut len);
+        if res == -1 {
+            return Err(Error::last_os_error());
+        }
+        let sun_path = addrss.sun_path.clone();
+        let peer_path: [u8; 108] = sun_path
+            .iter()
+            .map(|&x| x as u8)
+            .collect::<Vec<u8>>()
+            .try_into()
+            .unwrap();
+        println!(
+            "Client: Connected to server at path: {}",
+            String::from_utf8_lossy(&peer_path)
+        );
+    }
+
+    server_thread.join().expect("Server thread panicked");
+    println!("Client try recv!");
+    let recv_msg = recv_message(client_fd).expect("Failed to receive message from server");
+    println!("Client Received message: {}", recv_msg);
+
+    unsafe { close(client_fd) };
+    Ok(())
+}
+
+fn test_recourse_free() -> Result<(), Error> {
+    let client_fd = create_stream_socket()?;
+    unsafe {
+        let mut addr = sockaddr_un {
+            sun_family: AF_UNIX as u16,
+            sun_path: [0; 108],
+        };
+        addr.sun_path[0] = 0;
+        let path_cstr = CString::new(SOCKET_ABSTRUCT_PATH).unwrap();
+        let path_bytes = path_cstr.as_bytes();
+
+        for (i, &byte) in path_bytes.iter().enumerate() {
+            addr.sun_path[i + 1] = byte as i8;
+        }
+
+        if connect(
+            client_fd,
+            &addr as *const _ as *const sockaddr,
+            mem::size_of_val(&addr) as socklen_t,
+        ) == -1
+        {
+            return Err(Error::last_os_error());
+        }
+    }
+
+    send_message(client_fd, MSG1)?;
+    unsafe { close(client_fd) };
+    Ok(())
+}
+
 fn main() {
     match test_stream() {
         Ok(_) => println!("test for unix stream success"),
         Err(_) => println!("test for unix stream failed"),
     }
+
+    match test_abstruct_namespace() {
+        Ok(_) => println!("test for unix abstruct namespace success"),
+        Err(_) => println!("test for unix abstruct namespace failed"),
+    }
+
+    match test_recourse_free() {
+        Ok(_) => println!("not free!"),
+        Err(_) => println!("free!"),
+    }
 }