|
@@ -1,5 +1,6 @@
|
|
|
#![allow(dead_code)]
|
|
|
use alloc::{boxed::Box, sync::Arc, vec::Vec};
|
|
|
+use hashbrown::HashMap;
|
|
|
use smoltcp::{
|
|
|
iface::{SocketHandle, SocketSet},
|
|
|
socket::{raw, tcp, udp},
|
|
@@ -25,6 +26,100 @@ lazy_static! {
|
|
|
/// TODO: 优化这里,自己实现SocketSet!!!现在这样的话,不管全局有多少个网卡,每个时间点都只会有1个进程能够访问socket
|
|
|
pub static ref SOCKET_SET: SpinLock<SocketSet<'static >> = SpinLock::new(SocketSet::new(vec![]));
|
|
|
pub static ref SOCKET_WAITQUEUE: WaitQueue = WaitQueue::INIT;
|
|
|
+ /// 端口管理器
|
|
|
+ pub static ref PORT_MANAGER: PortManager = PortManager::new();
|
|
|
+}
|
|
|
+
|
|
|
+/// @brief TCP 和 UDP 的端口管理器。
|
|
|
+/// 如果 TCP/UDP 的 socket 绑定了某个端口,它会在对应的表中记录,以检测端口冲突。
|
|
|
+pub struct PortManager {
|
|
|
+ // TCP 端口记录表
|
|
|
+ tcp_port_table: SpinLock<HashMap<u16, Arc<GlobalSocketHandle>>>,
|
|
|
+ // UDP 端口记录表
|
|
|
+ udp_port_table: SpinLock<HashMap<u16, Arc<GlobalSocketHandle>>>,
|
|
|
+}
|
|
|
+
|
|
|
+impl PortManager {
|
|
|
+ pub fn new() -> Self {
|
|
|
+ return Self {
|
|
|
+ tcp_port_table: SpinLock::new(HashMap::new()),
|
|
|
+ udp_port_table: SpinLock::new(HashMap::new()),
|
|
|
+ };
|
|
|
+ }
|
|
|
+
|
|
|
+ /// @brief 自动分配一个相对应协议中未被使用的PORT,如果动态端口均已被占用,返回错误码 EADDRINUSE
|
|
|
+ pub fn get_ephemeral_port(&self, socket_type: SocketType) -> Result<u16, SystemError> {
|
|
|
+ // TODO selects non-conflict high port
|
|
|
+
|
|
|
+ static mut EPHEMERAL_PORT: u16 = 0;
|
|
|
+ unsafe {
|
|
|
+ if EPHEMERAL_PORT == 0 {
|
|
|
+ EPHEMERAL_PORT = (49152 + rand() % (65536 - 49152)) as u16;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ let mut remaining = 65536 - 49152; // 剩余尝试分配端口次数
|
|
|
+ let mut port: u16;
|
|
|
+ while remaining > 0 {
|
|
|
+ unsafe {
|
|
|
+ if EPHEMERAL_PORT == 65535 {
|
|
|
+ EPHEMERAL_PORT = 49152;
|
|
|
+ } else {
|
|
|
+ EPHEMERAL_PORT = EPHEMERAL_PORT + 1;
|
|
|
+ }
|
|
|
+ port = EPHEMERAL_PORT;
|
|
|
+ }
|
|
|
+
|
|
|
+ // 使用 ListenTable 检查端口是否被占用
|
|
|
+ let listen_table_guard = match socket_type {
|
|
|
+ SocketType::UdpSocket => self.udp_port_table.lock(),
|
|
|
+ SocketType::TcpSocket => self.tcp_port_table.lock(),
|
|
|
+ SocketType::RawSocket => todo!(),
|
|
|
+ };
|
|
|
+ if let None = listen_table_guard.get(&port) {
|
|
|
+ drop(listen_table_guard);
|
|
|
+ return Ok(port);
|
|
|
+ }
|
|
|
+ remaining -= 1;
|
|
|
+ }
|
|
|
+ return Err(SystemError::EADDRINUSE);
|
|
|
+ }
|
|
|
+
|
|
|
+ /// @brief 检测给定端口是否已被占用,如果未被占用则在 TCP/UDP 对应的表中记录
|
|
|
+ ///
|
|
|
+ /// TODO: 增加支持端口复用的逻辑
|
|
|
+ pub fn get_port(
|
|
|
+ &self,
|
|
|
+ socket_type: SocketType,
|
|
|
+ port: u16,
|
|
|
+ handle: Arc<GlobalSocketHandle>,
|
|
|
+ ) -> Result<(), SystemError> {
|
|
|
+ if port > 0 {
|
|
|
+ let mut listen_table_guard = match socket_type {
|
|
|
+ SocketType::UdpSocket => self.udp_port_table.lock(),
|
|
|
+ SocketType::TcpSocket => self.tcp_port_table.lock(),
|
|
|
+ SocketType::RawSocket => panic!("RawSocket cann't bind a port"),
|
|
|
+ };
|
|
|
+ match listen_table_guard.get(&port) {
|
|
|
+ Some(_) => return Err(SystemError::EADDRINUSE),
|
|
|
+ None => listen_table_guard.insert(port, handle),
|
|
|
+ };
|
|
|
+ drop(listen_table_guard);
|
|
|
+ }
|
|
|
+ return Ok(());
|
|
|
+ }
|
|
|
+
|
|
|
+ /// @brief 在对应的端口记录表中将端口和 socket 解绑
|
|
|
+ pub fn unbind_port(&self, socket_type: SocketType, port: u16) -> Result<(), SystemError> {
|
|
|
+ let mut listen_table_guard = match socket_type {
|
|
|
+ SocketType::UdpSocket => self.udp_port_table.lock(),
|
|
|
+ SocketType::TcpSocket => self.tcp_port_table.lock(),
|
|
|
+ SocketType::RawSocket => return Ok(()),
|
|
|
+ };
|
|
|
+ listen_table_guard.remove(&port);
|
|
|
+ drop(listen_table_guard);
|
|
|
+ return Ok(());
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
/* For setsockopt(2) */
|
|
@@ -38,8 +133,8 @@ pub const SOL_SOCKET: u8 = 1;
|
|
|
pub struct GlobalSocketHandle(SocketHandle);
|
|
|
|
|
|
impl GlobalSocketHandle {
|
|
|
- pub fn new(handle: SocketHandle) -> Self {
|
|
|
- Self(handle)
|
|
|
+ pub fn new(handle: SocketHandle) -> Arc<Self> {
|
|
|
+ return Arc::new(Self(handle));
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -59,7 +154,7 @@ impl Drop for GlobalSocketHandle {
|
|
|
}
|
|
|
|
|
|
/// @brief socket的类型
|
|
|
-#[derive(Debug)]
|
|
|
+#[derive(Debug, Clone, Copy)]
|
|
|
pub enum SocketType {
|
|
|
/// 原始的socket
|
|
|
RawSocket,
|
|
@@ -86,7 +181,7 @@ bitflags! {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-#[derive(Debug)]
|
|
|
+#[derive(Debug, Clone)]
|
|
|
/// @brief 在trait Socket的metadata函数中返回该结构体供外部使用
|
|
|
pub struct SocketMetadata {
|
|
|
/// socket的类型
|
|
@@ -101,18 +196,36 @@ pub struct SocketMetadata {
|
|
|
pub options: SocketOptions,
|
|
|
}
|
|
|
|
|
|
+impl SocketMetadata {
|
|
|
+ fn new(
|
|
|
+ socket_type: SocketType,
|
|
|
+ send_buf_size: usize,
|
|
|
+ recv_buf_size: usize,
|
|
|
+ metadata_buf_size: usize,
|
|
|
+ options: SocketOptions,
|
|
|
+ ) -> Self {
|
|
|
+ Self {
|
|
|
+ socket_type,
|
|
|
+ send_buf_size,
|
|
|
+ recv_buf_size,
|
|
|
+ metadata_buf_size,
|
|
|
+ options,
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
/// @brief 表示原始的socket。原始套接字绕过传输层协议(如 TCP 或 UDP)并提供对网络层协议(如 IP)的直接访问。
|
|
|
///
|
|
|
/// ref: https://man7.org/linux/man-pages/man7/raw.7.html
|
|
|
#[derive(Debug, Clone)]
|
|
|
pub struct RawSocket {
|
|
|
- handle: GlobalSocketHandle,
|
|
|
+ handle: Arc<GlobalSocketHandle>,
|
|
|
/// 用户发送的数据包是否包含了IP头.
|
|
|
/// 如果是true,用户发送的数据包,必须包含IP头。(即用户要自行设置IP头+数据)
|
|
|
/// 如果是false,用户发送的数据包,不包含IP头。(即用户只要设置数据)
|
|
|
header_included: bool,
|
|
|
- /// socket的选项
|
|
|
- options: SocketOptions,
|
|
|
+ /// socket的metadata
|
|
|
+ metadata: SocketMetadata,
|
|
|
}
|
|
|
|
|
|
impl RawSocket {
|
|
@@ -147,12 +260,21 @@ impl RawSocket {
|
|
|
);
|
|
|
|
|
|
// 把socket添加到socket集合中,并得到socket的句柄
|
|
|
- let handle: GlobalSocketHandle = GlobalSocketHandle::new(SOCKET_SET.lock().add(socket));
|
|
|
+ let handle: Arc<GlobalSocketHandle> =
|
|
|
+ GlobalSocketHandle::new(SOCKET_SET.lock().add(socket));
|
|
|
+
|
|
|
+ let metadata = SocketMetadata::new(
|
|
|
+ SocketType::RawSocket,
|
|
|
+ Self::DEFAULT_RX_BUF_SIZE,
|
|
|
+ Self::DEFAULT_TX_BUF_SIZE,
|
|
|
+ Self::DEFAULT_METADATA_BUF_SIZE,
|
|
|
+ options,
|
|
|
+ );
|
|
|
|
|
|
return Self {
|
|
|
handle,
|
|
|
header_included: false,
|
|
|
- options,
|
|
|
+ metadata,
|
|
|
};
|
|
|
}
|
|
|
}
|
|
@@ -177,7 +299,7 @@ impl Socket for RawSocket {
|
|
|
);
|
|
|
}
|
|
|
Err(smoltcp::socket::raw::RecvError::Exhausted) => {
|
|
|
- if !self.options.contains(SocketOptions::BLOCK) {
|
|
|
+ if !self.metadata.options.contains(SocketOptions::BLOCK) {
|
|
|
// 如果是非阻塞的socket,就返回错误
|
|
|
return (Err(SystemError::EAGAIN_OR_EWOULDBLOCK), Endpoint::Ip(None));
|
|
|
}
|
|
@@ -271,7 +393,7 @@ impl Socket for RawSocket {
|
|
|
}
|
|
|
|
|
|
fn metadata(&self) -> Result<SocketMetadata, SystemError> {
|
|
|
- todo!()
|
|
|
+ Ok(self.metadata.clone())
|
|
|
}
|
|
|
|
|
|
fn box_clone(&self) -> alloc::boxed::Box<dyn Socket> {
|
|
@@ -284,9 +406,9 @@ impl Socket for RawSocket {
|
|
|
/// https://man7.org/linux/man-pages/man7/udp.7.html
|
|
|
#[derive(Debug, Clone)]
|
|
|
pub struct UdpSocket {
|
|
|
- pub handle: GlobalSocketHandle,
|
|
|
+ pub handle: Arc<GlobalSocketHandle>,
|
|
|
remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect(), 应该使用IP地址。
|
|
|
- options: SocketOptions,
|
|
|
+ metadata: SocketMetadata,
|
|
|
}
|
|
|
|
|
|
impl UdpSocket {
|
|
@@ -315,17 +437,29 @@ impl UdpSocket {
|
|
|
let socket = udp::Socket::new(tx_buffer, rx_buffer);
|
|
|
|
|
|
// 把socket添加到socket集合中,并得到socket的句柄
|
|
|
- let handle: GlobalSocketHandle = GlobalSocketHandle::new(SOCKET_SET.lock().add(socket));
|
|
|
+ let handle: Arc<GlobalSocketHandle> =
|
|
|
+ GlobalSocketHandle::new(SOCKET_SET.lock().add(socket));
|
|
|
+
|
|
|
+ let metadata = SocketMetadata::new(
|
|
|
+ SocketType::UdpSocket,
|
|
|
+ Self::DEFAULT_RX_BUF_SIZE,
|
|
|
+ Self::DEFAULT_TX_BUF_SIZE,
|
|
|
+ Self::DEFAULT_METADATA_BUF_SIZE,
|
|
|
+ options,
|
|
|
+ );
|
|
|
|
|
|
return Self {
|
|
|
handle,
|
|
|
remote_endpoint: None,
|
|
|
- options,
|
|
|
+ metadata,
|
|
|
};
|
|
|
}
|
|
|
|
|
|
fn do_bind(&self, socket: &mut udp::Socket, endpoint: Endpoint) -> Result<(), SystemError> {
|
|
|
if let Endpoint::Ip(Some(ip)) = endpoint {
|
|
|
+ // 检测端口是否已被占用
|
|
|
+ PORT_MANAGER.get_port(self.metadata.socket_type, ip.port, self.handle.clone())?;
|
|
|
+
|
|
|
let bind_res = if ip.addr.is_unspecified() {
|
|
|
socket.bind(ip.port)
|
|
|
} else {
|
|
@@ -388,7 +522,7 @@ impl Socket for UdpSocket {
|
|
|
// kdebug!("is open()={}", socket.is_open());
|
|
|
// kdebug!("socket endpoint={:?}", socket.endpoint());
|
|
|
if socket.endpoint().port == 0 {
|
|
|
- let temp_port = get_ephemeral_port();
|
|
|
+ let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
|
|
|
|
|
|
let local_ep = match remote_endpoint.addr {
|
|
|
// 远程remote endpoint使用什么协议,发送的时候使用的协议是一样的吧
|
|
@@ -461,7 +595,7 @@ impl Socket for UdpSocket {
|
|
|
todo!()
|
|
|
}
|
|
|
fn metadata(&self) -> Result<SocketMetadata, SystemError> {
|
|
|
- todo!()
|
|
|
+ Ok(self.metadata.clone())
|
|
|
}
|
|
|
|
|
|
fn box_clone(&self) -> alloc::boxed::Box<dyn Socket> {
|
|
@@ -499,10 +633,10 @@ impl Socket for UdpSocket {
|
|
|
/// https://man7.org/linux/man-pages/man7/tcp.7.html
|
|
|
#[derive(Debug, Clone)]
|
|
|
pub struct TcpSocket {
|
|
|
- handle: GlobalSocketHandle,
|
|
|
+ handle: Arc<GlobalSocketHandle>,
|
|
|
local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind()
|
|
|
is_listening: bool,
|
|
|
- options: SocketOptions,
|
|
|
+ metadata: SocketMetadata,
|
|
|
}
|
|
|
|
|
|
impl TcpSocket {
|
|
@@ -525,13 +659,22 @@ impl TcpSocket {
|
|
|
let socket = tcp::Socket::new(tx_buffer, rx_buffer);
|
|
|
|
|
|
// 把socket添加到socket集合中,并得到socket的句柄
|
|
|
- let handle: GlobalSocketHandle = GlobalSocketHandle::new(SOCKET_SET.lock().add(socket));
|
|
|
+ let handle: Arc<GlobalSocketHandle> =
|
|
|
+ GlobalSocketHandle::new(SOCKET_SET.lock().add(socket));
|
|
|
+
|
|
|
+ let metadata = SocketMetadata::new(
|
|
|
+ SocketType::TcpSocket,
|
|
|
+ Self::DEFAULT_RX_BUF_SIZE,
|
|
|
+ Self::DEFAULT_TX_BUF_SIZE,
|
|
|
+ Self::DEFAULT_METADATA_BUF_SIZE,
|
|
|
+ options,
|
|
|
+ );
|
|
|
|
|
|
return Self {
|
|
|
handle,
|
|
|
local_endpoint: None,
|
|
|
is_listening: false,
|
|
|
- options,
|
|
|
+ metadata,
|
|
|
};
|
|
|
}
|
|
|
fn do_listen(
|
|
@@ -546,7 +689,7 @@ impl TcpSocket {
|
|
|
// kdebug!("Tcp Socket Listen on {local_endpoint}");
|
|
|
socket.listen(local_endpoint)
|
|
|
};
|
|
|
- // todo: 增加端口占用检查
|
|
|
+ // TODO: 增加端口占用检查
|
|
|
return match listen_result {
|
|
|
Ok(()) => {
|
|
|
// kdebug!(
|
|
@@ -668,7 +811,7 @@ impl Socket for TcpSocket {
|
|
|
let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
|
|
|
|
|
|
if let Endpoint::Ip(Some(ip)) = endpoint {
|
|
|
- let temp_port = get_ephemeral_port();
|
|
|
+ let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
|
|
|
// kdebug!("temp_port: {}", temp_port);
|
|
|
let iface: Arc<dyn NetDriver> = NET_DRIVERS.write().get(&0).unwrap().clone();
|
|
|
let mut inner_iface = iface.inner_iface().lock();
|
|
@@ -737,9 +880,12 @@ impl Socket for TcpSocket {
|
|
|
fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
|
|
|
if let Endpoint::Ip(Some(mut ip)) = endpoint {
|
|
|
if ip.port == 0 {
|
|
|
- ip.port = get_ephemeral_port();
|
|
|
+ ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
|
|
|
}
|
|
|
|
|
|
+ // 检测端口是否已被占用
|
|
|
+ PORT_MANAGER.get_port(self.metadata.socket_type, ip.port, self.handle.clone())?;
|
|
|
+
|
|
|
self.local_endpoint = Some(ip);
|
|
|
self.is_listening = false;
|
|
|
return Ok(());
|
|
@@ -785,11 +931,19 @@ impl Socket for TcpSocket {
|
|
|
let new_handle = GlobalSocketHandle::new(sockets.add(tcp_socket));
|
|
|
let old_handle = ::core::mem::replace(&mut self.handle, new_handle);
|
|
|
|
|
|
+ let metadata = SocketMetadata {
|
|
|
+ socket_type: SocketType::TcpSocket,
|
|
|
+ send_buf_size: Self::DEFAULT_RX_BUF_SIZE,
|
|
|
+ recv_buf_size: Self::DEFAULT_TX_BUF_SIZE,
|
|
|
+ metadata_buf_size: Self::DEFAULT_METADATA_BUF_SIZE,
|
|
|
+ options: self.metadata.options,
|
|
|
+ };
|
|
|
+
|
|
|
Box::new(TcpSocket {
|
|
|
handle: old_handle,
|
|
|
local_endpoint: self.local_endpoint,
|
|
|
is_listening: false,
|
|
|
- options: self.options,
|
|
|
+ metadata,
|
|
|
})
|
|
|
};
|
|
|
// kdebug!("tcp accept: new socket: {:?}", new_socket);
|
|
@@ -825,7 +979,7 @@ impl Socket for TcpSocket {
|
|
|
}
|
|
|
|
|
|
fn metadata(&self) -> Result<SocketMetadata, SystemError> {
|
|
|
- todo!()
|
|
|
+ Ok(self.metadata.clone())
|
|
|
}
|
|
|
|
|
|
fn box_clone(&self) -> alloc::boxed::Box<dyn Socket> {
|
|
@@ -833,26 +987,6 @@ impl Socket for TcpSocket {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-/// @breif 自动分配一个未被使用的PORT
|
|
|
-///
|
|
|
-/// TODO: 增加ListenTable, 用于检查端口是否被占用
|
|
|
-pub fn get_ephemeral_port() -> u16 {
|
|
|
- // TODO selects non-conflict high port
|
|
|
-
|
|
|
- static mut EPHEMERAL_PORT: u16 = 0;
|
|
|
- unsafe {
|
|
|
- if EPHEMERAL_PORT == 0 {
|
|
|
- EPHEMERAL_PORT = (49152 + rand() % (65536 - 49152)) as u16;
|
|
|
- }
|
|
|
- if EPHEMERAL_PORT == 65535 {
|
|
|
- EPHEMERAL_PORT = 49152;
|
|
|
- } else {
|
|
|
- EPHEMERAL_PORT = EPHEMERAL_PORT + 1;
|
|
|
- }
|
|
|
- EPHEMERAL_PORT
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
/// @brief 地址族的枚举
|
|
|
///
|
|
|
/// 参考:https://opengrok.ringotek.cn/xref/linux-5.19.10/include/linux/socket.h#180
|
|
@@ -1012,6 +1146,10 @@ impl IndexNode for SocketInode {
|
|
|
&self,
|
|
|
_data: &mut crate::filesystem::vfs::FilePrivateData,
|
|
|
) -> Result<(), SystemError> {
|
|
|
+ let socket = self.0.lock();
|
|
|
+ if let Some(Endpoint::Ip(Some(ip))) = socket.endpoint() {
|
|
|
+ PORT_MANAGER.unbind_port(socket.metadata().unwrap().socket_type, ip.port)?;
|
|
|
+ }
|
|
|
return Ok(());
|
|
|
}
|
|
|
|