Samuka007 5 months ago
parent
commit
8fe49e190e

+ 1 - 1
kernel/src/arch/x86_64/syscall/mod.rs

@@ -133,7 +133,7 @@ pub extern "sysv64" fn syscall_handler(frame: &mut TrapFrame) {
                 show &= false;
             }
         }
-        show = false;
+        show &= false;
         if show {
             debug!("[SYS] [Pid: {:?}] [Call: {:?}]", pid, to_print);
         }

+ 14 - 10
kernel/src/driver/base/block/block_device.rs

@@ -1,16 +1,20 @@
 /// 引入Module
-use crate::{driver::{
-    base::{
-        device::{
-            device_number::{DeviceNumber, Major}, Device, DeviceError, IdTable, BLOCKDEVS
-        },
-        map::{
-            DeviceStruct, DEV_MAJOR_DYN_END, DEV_MAJOR_DYN_EXT_END, DEV_MAJOR_DYN_EXT_START,
-            DEV_MAJOR_HASH_SIZE, DEV_MAJOR_MAX,
+use crate::{
+    driver::{
+        base::{
+            device::{
+                device_number::{DeviceNumber, Major},
+                Device, DeviceError, IdTable, BLOCKDEVS,
+            },
+            map::{
+                DeviceStruct, DEV_MAJOR_DYN_END, DEV_MAJOR_DYN_EXT_END, DEV_MAJOR_DYN_EXT_START,
+                DEV_MAJOR_HASH_SIZE, DEV_MAJOR_MAX,
+            },
         },
+        block::cache::{cached_block_device::BlockCache, BlockCacheError, BLOCK_SIZE},
     },
-    block::cache::{cached_block_device::BlockCache, BlockCacheError, BLOCK_SIZE},
-}, filesystem::sysfs::AttributeGroup};
+    filesystem::sysfs::AttributeGroup,
+};
 
 use alloc::{string::String, sync::Arc, vec::Vec};
 use core::{any::Any, fmt::Display, ops::Deref};

+ 8 - 5
kernel/src/driver/base/uevent/mod.rs

@@ -157,10 +157,14 @@ impl Attribute for UeventAttr {
                 writeln!(&mut uevent_content, "DEVTYPE=char").unwrap();
             }
             DeviceType::Net => {
-                let net_device = device.clone().cast::<dyn Iface>().map_err(|e: Arc<dyn Device>| {
-                    warn!("device:{:?} is not a net device!", e);
-                    SystemError::EINVAL
-                })?;
+                let net_device =
+                    device
+                        .clone()
+                        .cast::<dyn Iface>()
+                        .map_err(|e: Arc<dyn Device>| {
+                            warn!("device:{:?} is not a net device!", e);
+                            SystemError::EINVAL
+                        })?;
                 let iface_id = net_device.nic_id();
                 let device_name = device.name();
                 writeln!(&mut uevent_content, "INTERFACE={}", device_name).unwrap();
@@ -200,7 +204,6 @@ impl Attribute for UeventAttr {
     }
 }
 
-
 /// 将设备的基本信息写入 uevent 文件
 fn sysfs_emit_str(buf: &mut [u8], content: &str) -> Result<usize, SystemError> {
     log::info!("sysfs_emit_str");

+ 6 - 6
kernel/src/driver/net/mod.rs

@@ -255,12 +255,12 @@ impl IfaceCommon {
             }
         });
 
-            // let closed_sockets = self
-            //     .closing_sockets
-            //     .lock_irq_disabled()
-            //     .extract_if(|closing_socket| closing_socket.is_closed())
-            //     .collect::<Vec<_>>();
-            // drop(closed_sockets);
+        // let closed_sockets = self
+        //     .closing_sockets
+        //     .lock_irq_disabled()
+        //     .extract_if(|closing_socket| closing_socket.is_closed())
+        //     .collect::<Vec<_>>();
+        // drop(closed_sockets);
         // }
     }
 

+ 5 - 4
kernel/src/net/socket/inet/stream/inner.rs

@@ -3,6 +3,7 @@ use core::sync::atomic::{AtomicU32, AtomicUsize};
 use crate::libs::rwlock::RwLock;
 use crate::net::socket::EPollEventType;
 use crate::net::socket::{self, inet::Types};
+use alloc::boxed::Box;
 use alloc::vec::Vec;
 use smoltcp;
 use system_error::SystemError::{self, *};
