Browse Source

feat(socket): 添加shutdown方法并实现ShutdownTemp的TryFrom转换

xiaolin2004 3 months ago
parent
commit
bbea79ec19

+ 18 - 3
kernel/src/net/socket/common/shutdown.rs

@@ -1,4 +1,6 @@
-use core::sync::atomic::AtomicU8;
+use core::{default, sync::atomic::AtomicU8};
+
+use system_error::SystemError;
 
 bitflags! {
     /// @brief 用于指定socket的关闭类型
@@ -101,8 +103,8 @@ impl ShutdownTemp {
         self.bit == 0
     }
 
-    pub fn from_how(how: usize) -> Self {
-        Self { bit: how as u8 + 1 }
+    pub fn bits(&self) -> ShutdownBit {
+        ShutdownBit { bits: self.bit }
     }
 }
 
@@ -116,3 +118,16 @@ impl From<ShutdownBit> for ShutdownTemp {
         }
     }
 }
+
+impl TryFrom<usize> for ShutdownTemp {
+    type Error = SystemError;
+
+    fn try_from(value: usize) -> Result<Self, Self::Error> {
+        match value {
+            0 | 1 | 2 => Ok(ShutdownTemp {
+                bit: value as u8 + 1,
+            }),
+            _ => Err(SystemError::EINVAL),
+        }
+    }
+}

+ 22 - 0
kernel/src/net/socket/inet/stream/mod.rs

@@ -338,6 +338,28 @@ impl Socket for TcpSocket {
             .recv_buffer_size()
     }
 
+
+    fn shutdown(&self, how: ShutdownTemp) -> Result<(), SystemError> {
+        let self_shutdown = self.shutdown.get().bits();
+        let diff = how.bits().difference(self_shutdown);
+        match diff.is_empty(){
+            true => {
+                return Ok(())
+            },
+            false => {
+                if diff.contains(ShutdownBit::SHUT_RD){
+                    self.shutdown.recv_shutdown();
+                    // TODO 协议栈处理
+                }
+                if diff.contains(ShutdownBit::SHUT_WR){
+                    self.shutdown.send_shutdown();
+                    // TODO 协议栈处理
+                }
+            },
+        }
+        Ok(()) 
+    }
+
     fn close(&self) -> Result<(), SystemError> {
         let inner = self.inner
             .write()

+ 1 - 1
kernel/src/net/syscall.rs

@@ -367,7 +367,7 @@ impl Syscall {
         let socket: Arc<socket::Inode> = ProcessManager::current_pcb()
             .get_socket(fd as i32)
             .ok_or(SystemError::EBADF)?;
-        socket.shutdown(socket::ShutdownTemp::from_how(how))?;
+        socket.shutdown(how.try_into()?)?;
         return Ok(0);
     }