socket.rs 42 KB


  1. #![allow(dead_code)]
  2. use alloc::{boxed::Box, sync::Arc, vec::Vec};
  3. use hashbrown::HashMap;
  4. use smoltcp::{
  5. iface::{SocketHandle, SocketSet},
  6. socket::{raw, tcp, udp},
  7. wire,
  8. };
  9. use crate::{
  10. arch::rand::rand,
  11. driver::net::NetDriver,
  12. filesystem::vfs::{syscall::ModeType, FileType, IndexNode, Metadata, PollStatus},
  13. kerror, kwarn,
  14. libs::{
  15. spinlock::{SpinLock, SpinLockGuard},
  16. wait_queue::WaitQueue,
  17. },
  18. syscall::SystemError,
  19. };
  20. use super::{net_core::poll_ifaces, Endpoint, Protocol, Socket, NET_DRIVERS};
  21. lazy_static! {
  22. /// 所有socket的集合
  23. /// TODO: 优化这里,自己实现SocketSet!!!现在这样的话,不管全局有多少个网卡,每个时间点都只会有1个进程能够访问socket
  24. pub static ref SOCKET_SET: SpinLock<SocketSet<'static >> = SpinLock::new(SocketSet::new(vec![]));
  25. pub static ref SOCKET_WAITQUEUE: WaitQueue = WaitQueue::INIT;
  26. /// 端口管理器
  27. pub static ref PORT_MANAGER: PortManager = PortManager::new();
  28. }
  29. /// @brief TCP 和 UDP 的端口管理器。
  30. /// 如果 TCP/UDP 的 socket 绑定了某个端口,它会在对应的表中记录,以检测端口冲突。
  31. pub struct PortManager {
  32. // TCP 端口记录表
  33. tcp_port_table: SpinLock<HashMap<u16, Arc<GlobalSocketHandle>>>,
  34. // UDP 端口记录表
  35. udp_port_table: SpinLock<HashMap<u16, Arc<GlobalSocketHandle>>>,
  36. }
  37. impl PortManager {
  38. pub fn new() -> Self {
  39. return Self {
  40. tcp_port_table: SpinLock::new(HashMap::new()),
  41. udp_port_table: SpinLock::new(HashMap::new()),
  42. };
  43. }
  44. /// @brief 自动分配一个相对应协议中未被使用的PORT,如果动态端口均已被占用,返回错误码 EADDRINUSE
  45. pub fn get_ephemeral_port(&self, socket_type: SocketType) -> Result<u16, SystemError> {
  46. // TODO: selects non-conflict high port
  47. static mut EPHEMERAL_PORT: u16 = 0;
  48. unsafe {
  49. if EPHEMERAL_PORT == 0 {
  50. EPHEMERAL_PORT = (49152 + rand() % (65536 - 49152)) as u16;
  51. }
  52. }
  53. let mut remaining = 65536 - 49152; // 剩余尝试分配端口次数
  54. let mut port: u16;
  55. while remaining > 0 {
  56. unsafe {
  57. if EPHEMERAL_PORT == 65535 {
  58. EPHEMERAL_PORT = 49152;
  59. } else {
  60. EPHEMERAL_PORT = EPHEMERAL_PORT + 1;
  61. }
  62. port = EPHEMERAL_PORT;
  63. }
  64. // 使用 ListenTable 检查端口是否被占用
  65. let listen_table_guard = match socket_type {
  66. SocketType::UdpSocket => self.udp_port_table.lock(),
  67. SocketType::TcpSocket => self.tcp_port_table.lock(),
  68. SocketType::RawSocket => panic!("RawSocket cann't get a port"),
  69. };
  70. if let None = listen_table_guard.get(&port) {
  71. drop(listen_table_guard);
  72. return Ok(port);
  73. }
  74. remaining -= 1;
  75. }
  76. return Err(SystemError::EADDRINUSE);
  77. }
  78. /// @brief 检测给定端口是否已被占用,如果未被占用则在 TCP/UDP 对应的表中记录
  79. ///
  80. /// TODO: 增加支持端口复用的逻辑
  81. pub fn bind_port(
  82. &self,
  83. socket_type: SocketType,
  84. port: u16,
  85. handle: Arc<GlobalSocketHandle>,
  86. ) -> Result<(), SystemError> {
  87. if port > 0 {
  88. let mut listen_table_guard = match socket_type {
  89. SocketType::UdpSocket => self.udp_port_table.lock(),
  90. SocketType::TcpSocket => self.tcp_port_table.lock(),
  91. SocketType::RawSocket => panic!("RawSocket cann't bind a port"),
  92. };
  93. match listen_table_guard.get(&port) {
  94. Some(_) => return Err(SystemError::EADDRINUSE),
  95. None => listen_table_guard.insert(port, handle),
  96. };
  97. drop(listen_table_guard);
  98. }
  99. return Ok(());
  100. }
  101. /// @brief 在对应的端口记录表中将端口和 socket 解绑
  102. pub fn unbind_port(&self, socket_type: SocketType, port: u16) -> Result<(), SystemError> {
  103. let mut listen_table_guard = match socket_type {
  104. SocketType::UdpSocket => self.udp_port_table.lock(),
  105. SocketType::TcpSocket => self.tcp_port_table.lock(),
  106. SocketType::RawSocket => return Ok(()),
  107. };
  108. listen_table_guard.remove(&port);
  109. drop(listen_table_guard);
  110. return Ok(());
  111. }
  112. }
  113. /* For setsockopt(2) */
  114. // See: linux-5.19.10/include/uapi/asm-generic/socket.h#9
  115. pub const SOL_SOCKET: u8 = 1;
  116. /// @brief socket的句柄管理组件。
  117. /// 它在smoltcp的SocketHandle上封装了一层,增加更多的功能。
  118. /// 比如,在socket被关闭时,自动释放socket的资源,通知系统的其他组件。
  119. #[derive(Debug)]
  120. pub struct GlobalSocketHandle(SocketHandle);
  121. impl GlobalSocketHandle {
  122. pub fn new(handle: SocketHandle) -> Arc<Self> {
  123. return Arc::new(Self(handle));
  124. }
  125. }
  126. impl Clone for GlobalSocketHandle {
  127. fn clone(&self) -> Self {
  128. Self(self.0)
  129. }
  130. }
  131. impl Drop for GlobalSocketHandle {
  132. fn drop(&mut self) {
  133. let mut socket_set_guard = SOCKET_SET.lock();
  134. socket_set_guard.remove(self.0); // 删除的时候,会发送一条FINISH的信息?
  135. drop(socket_set_guard);
  136. poll_ifaces();
  137. }
  138. }
  139. /// @brief socket的类型
  140. #[derive(Debug, Clone, Copy)]
  141. pub enum SocketType {
  142. /// 原始的socket
  143. RawSocket,
  144. /// 用于Tcp通信的 Socket
  145. TcpSocket,
  146. /// 用于Udp通信的 Socket
  147. UdpSocket,
  148. }
  149. bitflags! {
  150. /// @brief socket的选项
  151. #[derive(Default)]
  152. pub struct SocketOptions: u32 {
  153. /// 是否阻塞
  154. const BLOCK = 1 << 0;
  155. /// 是否允许广播
  156. const BROADCAST = 1 << 1;
  157. /// 是否允许多播
  158. const MULTICAST = 1 << 2;
  159. /// 是否允许重用地址
  160. const REUSEADDR = 1 << 3;
  161. /// 是否允许重用端口
  162. const REUSEPORT = 1 << 4;
  163. }
  164. }
  165. #[derive(Debug, Clone)]
  166. /// @brief 在trait Socket的metadata函数中返回该结构体供外部使用
  167. pub struct SocketMetadata {
  168. /// socket的类型
  169. pub socket_type: SocketType,
  170. /// 发送缓冲区的大小
  171. pub send_buf_size: usize,
  172. /// 接收缓冲区的大小
  173. pub recv_buf_size: usize,
  174. /// 元数据的缓冲区的大小
  175. pub metadata_buf_size: usize,
  176. /// socket的选项
  177. pub options: SocketOptions,
  178. }
  179. impl SocketMetadata {
  180. fn new(
  181. socket_type: SocketType,
  182. send_buf_size: usize,
  183. recv_buf_size: usize,
  184. metadata_buf_size: usize,
  185. options: SocketOptions,
  186. ) -> Self {
  187. Self {
  188. socket_type,
  189. send_buf_size,
  190. recv_buf_size,
  191. metadata_buf_size,
  192. options,
  193. }
  194. }
  195. }
  196. /// @brief 表示原始的socket。原始套接字绕过传输层协议(如 TCP 或 UDP)并提供对网络层协议(如 IP)的直接访问。
  197. ///
  198. /// ref: https://man7.org/linux/man-pages/man7/raw.7.html
  199. #[derive(Debug, Clone)]
  200. pub struct RawSocket {
  201. handle: Arc<GlobalSocketHandle>,
  202. /// 用户发送的数据包是否包含了IP头.
  203. /// 如果是true,用户发送的数据包,必须包含IP头。(即用户要自行设置IP头+数据)
  204. /// 如果是false,用户发送的数据包,不包含IP头。(即用户只要设置数据)
  205. header_included: bool,
  206. /// socket的metadata
  207. metadata: SocketMetadata,
  208. }
  209. impl RawSocket {
  210. /// 元数据的缓冲区的大小
  211. pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
  212. /// 默认的发送缓冲区的大小 transmiss
  213. pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024;
  214. /// 默认的接收缓冲区的大小 receive
  215. pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024;
  216. /// @brief 创建一个原始的socket
  217. ///
  218. /// @param protocol 协议号
  219. /// @param options socket的选项
  220. ///
  221. /// @return 返回创建的原始的socket
  222. pub fn new(protocol: Protocol, options: SocketOptions) -> Self {
  223. let tx_buffer = raw::PacketBuffer::new(
  224. vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
  225. vec![0; Self::DEFAULT_TX_BUF_SIZE],
  226. );
  227. let rx_buffer = raw::PacketBuffer::new(
  228. vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
  229. vec![0; Self::DEFAULT_RX_BUF_SIZE],
  230. );
  231. let protocol: u8 = protocol.into();
  232. let socket = raw::Socket::new(
  233. smoltcp::wire::IpVersion::Ipv4,
  234. wire::IpProtocol::from(protocol),
  235. tx_buffer,
  236. rx_buffer,
  237. );
  238. // 把socket添加到socket集合中,并得到socket的句柄
  239. let handle: Arc<GlobalSocketHandle> =
  240. GlobalSocketHandle::new(SOCKET_SET.lock().add(socket));
  241. let metadata = SocketMetadata::new(
  242. SocketType::RawSocket,
  243. Self::DEFAULT_RX_BUF_SIZE,
  244. Self::DEFAULT_TX_BUF_SIZE,
  245. Self::DEFAULT_METADATA_BUF_SIZE,
  246. options,
  247. );
  248. return Self {
  249. handle,
  250. header_included: false,
  251. metadata,
  252. };
  253. }
  254. }
  255. impl Socket for RawSocket {
  256. fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
  257. poll_ifaces();
  258. loop {
  259. // 如何优化这里?
  260. let mut socket_set_guard = SOCKET_SET.lock();
  261. let socket = socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
  262. match socket.recv_slice(buf) {
  263. Ok(len) => {
  264. let packet = wire::Ipv4Packet::new_unchecked(buf);
  265. return (
  266. Ok(len),
  267. Endpoint::Ip(Some(smoltcp::wire::IpEndpoint {
  268. addr: wire::IpAddress::Ipv4(packet.src_addr()),
  269. port: 0,
  270. })),
  271. );
  272. }
  273. Err(smoltcp::socket::raw::RecvError::Exhausted) => {
  274. if !self.metadata.options.contains(SocketOptions::BLOCK) {
  275. // 如果是非阻塞的socket,就返回错误
  276. return (Err(SystemError::EAGAIN_OR_EWOULDBLOCK), Endpoint::Ip(None));
  277. }
  278. }
  279. }
  280. drop(socket);
  281. drop(socket_set_guard);
  282. SOCKET_WAITQUEUE.sleep();
  283. }
  284. }
  285. fn write(&self, buf: &[u8], to: Option<super::Endpoint>) -> Result<usize, SystemError> {
  286. // 如果用户发送的数据包,包含IP头,则直接发送
  287. if self.header_included {
  288. let mut socket_set_guard = SOCKET_SET.lock();
  289. let socket = socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
  290. match socket.send_slice(buf) {
  291. Ok(_len) => {
  292. return Ok(buf.len());
  293. }
  294. Err(smoltcp::socket::raw::SendError::BufferFull) => {
  295. return Err(SystemError::ENOBUFS);
  296. }
  297. }
  298. } else {
  299. // 如果用户发送的数据包,不包含IP头,则需要自己构造IP头
  300. if let Some(Endpoint::Ip(Some(endpoint))) = to {
  301. let mut socket_set_guard = SOCKET_SET.lock();
  302. let socket: &mut raw::Socket =
  303. socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
  304. // 暴力解决方案:只考虑0号网卡。 TODO:考虑多网卡的情况!!!
  305. let iface = NET_DRIVERS.read().get(&0).unwrap().clone();
  306. // 构造IP头
  307. let ipv4_src_addr: Option<smoltcp::wire::Ipv4Address> =
  308. iface.inner_iface().lock().ipv4_addr();
  309. if ipv4_src_addr.is_none() {
  310. return Err(SystemError::ENETUNREACH);
  311. }
  312. let ipv4_src_addr = ipv4_src_addr.unwrap();
  313. if let wire::IpAddress::Ipv4(ipv4_dst) = endpoint.addr {
  314. let len = buf.len();
  315. // 创建20字节的IPv4头部
  316. let mut buffer: Vec<u8> = vec![0u8; len + 20];
  317. let mut packet: wire::Ipv4Packet<&mut Vec<u8>> =
  318. wire::Ipv4Packet::new_unchecked(&mut buffer);
  319. // 封装ipv4 header
  320. packet.set_version(4);
  321. packet.set_header_len(20);
  322. packet.set_total_len((20 + len) as u16);
  323. packet.set_src_addr(ipv4_src_addr);
  324. packet.set_dst_addr(ipv4_dst);
  325. // 设置ipv4 header的protocol字段
  326. packet.set_next_header(socket.ip_protocol().into());
  327. // 获取IP数据包的负载字段
  328. let payload: &mut [u8] = packet.payload_mut();
  329. payload.copy_from_slice(buf);
  330. // 填充checksum字段
  331. packet.fill_checksum();
  332. // 发送数据包
  333. socket.send_slice(&buffer).unwrap();
  334. drop(socket);
  335. iface.poll(&mut socket_set_guard).ok();
  336. drop(socket_set_guard);
  337. return Ok(len);
  338. } else {
  339. kwarn!("Unsupport Ip protocol type!");
  340. return Err(SystemError::EINVAL);
  341. }
  342. } else {
  343. // 如果没有指定目的地址,则返回错误
  344. return Err(SystemError::ENOTCONN);
  345. }
  346. }
  347. }
  348. fn connect(&mut self, _endpoint: super::Endpoint) -> Result<(), SystemError> {
  349. return Ok(());
  350. }
  351. fn metadata(&self) -> Result<SocketMetadata, SystemError> {
  352. Ok(self.metadata.clone())
  353. }
  354. fn box_clone(&self) -> alloc::boxed::Box<dyn Socket> {
  355. return Box::new(self.clone());
  356. }
  357. }
  358. /// @brief 表示udp socket
  359. ///
  360. /// https://man7.org/linux/man-pages/man7/udp.7.html
  361. #[derive(Debug, Clone)]
  362. pub struct UdpSocket {
  363. pub handle: Arc<GlobalSocketHandle>,
  364. remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect(), 应该使用IP地址。
  365. metadata: SocketMetadata,
  366. }
  367. impl UdpSocket {
  368. /// 元数据的缓冲区的大小
  369. pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
  370. /// 默认的发送缓冲区的大小 transmiss
  371. pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024;
  372. /// 默认的接收缓冲区的大小 receive
  373. pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024;
  374. /// @brief 创建一个原始的socket
  375. ///
  376. /// @param protocol 协议号
  377. /// @param options socket的选项
  378. ///
  379. /// @return 返回创建的原始的socket
  380. pub fn new(options: SocketOptions) -> Self {
  381. let tx_buffer = udp::PacketBuffer::new(
  382. vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
  383. vec![0; Self::DEFAULT_TX_BUF_SIZE],
  384. );
  385. let rx_buffer = udp::PacketBuffer::new(
  386. vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
  387. vec![0; Self::DEFAULT_RX_BUF_SIZE],
  388. );
  389. let socket = udp::Socket::new(tx_buffer, rx_buffer);
  390. // 把socket添加到socket集合中,并得到socket的句柄
  391. let handle: Arc<GlobalSocketHandle> =
  392. GlobalSocketHandle::new(SOCKET_SET.lock().add(socket));
  393. let metadata = SocketMetadata::new(
  394. SocketType::UdpSocket,
  395. Self::DEFAULT_RX_BUF_SIZE,
  396. Self::DEFAULT_TX_BUF_SIZE,
  397. Self::DEFAULT_METADATA_BUF_SIZE,
  398. options,
  399. );
  400. return Self {
  401. handle,
  402. remote_endpoint: None,
  403. metadata,
  404. };
  405. }
  406. fn do_bind(&self, socket: &mut udp::Socket, endpoint: Endpoint) -> Result<(), SystemError> {
  407. if let Endpoint::Ip(Some(ip)) = endpoint {
  408. // 检测端口是否已被占用
  409. PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port, self.handle.clone())?;
  410. let bind_res = if ip.addr.is_unspecified() {
  411. socket.bind(ip.port)
  412. } else {
  413. socket.bind(ip)
  414. };
  415. match bind_res {
  416. Ok(()) => return Ok(()),
  417. Err(_) => return Err(SystemError::EINVAL),
  418. }
  419. } else {
  420. return Err(SystemError::EINVAL);
  421. };
  422. }
  423. }
  424. impl Socket for UdpSocket {
  425. /// @brief 在read函数执行之前,请先bind到本地的指定端口
  426. fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
  427. loop {
  428. // kdebug!("Wait22 to Read");
  429. poll_ifaces();
  430. let mut socket_set_guard = SOCKET_SET.lock();
  431. let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.0);
  432. // kdebug!("Wait to Read");
  433. if socket.can_recv() {
  434. if let Ok((size, remote_endpoint)) = socket.recv_slice(buf) {
  435. drop(socket);
  436. drop(socket_set_guard);
  437. poll_ifaces();
  438. return (Ok(size), Endpoint::Ip(Some(remote_endpoint)));
  439. }
  440. } else {
  441. // 如果socket没有连接,则忙等
  442. // return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
  443. }
  444. drop(socket);
  445. drop(socket_set_guard);
  446. SOCKET_WAITQUEUE.sleep();
  447. }
  448. }
  449. fn write(&self, buf: &[u8], to: Option<super::Endpoint>) -> Result<usize, SystemError> {
  450. // kdebug!("udp to send: {:?}, len={}", to, buf.len());
  451. let remote_endpoint: &wire::IpEndpoint = {
  452. if let Some(Endpoint::Ip(Some(ref endpoint))) = to {
  453. endpoint
  454. } else if let Some(Endpoint::Ip(Some(ref endpoint))) = self.remote_endpoint {
  455. endpoint
  456. } else {
  457. return Err(SystemError::ENOTCONN);
  458. }
  459. };
  460. // kdebug!("udp write: remote = {:?}", remote_endpoint);
  461. let mut socket_set_guard = SOCKET_SET.lock();
  462. let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.0);
  463. // kdebug!("is open()={}", socket.is_open());
  464. // kdebug!("socket endpoint={:?}", socket.endpoint());
  465. if socket.endpoint().port == 0 {
  466. let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
  467. let local_ep = match remote_endpoint.addr {
  468. // 远程remote endpoint使用什么协议,发送的时候使用的协议是一样的吧
  469. // 否则就用 self.endpoint().addr.unwrap()
  470. wire::IpAddress::Ipv4(_) => Endpoint::Ip(Some(wire::IpEndpoint::new(
  471. smoltcp::wire::IpAddress::Ipv4(wire::Ipv4Address::UNSPECIFIED),
  472. temp_port,
  473. ))),
  474. wire::IpAddress::Ipv6(_) => Endpoint::Ip(Some(wire::IpEndpoint::new(
  475. smoltcp::wire::IpAddress::Ipv6(wire::Ipv6Address::UNSPECIFIED),
  476. temp_port,
  477. ))),
  478. };
  479. // kdebug!("udp write: local_ep = {:?}", local_ep);
  480. self.do_bind(socket, local_ep)?;
  481. }
  482. // kdebug!("is open()={}", socket.is_open());
  483. if socket.can_send() {
  484. // kdebug!("udp write: can send");
  485. match socket.send_slice(&buf, *remote_endpoint) {
  486. Ok(()) => {
  487. // kdebug!("udp write: send ok");
  488. drop(socket);
  489. drop(socket_set_guard);
  490. poll_ifaces();
  491. return Ok(buf.len());
  492. }
  493. Err(_) => {
  494. // kdebug!("udp write: send err");
  495. return Err(SystemError::ENOBUFS);
  496. }
  497. }
  498. } else {
  499. // kdebug!("udp write: can not send");
  500. return Err(SystemError::ENOBUFS);
  501. };
  502. }
  503. fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
  504. let mut sockets = SOCKET_SET.lock();
  505. let socket = sockets.get_mut::<udp::Socket>(self.handle.0);
  506. // kdebug!("UDP Bind to {:?}", endpoint);
  507. return self.do_bind(socket, endpoint);
  508. }
  509. fn poll(&self) -> (bool, bool, bool) {
  510. let sockets = SOCKET_SET.lock();
  511. let socket = sockets.get::<udp::Socket>(self.handle.0);
  512. return (socket.can_send(), socket.can_recv(), false);
  513. }
  514. /// @brief
  515. fn connect(&mut self, endpoint: super::Endpoint) -> Result<(), SystemError> {
  516. if let Endpoint::Ip(_) = endpoint {
  517. self.remote_endpoint = Some(endpoint);
  518. return Ok(());
  519. } else {
  520. return Err(SystemError::EINVAL);
  521. };
  522. }
  523. fn ioctl(
  524. &self,
  525. _cmd: usize,
  526. _arg0: usize,
  527. _arg1: usize,
  528. _arg2: usize,
  529. ) -> Result<usize, SystemError> {
  530. todo!()
  531. }
  532. fn metadata(&self) -> Result<SocketMetadata, SystemError> {
  533. Ok(self.metadata.clone())
  534. }
  535. fn box_clone(&self) -> alloc::boxed::Box<dyn Socket> {
  536. return Box::new(self.clone());
  537. }
  538. fn endpoint(&self) -> Option<Endpoint> {
  539. let sockets = SOCKET_SET.lock();
  540. let socket = sockets.get::<udp::Socket>(self.handle.0);
  541. let listen_endpoint = socket.endpoint();
  542. if listen_endpoint.port == 0 {
  543. return None;
  544. } else {
  545. // 如果listen_endpoint的address是None,意味着“监听所有的地址”。
  546. // 这里假设所有的地址都是ipv4
  547. // TODO: 支持ipv6
  548. let result = wire::IpEndpoint::new(
  549. listen_endpoint
  550. .addr
  551. .unwrap_or(wire::IpAddress::v4(0, 0, 0, 0)),
  552. listen_endpoint.port,
  553. );
  554. return Some(Endpoint::Ip(Some(result)));
  555. }
  556. }
  557. fn peer_endpoint(&self) -> Option<Endpoint> {
  558. return self.remote_endpoint.clone();
  559. }
  560. }
  561. /// @brief 表示 tcp socket
  562. ///
  563. /// https://man7.org/linux/man-pages/man7/tcp.7.html
  564. #[derive(Debug, Clone)]
  565. pub struct TcpSocket {
  566. handle: Arc<GlobalSocketHandle>,
  567. local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind()
  568. is_listening: bool,
  569. metadata: SocketMetadata,
  570. }
  571. impl TcpSocket {
  572. /// 元数据的缓冲区的大小
  573. pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
  574. /// 默认的发送缓冲区的大小 transmiss
  575. pub const DEFAULT_RX_BUF_SIZE: usize = 512 * 1024;
  576. /// 默认的接收缓冲区的大小 receive
  577. pub const DEFAULT_TX_BUF_SIZE: usize = 512 * 1024;
  578. /// @brief 创建一个原始的socket
  579. ///
  580. /// @param protocol 协议号
  581. /// @param options socket的选项
  582. ///
  583. /// @return 返回创建的原始的socket
  584. pub fn new(options: SocketOptions) -> Self {
  585. let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]);
  586. let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]);
  587. let socket = tcp::Socket::new(tx_buffer, rx_buffer);
  588. // 把socket添加到socket集合中,并得到socket的句柄
  589. let handle: Arc<GlobalSocketHandle> =
  590. GlobalSocketHandle::new(SOCKET_SET.lock().add(socket));
  591. let metadata = SocketMetadata::new(
  592. SocketType::TcpSocket,
  593. Self::DEFAULT_RX_BUF_SIZE,
  594. Self::DEFAULT_TX_BUF_SIZE,
  595. Self::DEFAULT_METADATA_BUF_SIZE,
  596. options,
  597. );
  598. return Self {
  599. handle,
  600. local_endpoint: None,
  601. is_listening: false,
  602. metadata,
  603. };
  604. }
  605. fn do_listen(
  606. &mut self,
  607. socket: &mut smoltcp::socket::tcp::Socket,
  608. local_endpoint: smoltcp::wire::IpEndpoint,
  609. ) -> Result<(), SystemError> {
  610. let listen_result = if local_endpoint.addr.is_unspecified() {
  611. // kdebug!("Tcp Socket Listen on port {}", local_endpoint.port);
  612. socket.listen(local_endpoint.port)
  613. } else {
  614. // kdebug!("Tcp Socket Listen on {local_endpoint}");
  615. socket.listen(local_endpoint)
  616. };
  617. // TODO: 增加端口占用检查
  618. return match listen_result {
  619. Ok(()) => {
  620. // kdebug!(
  621. // "Tcp Socket Listen on {local_endpoint}, open?:{}",
  622. // socket.is_open()
  623. // );
  624. self.is_listening = true;
  625. Ok(())
  626. }
  627. Err(_) => Err(SystemError::EINVAL),
  628. };
  629. }
  630. }
  631. impl Socket for TcpSocket {
  632. fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
  633. // kdebug!("tcp socket: read, buf len={}", buf.len());
  634. loop {
  635. poll_ifaces();
  636. let mut socket_set_guard = SOCKET_SET.lock();
  637. let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handle.0);
  638. // 如果socket已经关闭,返回错误
  639. if !socket.is_active() {
  640. // kdebug!("Tcp Socket Read Error, socket is closed");
  641. return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
  642. }
  643. if socket.may_recv() {
  644. let recv_res = socket.recv_slice(buf);
  645. if let Ok(size) = recv_res {
  646. if size > 0 {
  647. let endpoint = if let Some(p) = socket.remote_endpoint() {
  648. p
  649. } else {
  650. return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
  651. };
  652. drop(socket);
  653. drop(socket_set_guard);
  654. poll_ifaces();
  655. return (Ok(size), Endpoint::Ip(Some(endpoint)));
  656. }
  657. } else {
  658. let err = recv_res.unwrap_err();
  659. match err {
  660. tcp::RecvError::InvalidState => {
  661. kwarn!("Tcp Socket Read Error, InvalidState");
  662. return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
  663. }
  664. tcp::RecvError::Finished => {
  665. return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
  666. }
  667. }
  668. }
  669. } else {
  670. return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
  671. }
  672. drop(socket);
  673. drop(socket_set_guard);
  674. SOCKET_WAITQUEUE.sleep();
  675. }
  676. }
  677. fn write(&self, buf: &[u8], _to: Option<super::Endpoint>) -> Result<usize, SystemError> {
  678. let mut socket_set_guard = SOCKET_SET.lock();
  679. let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handle.0);
  680. if socket.is_open() {
  681. if socket.can_send() {
  682. match socket.send_slice(buf) {
  683. Ok(size) => {
  684. drop(socket);
  685. drop(socket_set_guard);
  686. poll_ifaces();
  687. return Ok(size);
  688. }
  689. Err(e) => {
  690. kerror!("Tcp Socket Write Error {e:?}");
  691. return Err(SystemError::ENOBUFS);
  692. }
  693. }
  694. } else {
  695. return Err(SystemError::ENOBUFS);
  696. }
  697. }
  698. return Err(SystemError::ENOTCONN);
  699. }
  700. fn poll(&self) -> (bool, bool, bool) {
  701. let mut socket_set_guard = SOCKET_SET.lock();
  702. let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handle.0);
  703. let mut input = false;
  704. let mut output = false;
  705. let mut error = false;
  706. if self.is_listening && socket.is_active() {
  707. input = true;
  708. } else if !socket.is_open() {
  709. error = true;
  710. } else {
  711. if socket.may_recv() {
  712. input = true;
  713. }
  714. if socket.can_send() {
  715. output = true;
  716. }
  717. }
  718. return (input, output, error);
  719. }
  720. fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
  721. let mut sockets = SOCKET_SET.lock();
  722. let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
  723. if let Endpoint::Ip(Some(ip)) = endpoint {
  724. let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
  725. // 检测端口是否被占用
  726. PORT_MANAGER.bind_port(self.metadata.socket_type, temp_port, self.handle.clone())?;
  727. // kdebug!("temp_port: {}", temp_port);
  728. let iface: Arc<dyn NetDriver> = NET_DRIVERS.write().get(&0).unwrap().clone();
  729. let mut inner_iface = iface.inner_iface().lock();
  730. // kdebug!("to connect: {ip:?}");
  731. match socket.connect(&mut inner_iface.context(), ip, temp_port) {
  732. Ok(()) => {
  733. // avoid deadlock
  734. drop(inner_iface);
  735. drop(iface);
  736. drop(socket);
  737. drop(sockets);
  738. loop {
  739. poll_ifaces();
  740. let mut sockets = SOCKET_SET.lock();
  741. let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
  742. match socket.state() {
  743. tcp::State::Established => {
  744. return Ok(());
  745. }
  746. tcp::State::SynSent => {
  747. drop(socket);
  748. drop(sockets);
  749. SOCKET_WAITQUEUE.sleep();
  750. }
  751. _ => {
  752. return Err(SystemError::ECONNREFUSED);
  753. }
  754. }
  755. }
  756. }
  757. Err(e) => {
  758. // kerror!("Tcp Socket Connect Error {e:?}");
  759. match e {
  760. tcp::ConnectError::InvalidState => return Err(SystemError::EISCONN),
  761. tcp::ConnectError::Unaddressable => return Err(SystemError::EADDRNOTAVAIL),
  762. }
  763. }
  764. }
  765. } else {
  766. return Err(SystemError::EINVAL);
  767. }
  768. }
  769. /// @brief tcp socket 监听 local_endpoint 端口
  770. ///
  771. /// @param backlog 未处理的连接队列的最大长度. 由于smoltcp不支持backlog,所以这个参数目前无效
  772. fn listen(&mut self, _backlog: usize) -> Result<(), SystemError> {
  773. if self.is_listening {
  774. return Ok(());
  775. }
  776. let local_endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?;
  777. let mut sockets = SOCKET_SET.lock();
  778. let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
  779. if socket.is_listening() {
  780. // kdebug!("Tcp Socket is already listening on {local_endpoint}");
  781. return Ok(());
  782. }
  783. // kdebug!("Tcp Socket before listen, open={}", socket.is_open());
  784. return self.do_listen(socket, local_endpoint);
  785. }
  786. fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
  787. if let Endpoint::Ip(Some(mut ip)) = endpoint {
  788. if ip.port == 0 {
  789. ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
  790. }
  791. // 检测端口是否已被占用
  792. PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port, self.handle.clone())?;
  793. self.local_endpoint = Some(ip);
  794. self.is_listening = false;
  795. return Ok(());
  796. }
  797. return Err(SystemError::EINVAL);
  798. }
  799. fn shutdown(&self, _type: super::ShutdownType) -> Result<(), SystemError> {
  800. let mut sockets = SOCKET_SET.lock();
  801. let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
  802. socket.close();
  803. return Ok(());
  804. }
  805. fn accept(&mut self) -> Result<(Box<dyn Socket>, Endpoint), SystemError> {
  806. let endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?;
  807. loop {
  808. // kdebug!("tcp accept: poll_ifaces()");
  809. poll_ifaces();
  810. let mut sockets = SOCKET_SET.lock();
  811. let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
  812. if socket.is_active() {
  813. // kdebug!("tcp accept: socket.is_active()");
  814. let remote_ep = socket.remote_endpoint().ok_or(SystemError::ENOTCONN)?;
  815. drop(socket);
  816. let new_socket = {
  817. // Initialize the TCP socket's buffers.
  818. let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]);
  819. let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]);
  820. // The new TCP socket used for sending and receiving data.
  821. let mut tcp_socket = tcp::Socket::new(rx_buffer, tx_buffer);
  822. self.do_listen(&mut tcp_socket, endpoint)
  823. .expect("do_listen failed");
  824. // tcp_socket.listen(endpoint).unwrap();
  825. // 之所以把old_handle存入new_socket, 是因为当前时刻,smoltcp已经把old_handle对应的socket与远程的endpoint关联起来了
  826. // 因此需要再为当前的socket分配一个新的handle
  827. let new_handle = GlobalSocketHandle::new(sockets.add(tcp_socket));
  828. let old_handle = ::core::mem::replace(&mut self.handle, new_handle.clone());
  829. // 更新端口与 handle 的绑定
  830. if let Some(Endpoint::Ip(Some(ip))) = self.endpoint() {
  831. PORT_MANAGER.unbind_port(self.metadata.socket_type, ip.port)?;
  832. PORT_MANAGER.bind_port(
  833. self.metadata.socket_type,
  834. ip.port,
  835. new_handle.clone(),
  836. )?;
  837. }
  838. let metadata = SocketMetadata::new(
  839. SocketType::TcpSocket,
  840. Self::DEFAULT_RX_BUF_SIZE,
  841. Self::DEFAULT_TX_BUF_SIZE,
  842. Self::DEFAULT_METADATA_BUF_SIZE,
  843. self.metadata.options,
  844. );
  845. Box::new(TcpSocket {
  846. handle: old_handle,
  847. local_endpoint: self.local_endpoint,
  848. is_listening: false,
  849. metadata,
  850. })
  851. };
  852. // kdebug!("tcp accept: new socket: {:?}", new_socket);
  853. drop(sockets);
  854. poll_ifaces();
  855. return Ok((new_socket, Endpoint::Ip(Some(remote_ep))));
  856. }
  857. drop(socket);
  858. drop(sockets);
  859. SOCKET_WAITQUEUE.sleep();
  860. }
  861. }
  862. fn endpoint(&self) -> Option<Endpoint> {
  863. let mut result: Option<Endpoint> =
  864. self.local_endpoint.clone().map(|x| Endpoint::Ip(Some(x)));
  865. if result.is_none() {
  866. let sockets = SOCKET_SET.lock();
  867. let socket = sockets.get::<tcp::Socket>(self.handle.0);
  868. if let Some(ep) = socket.local_endpoint() {
  869. result = Some(Endpoint::Ip(Some(ep)));
  870. }
  871. }
  872. return result;
  873. }
  874. fn peer_endpoint(&self) -> Option<Endpoint> {
  875. let sockets = SOCKET_SET.lock();
  876. let socket = sockets.get::<tcp::Socket>(self.handle.0);
  877. return socket.remote_endpoint().map(|x| Endpoint::Ip(Some(x)));
  878. }
  879. fn metadata(&self) -> Result<SocketMetadata, SystemError> {
  880. Ok(self.metadata.clone())
  881. }
  882. fn box_clone(&self) -> alloc::boxed::Box<dyn Socket> {
  883. return Box::new(self.clone());
  884. }
  885. }
  886. /// @brief 地址族的枚举
  887. ///
  888. /// 参考:https://opengrok.ringotek.cn/xref/linux-5.19.10/include/linux/socket.h#180
  889. #[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)]
  890. pub enum AddressFamily {
  891. /// AF_UNSPEC 表示地址族未指定
  892. Unspecified = 0,
  893. /// AF_UNIX 表示Unix域的socket (与AF_LOCAL相同)
  894. Unix = 1,
  895. /// AF_INET 表示IPv4的socket
  896. INet = 2,
  897. /// AF_AX25 表示AMPR AX.25的socket
  898. AX25 = 3,
  899. /// AF_IPX 表示IPX的socket
  900. IPX = 4,
  901. /// AF_APPLETALK 表示Appletalk的socket
  902. Appletalk = 5,
  903. /// AF_NETROM 表示AMPR NET/ROM的socket
  904. Netrom = 6,
  905. /// AF_BRIDGE 表示多协议桥接的socket
  906. Bridge = 7,
  907. /// AF_ATMPVC 表示ATM PVCs的socket
  908. Atmpvc = 8,
  909. /// AF_X25 表示X.25的socket
  910. X25 = 9,
  911. /// AF_INET6 表示IPv6的socket
  912. INet6 = 10,
  913. /// AF_ROSE 表示AMPR ROSE的socket
  914. Rose = 11,
  915. /// AF_DECnet Reserved for DECnet project
  916. Decnet = 12,
  917. /// AF_NETBEUI Reserved for 802.2LLC project
  918. Netbeui = 13,
  919. /// AF_SECURITY 表示Security callback的伪AF
  920. Security = 14,
  921. /// AF_KEY 表示Key management API
  922. Key = 15,
  923. /// AF_NETLINK 表示Netlink的socket
  924. Netlink = 16,
  925. /// AF_PACKET 表示Low level packet interface
  926. Packet = 17,
  927. /// AF_ASH 表示Ash
  928. Ash = 18,
  929. /// AF_ECONET 表示Acorn Econet
  930. Econet = 19,
  931. /// AF_ATMSVC 表示ATM SVCs
  932. Atmsvc = 20,
  933. /// AF_RDS 表示Reliable Datagram Sockets
  934. Rds = 21,
  935. /// AF_SNA 表示Linux SNA Project
  936. Sna = 22,
  937. /// AF_IRDA 表示IRDA sockets
  938. Irda = 23,
  939. /// AF_PPPOX 表示PPPoX sockets
  940. Pppox = 24,
  941. /// AF_WANPIPE 表示WANPIPE API sockets
  942. WanPipe = 25,
  943. /// AF_LLC 表示Linux LLC
  944. Llc = 26,
  945. /// AF_IB 表示Native InfiniBand address
  946. /// 介绍:https://access.redhat.com/documentation/en-us/red_hat_enterprise_linux/9/html-single/configuring_infiniband_and_rdma_networks/index#understanding-infiniband-and-rdma_configuring-infiniband-and-rdma-networks
  947. Ib = 27,
  948. /// AF_MPLS 表示MPLS
  949. Mpls = 28,
  950. /// AF_CAN 表示Controller Area Network
  951. Can = 29,
  952. /// AF_TIPC 表示TIPC sockets
  953. Tipc = 30,
  954. /// AF_BLUETOOTH 表示Bluetooth sockets
  955. Bluetooth = 31,
  956. /// AF_IUCV 表示IUCV sockets
  957. Iucv = 32,
  958. /// AF_RXRPC 表示RxRPC sockets
  959. Rxrpc = 33,
  960. /// AF_ISDN 表示mISDN sockets
  961. Isdn = 34,
  962. /// AF_PHONET 表示Phonet sockets
  963. Phonet = 35,
  964. /// AF_IEEE802154 表示IEEE 802.15.4 sockets
  965. Ieee802154 = 36,
  966. /// AF_CAIF 表示CAIF sockets
  967. Caif = 37,
  968. /// AF_ALG 表示Algorithm sockets
  969. Alg = 38,
  970. /// AF_NFC 表示NFC sockets
  971. Nfc = 39,
  972. /// AF_VSOCK 表示vSockets
  973. Vsock = 40,
  974. /// AF_KCM 表示Kernel Connection Multiplexor
  975. Kcm = 41,
  976. /// AF_QIPCRTR 表示Qualcomm IPC Router
  977. Qipcrtr = 42,
  978. /// AF_SMC 表示SMC-R sockets.
  979. /// reserve number for PF_SMC protocol family that reuses AF_INET address family
  980. Smc = 43,
  981. /// AF_XDP 表示XDP sockets
  982. Xdp = 44,
  983. /// AF_MCTP 表示Management Component Transport Protocol
  984. Mctp = 45,
  985. /// AF_MAX 表示最大的地址族
  986. Max = 46,
  987. }
  988. impl TryFrom<u16> for AddressFamily {
  989. type Error = SystemError;
  990. fn try_from(x: u16) -> Result<Self, Self::Error> {
  991. use num_traits::FromPrimitive;
  992. return <Self as FromPrimitive>::from_u16(x).ok_or_else(|| SystemError::EINVAL);
  993. }
  994. }
  995. /// @brief posix套接字类型的枚举(这些值与linux内核中的值一致)
  996. #[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)]
  997. pub enum PosixSocketType {
  998. Stream = 1,
  999. Datagram = 2,
  1000. Raw = 3,
  1001. Rdm = 4,
  1002. SeqPacket = 5,
  1003. Dccp = 6,
  1004. Packet = 10,
  1005. }
  1006. impl TryFrom<u8> for PosixSocketType {
  1007. type Error = SystemError;
  1008. fn try_from(x: u8) -> Result<Self, Self::Error> {
  1009. use num_traits::FromPrimitive;
  1010. return <Self as FromPrimitive>::from_u8(x).ok_or_else(|| SystemError::EINVAL);
  1011. }
  1012. }
  1013. /// @brief Socket在文件系统中的inode封装
  1014. #[derive(Debug)]
  1015. pub struct SocketInode(SpinLock<Box<dyn Socket>>);
  1016. impl SocketInode {
  1017. pub fn new(socket: Box<dyn Socket>) -> Arc<Self> {
  1018. return Arc::new(Self(SpinLock::new(socket)));
  1019. }
  1020. #[inline]
  1021. pub fn inner(&self) -> SpinLockGuard<Box<dyn Socket>> {
  1022. return self.0.lock();
  1023. }
  1024. pub unsafe fn inner_no_preempt(&self) -> SpinLockGuard<Box<dyn Socket>> {
  1025. return self.0.lock_no_preempt();
  1026. }
  1027. }
  1028. impl IndexNode for SocketInode {
  1029. fn open(
  1030. &self,
  1031. _data: &mut crate::filesystem::vfs::FilePrivateData,
  1032. _mode: &crate::filesystem::vfs::file::FileMode,
  1033. ) -> Result<(), SystemError> {
  1034. return Ok(());
  1035. }
  1036. fn close(
  1037. &self,
  1038. _data: &mut crate::filesystem::vfs::FilePrivateData,
  1039. ) -> Result<(), SystemError> {
  1040. let socket = self.0.lock();
  1041. if let Some(Endpoint::Ip(Some(ip))) = socket.endpoint() {
  1042. PORT_MANAGER.unbind_port(socket.metadata().unwrap().socket_type, ip.port)?;
  1043. }
  1044. return Ok(());
  1045. }
  1046. fn read_at(
  1047. &self,
  1048. _offset: usize,
  1049. len: usize,
  1050. buf: &mut [u8],
  1051. _data: &mut crate::filesystem::vfs::FilePrivateData,
  1052. ) -> Result<usize, SystemError> {
  1053. return self.0.lock_no_preempt().read(&mut buf[0..len]).0;
  1054. }
  1055. fn write_at(
  1056. &self,
  1057. _offset: usize,
  1058. len: usize,
  1059. buf: &[u8],
  1060. _data: &mut crate::filesystem::vfs::FilePrivateData,
  1061. ) -> Result<usize, SystemError> {
  1062. return self.0.lock_no_preempt().write(&buf[0..len], None);
  1063. }
  1064. fn poll(&self) -> Result<crate::filesystem::vfs::PollStatus, SystemError> {
  1065. let (read, write, error) = self.0.lock().poll();
  1066. let mut result = PollStatus::empty();
  1067. if read {
  1068. result.insert(PollStatus::READ);
  1069. }
  1070. if write {
  1071. result.insert(PollStatus::WRITE);
  1072. }
  1073. if error {
  1074. result.insert(PollStatus::ERROR);
  1075. }
  1076. return Ok(result);
  1077. }
  1078. fn fs(&self) -> alloc::sync::Arc<dyn crate::filesystem::vfs::FileSystem> {
  1079. todo!()
  1080. }
  1081. fn as_any_ref(&self) -> &dyn core::any::Any {
  1082. self
  1083. }
  1084. fn list(&self) -> Result<Vec<alloc::string::String>, SystemError> {
  1085. return Err(SystemError::ENOTDIR);
  1086. }
  1087. fn metadata(&self) -> Result<crate::filesystem::vfs::Metadata, SystemError> {
  1088. let meta = Metadata {
  1089. mode: ModeType::from_bits_truncate(0o755),
  1090. file_type: FileType::Socket,
  1091. ..Default::default()
  1092. };
  1093. return Ok(meta);
  1094. }
  1095. fn resize(&self, _len: usize) -> Result<(), SystemError> {
  1096. return Ok(());
  1097. }
  1098. }