@@ -30,13 +31,13 @@ where
 
 #[derive(Debug)]
 pub enum Init {
-    Unbound(smoltcp::socket::tcp::Socket<'static>),
+    Unbound(Box<smoltcp::socket::tcp::Socket<'static>>),
     Bound((socket::inet::BoundInner, smoltcp::wire::IpEndpoint)),
 }
 
 impl Init {
     pub(super) fn new() -> Self {
-        Init::Unbound(new_smoltcp_socket())
+        Init::Unbound(Box::new(new_smoltcp_socket()))
     }
 
     /// 传入一个已经绑定的socket
@@ -55,7 +56,7 @@ impl Init {
     ) -> Result<Self, SystemError> {
         match self {
             Init::Unbound(socket) => {
-                let bound = socket::inet::BoundInner::bind(socket, &local_endpoint.addr)?;
+                let bound = socket::inet::BoundInner::bind(*socket, &local_endpoint.addr)?;
                 bound
                     .port_manager()
                     .bind_port(Types::Tcp, local_endpoint.port)?;
@@ -73,7 +74,7 @@ impl Init {
         match self {
             Init::Unbound(socket) => {
                 let (bound, address) =
-                    socket::inet::BoundInner::bind_ephemeral(socket, remote_endpoint.addr)
+                    socket::inet::BoundInner::bind_ephemeral(*socket, remote_endpoint.addr)
                         .map_err(|err| (Self::new(), err))?;
                 let bound_port = bound
                     .port_manager()

+ 14 - 10
kernel/src/net/socket/inet/stream/mod.rs

@@ -185,15 +185,19 @@ impl TcpSocket {
     }
 
     pub fn try_recv(&self, buf: &mut [u8]) -> Result<usize, SystemError> {
-        self.inner.read().as_ref().map(|inner| {
-            inner.iface().unwrap().poll();
-            let result = match inner {
-                Inner::Established(inner) => inner.recv_slice(buf),
-                _ => Err(EINVAL),
-            };
-            inner.iface().unwrap().poll();
-            result
-        }).unwrap()
+        self.inner
+            .read()
+            .as_ref()
+            .map(|inner| {
+                inner.iface().unwrap().poll();
+                let result = match inner {
+                    Inner::Established(inner) => inner.recv_slice(buf),
+                    _ => Err(EINVAL),
+                };
+                inner.iface().unwrap().poll();
+                result
+            })
+            .unwrap()
     }
 
     pub fn try_send(&self, buf: &[u8]) -> Result<usize, SystemError> {
@@ -238,7 +242,7 @@ impl Socket for TcpSocket {
     fn get_name(&self) -> Result<Endpoint, SystemError> {
         match self.inner.read().as_ref().expect("Tcp Inner is None") {
             Inner::Init(Init::Unbound(_)) => Ok(Endpoint::Ip(UNSPECIFIED_LOCAL_ENDPOINT)),
-            Inner::Init(Init::Bound((_, local))) => Ok(Endpoint::Ip(local.clone())),
+            Inner::Init(Init::Bound((_, local))) => Ok(Endpoint::Ip(*local)),
             Inner::Connecting(connecting) => Ok(Endpoint::Ip(connecting.get_name())),
             Inner::Established(established) => Ok(Endpoint::Ip(established.local_endpoint())),
             Inner::Listening(listening) => Ok(Endpoint::Ip(listening.get_name())),

+ 3 - 7
kernel/src/net/socket/unix/seqpacket/inner.rs

@@ -62,11 +62,7 @@ pub(super) struct Listener {
 impl Listener {
     pub(super) fn new(inode: Endpoint, backlog: usize) -> Self {
         log::debug!("backlog {}", backlog);
-        let back = if backlog > 1024 {
-            1024 as usize
-        } else {
-            backlog
-        };
+        let back = if backlog > 1024 { 1024_usize } else { backlog };
         return Self {
             inode,
             backlog: AtomicUsize::new(back),
@@ -82,7 +78,7 @@ impl Listener {
         log::debug!(" incom len {}", incoming_conns.len());
         let conn = incoming_conns
             .pop_front()
-            .ok_or_else(|| SystemError::EAGAIN_OR_EWOULDBLOCK)?;
+            .ok_or(SystemError::EAGAIN_OR_EWOULDBLOCK)?;
         let socket =
             Arc::downcast::<SeqpacketSocket>(conn.inner()).map_err(|_| SystemError::EINVAL)?;
         let peer = match &*socket.inner.read() {
@@ -190,7 +186,7 @@ impl Connected {
         if self.can_send()? {
             return self.send_slice(buf);
         } else {
-            log::debug!("can not send {:?}", String::from_utf8_lossy(&buf[..]));
+            log::debug!("can not send {:?}", String::from_utf8_lossy(buf));
             return Err(SystemError::ENOBUFS);
         }
     }

+ 3 - 7
kernel/src/net/socket/unix/seqpacket/mod.rs

@@ -230,11 +230,7 @@ impl Socket for SeqpacketSocket {
         if !self.is_nonblocking() {
             loop {
                 wq_wait_event_interruptible!(self.wait_queue, self.is_acceptable(), {})?;
-                match self
-                    .try_accept()
-                    .map(|(seqpacket_socket, remote_endpoint)| {
-                        (seqpacket_socket, Endpoint::from(remote_endpoint))
-                    }) {
+                match self.try_accept() {
                     Ok((socket, epoint)) => return Ok((socket, epoint)),
                     Err(_) => continue,
                 }
@@ -274,7 +270,7 @@ impl Socket for SeqpacketSocket {
         };
 
         if let Some(endpoint) = endpoint {
-            return Ok(Endpoint::from(endpoint));
+            return Ok(endpoint);
         } else {
             return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
         }
@@ -289,7 +285,7 @@ impl Socket for SeqpacketSocket {
         };
 
         if let Some(endpoint) = endpoint {
-            return Ok(Endpoint::from(endpoint));
+            return Ok(endpoint);
         } else {
             return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
         }

+ 1 - 4
kernel/src/net/socket/unix/stream/mod.rs

@@ -231,10 +231,7 @@ impl Socket for StreamSocket {
         //目前只实现了阻塞式实现
         loop {
             wq_wait_event_interruptible!(self.wait_queue, self.is_acceptable(), {})?;
-            match self
-                .try_accept()
-                .map(|(stream_socket, remote_endpoint)| (stream_socket, remote_endpoint))
-            {
+            match self.try_accept() {
                 Ok((socket, endpoint)) => {
                     debug!("server accept!:{:?}", endpoint);
                     return Ok((socket, endpoint));

+ 1 - 1
kernel/src/net/syscall_util.rs

@@ -312,7 +312,7 @@ impl From<Endpoint> for SockAddr {
                 }
                 let addr_un = SockAddrUn {
                     sun_family: AddressFamily::Unix as u16,
-                    sun_path: sun_path,
+                    sun_path,
                 };
                 return SockAddr { addr_un };
             }

+ 1 - 1
user/apps/ping/src/ping.rs

@@ -101,7 +101,7 @@ impl Ping {
 
         for i in 0..this.config.count {
             let _this = this.clone();
-            let handle = thread::spawn(move||{
+            let handle = thread::spawn(move || {
                 _this.ping(i).unwrap();
             });
             _send.fetch_add(1, Ordering::SeqCst);

+ 31 - 8
user/apps/test-uevent/src/main.rs

@@ -1,7 +1,10 @@
-use libc::{sockaddr, sockaddr_storage, recvfrom, bind, sendto, socket, AF_NETLINK, SOCK_DGRAM, SOCK_CLOEXEC, getpid, c_void};
+use libc::{
+    bind, c_void, getpid, recvfrom, sendto, sockaddr, sockaddr_storage, socket, AF_NETLINK,
+    SOCK_CLOEXEC, SOCK_DGRAM,
+};
 use nix::libc;
 use std::os::unix::io::RawFd;
-use std::{ mem, io};
+use std::{io, mem};
 
 #[repr(C)]
 struct Nlmsghdr {
@@ -14,7 +17,11 @@ struct Nlmsghdr {
 
 fn create_netlink_socket() -> io::Result<RawFd> {
     let sockfd = unsafe {
-        socket(AF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, libc::NETLINK_KOBJECT_UEVENT)
+        socket(
+            AF_NETLINK,
+            SOCK_DGRAM | SOCK_CLOEXEC,
+            libc::NETLINK_KOBJECT_UEVENT,
+        )
     };
 
     if sockfd < 0 {
@@ -33,7 +40,11 @@ fn bind_netlink_socket(sock: RawFd) -> io::Result<()> {
     addr.nl_groups = 0;
 
     let ret = unsafe {
-        bind(sock, &addr as *const _ as *const sockaddr, mem::size_of::<libc::sockaddr_nl>() as u32)
+        bind(
+            sock,
+            &addr as *const _ as *const sockaddr,
+            mem::size_of::<libc::sockaddr_nl>() as u32,
+        )
     };
 
     if ret < 0 {
@@ -90,7 +101,10 @@ fn receive_uevent(sock: RawFd) -> io::Result<String> {
     // 检查套接字文件描述符是否有效
     if sock < 0 {
         println!("Invalid socket file descriptor: {}", sock);
-        return Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid socket file descriptor"));
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "Invalid socket file descriptor",
+        ));
     }
 
     let mut buf = [0u8; 1024];
@@ -100,7 +114,10 @@ fn receive_uevent(sock: RawFd) -> io::Result<String> {
     // 检查缓冲区指针和长度是否有效
     if buf.is_empty() {
         println!("Buffer is empty");
-        return Err(io::Error::new(io::ErrorKind::InvalidInput, "Buffer is empty"));
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "Buffer is empty",
+        ));
     }
     let len = unsafe {
         recvfrom(
@@ -122,13 +139,19 @@ fn receive_uevent(sock: RawFd) -> io::Result<String> {
     let nlmsghdr_size = mem::size_of::<Nlmsghdr>();
     if (len as usize) < nlmsghdr_size {
         println!("Received message is too short");
-        return Err(io::Error::new(io::ErrorKind::InvalidData, "Received message is too short"));
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidData,
+            "Received message is too short",
+        ));
     }
 
     let nlmsghdr = unsafe { &*(buf.as_ptr() as *const Nlmsghdr) };
     if nlmsghdr.nlmsg_len as isize > len {
         println!("Received message is incomplete");
-        return Err(io::Error::new(io::ErrorKind::InvalidData, "Received message is incomplete"));
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidData,
+            "Received message is incomplete",
+        ));
     }
 
     let message_data = &buf[nlmsghdr_size..nlmsghdr.nlmsg_len as usize];

+ 3 - 3
user/apps/test_seqpacket/src/main.rs

@@ -1,8 +1,8 @@
-mod seq_socket;
 mod seq_pair;
+mod seq_socket;
 
-use seq_socket::test_seq_socket;
 use seq_pair::test_seq_pair;
+use seq_socket::test_seq_socket;
 
 fn main() -> Result<(), std::io::Error> {
     if let Err(e) = test_seq_socket() {
@@ -187,4 +187,4 @@ fn main() -> Result<(), std::io::Error> {
 //     let len = socket1.read(&mut buf)?;
 //     println!("sock1 receive: {:?}", String::from_utf8_lossy(&buf[..len]));
 //     Ok(())
-// }
+// }

+ 5 - 4
user/apps/test_seqpacket/src/seq_pair.rs

@@ -1,16 +1,17 @@
 use nix::sys::socket::{socketpair, AddressFamily, SockFlag, SockType};
 use std::fs::File;
-use std::io::{Read, Write,Error};
+use std::io::{Error, Read, Write};
 use std::os::fd::FromRawFd;
 
-pub fn test_seq_pair()->Result<(),Error>{
+pub fn test_seq_pair() -> Result<(), Error> {
     // 创建 socket pair
     let (sock1, sock2) = socketpair(
         AddressFamily::Unix,
         SockType::SeqPacket, // 使用 SeqPacket 类型
         None,                // 协议默认
         SockFlag::empty(),
-    ).expect("Failed to create socket pair");
+    )
+    .expect("Failed to create socket pair");
 
     let mut socket1 = unsafe { File::from_raw_fd(sock1) };
     let mut socket2 = unsafe { File::from_raw_fd(sock2) };
@@ -36,4 +37,4 @@ pub fn test_seq_pair()->Result<(),Error>{
     let len = socket1.read(&mut buf)?;
     println!("sock1 receive: {:?}", String::from_utf8_lossy(&buf[..len]));
     Ok(())
-}
+}

+ 95 - 69
user/apps/test_seqpacket/src/seq_socket.rs

@@ -1,16 +1,14 @@
-
 use libc::*;
-use std::{fs, str};
 use std::ffi::CString;
 use std::io::Error;
 use std::mem;
 use std::os::unix::io::RawFd;
+use std::{fs, str};
 
 const SOCKET_PATH: &str = "/test.seqpacket";
 const MSG1: &str = "Hello, Unix SEQPACKET socket from Client!";
 const MSG2: &str = "Hello, Unix SEQPACKET socket from Server!";
 
-
 fn create_seqpacket_socket() -> Result<RawFd, Error> {
     unsafe {
         let fd = socket(AF_UNIX, SOCK_SEQPACKET, 0);
@@ -33,7 +31,12 @@ fn bind_socket(fd: RawFd) -> Result<(), Error> {
             addr.sun_path[i] = byte as i8;
         }
 
-        if bind(fd, &addr as *const _ as *const sockaddr, mem::size_of_val(&addr) as socklen_t) == -1 {
+        if bind(
+            fd,
+            &addr as *const _ as *const sockaddr,
+            mem::size_of_val(&addr) as socklen_t,
+        ) == -1
+        {
             return Err(Error::last_os_error());
         }
     }
@@ -68,7 +71,13 @@ fn accept_connection(fd: RawFd) -> Result<RawFd, Error> {
 fn send_message(fd: RawFd, msg: &str) -> Result<(), Error> {
     unsafe {
         let msg_bytes = msg.as_bytes();
-        if send(fd, msg_bytes.as_ptr() as *const libc::c_void, msg_bytes.len(), 0) == -1 {
+        if send(
+            fd,
+            msg_bytes.as_ptr() as *const libc::c_void,
+            msg_bytes.len(),
+            0,
+        ) == -1
+        {
             return Err(Error::last_os_error());
         }
     }
@@ -78,7 +87,12 @@ fn send_message(fd: RawFd, msg: &str) -> Result<(), Error> {
 fn receive_message(fd: RawFd) -> Result<String, Error> {
     let mut buffer = [0; 1024];
     unsafe {
-        let len = recv(fd, buffer.as_mut_ptr() as *mut libc::c_void, buffer.len(), 0);
+        let len = recv(
+            fd,
+            buffer.as_mut_ptr() as *mut libc::c_void,
+            buffer.len(),
+            0,
+        );
         if len == -1 {
             return Err(Error::last_os_error());
         }
@@ -86,70 +100,82 @@ fn receive_message(fd: RawFd) -> Result<String, Error> {
     }
 }
 
-pub fn test_seq_socket() ->Result<(), Error>{
-        // Create and bind the server socket
-        fs::remove_file(&SOCKET_PATH).ok();
-
-        let server_fd = create_seqpacket_socket()?;
-        bind_socket(server_fd)?;
-        listen_socket(server_fd)?;
-
-        // Accept connection in a separate thread
-        let server_thread = std::thread::spawn(move || {
-            let client_fd = accept_connection(server_fd).expect("Failed to accept connection");
-    
-            // Receive and print message
-            let received_msg = receive_message(client_fd).expect("Failed to receive message");
-            println!("Server: Received message: {}", received_msg);
-            
-            send_message(client_fd, MSG2).expect("Failed to send message");
-    
-            // Close client connection
-            unsafe { close(client_fd) };
-        });
-    
-        // Create and connect the client socket
-        let client_fd = create_seqpacket_socket()?;
-        unsafe {
-            let mut addr = sockaddr_un {
-                sun_family: AF_UNIX as u16,
-                sun_path: [0; 108],
-            };
-            let path_cstr = CString::new(SOCKET_PATH).unwrap();
-            let path_bytes = path_cstr.as_bytes();
-            // Convert u8 to i8
-            for (i, &byte) in path_bytes.iter().enumerate() {
-                addr.sun_path[i] = byte as i8;
-            }
-            if connect(client_fd, &addr as *const _ as *const sockaddr, mem::size_of_val(&addr) as socklen_t) == -1 {
-                return Err(Error::last_os_error());
-            }
-        }
-        send_message(client_fd, MSG1)?;
-        let received_msg = receive_message(client_fd).expect("Failed to receive message");
-        println!("Client: Received message: {}", received_msg);
-        // get peer_name
-        unsafe {
-            let mut addrss = sockaddr_un {
-                sun_family: AF_UNIX as u16,
-                sun_path: [0; 108],
-            };
-            let mut len = mem::size_of_val(&addrss) as socklen_t;
-            let res = getpeername(client_fd, &mut addrss as *mut _ as *mut sockaddr, &mut len);
-            if res == -1 {
-                return Err(Error::last_os_error());
-            }
-            let sun_path = addrss.sun_path.clone();
-            let peer_path:[u8;108] = sun_path.iter().map(|&x| x as u8).collect::<Vec<u8>>().try_into().unwrap();
-            println!("Client: Connected to server at path: {}", String::from_utf8_lossy(&peer_path));
+pub fn test_seq_socket() -> Result<(), Error> {
+    // Create and bind the server socket
+    fs::remove_file(&SOCKET_PATH).ok();
 
-        }
-            
-        server_thread.join().expect("Server thread panicked");
+    let server_fd = create_seqpacket_socket()?;
+    bind_socket(server_fd)?;
+    listen_socket(server_fd)?;
+
+    // Accept connection in a separate thread
+    let server_thread = std::thread::spawn(move || {
+        let client_fd = accept_connection(server_fd).expect("Failed to accept connection");
+
+        // Receive and print message
         let received_msg = receive_message(client_fd).expect("Failed to receive message");
-        println!("Client: Received message: {}", received_msg);
+        println!("Server: Received message: {}", received_msg);
+
+        send_message(client_fd, MSG2).expect("Failed to send message");
+
         // Close client connection
         unsafe { close(client_fd) };
-        fs::remove_file(&SOCKET_PATH).ok();
-        Ok(())
-}
+    });
+
+    // Create and connect the client socket
+    let client_fd = create_seqpacket_socket()?;
+    unsafe {
+        let mut addr = sockaddr_un {
+            sun_family: AF_UNIX as u16,
+            sun_path: [0; 108],
+        };
+        let path_cstr = CString::new(SOCKET_PATH).unwrap();
+        let path_bytes = path_cstr.as_bytes();
+        // Convert u8 to i8
+        for (i, &byte) in path_bytes.iter().enumerate() {
+            addr.sun_path[i] = byte as i8;
+        }
+        if connect(
+            client_fd,
+            &addr as *const _ as *const sockaddr,
+            mem::size_of_val(&addr) as socklen_t,
+        ) == -1
+        {
+            return Err(Error::last_os_error());
+        }
+    }
+    send_message(client_fd, MSG1)?;
+    let received_msg = receive_message(client_fd).expect("Failed to receive message");
+    println!("Client: Received message: {}", received_msg);
+    // get peer_name
+    unsafe {
+        let mut addrss = sockaddr_un {
+            sun_family: AF_UNIX as u16,
+            sun_path: [0; 108],
+        };
+        let mut len = mem::size_of_val(&addrss) as socklen_t;
+        let res = getpeername(client_fd, &mut addrss as *mut _ as *mut sockaddr, &mut len);
+        if res == -1 {
+            return Err(Error::last_os_error());
+        }
+        let sun_path = addrss.sun_path.clone();
+        let peer_path: [u8; 108] = sun_path
+            .iter()
+            .map(|&x| x as u8)
+            .collect::<Vec<u8>>()
+            .try_into()
+            .unwrap();
+        println!(
+            "Client: Connected to server at path: {}",
+            String::from_utf8_lossy(&peer_path)
+        );
+    }
+
+    server_thread.join().expect("Server thread panicked");
+    let received_msg = receive_message(client_fd).expect("Failed to receive message");
+    println!("Client: Received message: {}", received_msg);
+    // Close client connection
+    unsafe { close(client_fd) };
+    fs::remove_file(&SOCKET_PATH).ok();
+    Ok(())
+}

+ 46 - 18
user/apps/test_unix_stream_socket/src/main.rs

@@ -1,19 +1,19 @@
-use std::io::Error;
-use std::os::fd::RawFd;
-use std::fs;
 use libc::*;
 use std::ffi::CString;
+use std::fs;
+use std::io::Error;
 use std::mem;
+use std::os::fd::RawFd;
 
 const SOCKET_PATH: &str = "/test.stream";
 const MSG1: &str = "Hello, unix stream socket from Client!";
 const MSG2: &str = "Hello, unix stream socket from Server!";
 
-fn create_stream_socket() -> Result<RawFd, Error>{
+fn create_stream_socket() -> Result<RawFd, Error> {
     unsafe {
         let fd = socket(AF_UNIX, SOCK_STREAM, 0);
         if fd == -1 {
-            return Err(Error::last_os_error())
+            return Err(Error::last_os_error());
         }
         Ok(fd)
     }
@@ -31,7 +31,12 @@ fn bind_socket(fd: RawFd) -> Result<(), Error> {
             addr.sun_path[i] = byte as i8;
         }
 
-        if bind(fd, &addr as *const _ as *const sockaddr, mem::size_of_val(&addr) as socklen_t) == -1 {
+        if bind(
+            fd,
+            &addr as *const _ as *const sockaddr,
+            mem::size_of_val(&addr) as socklen_t,
+        ) == -1
+        {
             return Err(Error::last_os_error());
         }
     }
@@ -61,7 +66,13 @@ fn accept_conn(fd: RawFd) -> Result<RawFd, Error> {
 fn send_message(fd: RawFd, msg: &str) -> Result<(), Error> {
     unsafe {
         let msg_bytes = msg.as_bytes();
-        if send(fd, msg_bytes.as_ptr() as *const libc::c_void, msg_bytes.len(), 0)== -1 {
+        if send(
+            fd,
+            msg_bytes.as_ptr() as *const libc::c_void,
+            msg_bytes.len(),
+            0,
+        ) == -1
+        {
             return Err(Error::last_os_error());
         }
     }
@@ -71,7 +82,12 @@ fn send_message(fd: RawFd, msg: &str) -> Result<(), Error> {
 fn recv_message(fd: RawFd) -> Result<String, Error> {
     let mut buffer = [0; 1024];
     unsafe {
-        let len = recv(fd, buffer.as_mut_ptr() as *mut libc::c_void, buffer.len(),0);
+        let len = recv(
+            fd,
+            buffer.as_mut_ptr() as *mut libc::c_void,
+            buffer.len(),
+            0,
+        );
         if len == -1 {
             return Err(Error::last_os_error());
         }
@@ -82,7 +98,7 @@ fn recv_message(fd: RawFd) -> Result<String, Error> {
 fn test_stream() -> Result<(), Error> {
     fs::remove_file(&SOCKET_PATH).ok();
 
-    let server_fd =  create_stream_socket()?;
+    let server_fd = create_stream_socket()?;
     bind_socket(server_fd)?;
     listen_socket(server_fd)?;
 
@@ -95,7 +111,7 @@ fn test_stream() -> Result<(), Error> {
         send_message(client_fd, MSG2).expect("Failed to send message");
         println!("Server send finish");
 
-        unsafe {close(client_fd)};
+        unsafe { close(client_fd) };
     });
 
     let client_fd = create_stream_socket()?;
@@ -111,9 +127,14 @@ fn test_stream() -> Result<(), Error> {
             addr.sun_path[i] = byte as i8;
         }
 
-        if connect(client_fd, &addr as *const _ as *const sockaddr, mem::size_of_val(&addr) as socklen_t) == -1 {
+        if connect(
+            client_fd,
+            &addr as *const _ as *const sockaddr,
+            mem::size_of_val(&addr) as socklen_t,
+        ) == -1
+        {
             return Err(Error::last_os_error());
-        } 
+        }
     }
 
     send_message(client_fd, MSG1)?;
@@ -129,9 +150,16 @@ fn test_stream() -> Result<(), Error> {
             return Err(Error::last_os_error());
         }
         let sun_path = addrss.sun_path.clone();
-        let peer_path:[u8;108] = sun_path.iter().map(|&x| x as u8).collect::<Vec<u8>>().try_into().unwrap();
-        println!("Client: Connected to server at path: {}", String::from_utf8_lossy(&peer_path));
-
+        let peer_path: [u8; 108] = sun_path
+            .iter()
+            .map(|&x| x as u8)
+            .collect::<Vec<u8>>()
+            .try_into()
+            .unwrap();
+        println!(
+            "Client: Connected to server at path: {}",
+            String::from_utf8_lossy(&peer_path)
+        );
     }
 
     server_thread.join().expect("Server thread panicked");
@@ -139,7 +167,7 @@ fn test_stream() -> Result<(), Error> {
     let recv_msg = recv_message(client_fd).expect("Failed to receive message from server");
     println!("Client Received message: {}", recv_msg);
 
-    unsafe {close(client_fd)};
+    unsafe { close(client_fd) };
     fs::remove_file(&SOCKET_PATH).ok();
 
     Ok(())
@@ -148,6 +176,6 @@ fn test_stream() -> Result<(), Error> {
 fn main() {
     match test_stream() {
         Ok(_) => println!("test for unix stream success"),
-        Err(_) => println!("test for unix stream failed")
+        Err(_) => println!("test for unix stream failed"),
     }
-}
+}