mod.rs 19 KB


  1. use crate::{
  2. net::{
  3. posix::MsgHdr,
  4. socket::{
  5. common::shutdown::{Shutdown, ShutdownTemp},
  6. endpoint::Endpoint,
  7. },
  8. },
  9. sched::SchedMode,
  10. };
  11. use alloc::{
  12. string::String,
  13. sync::{Arc, Weak},
  14. };
  15. use inner::{Connected, Init, Inner, Listener};
  16. use log::debug;
  17. use system_error::SystemError;
  18. use unix::{
  19. ns::abs::{remove_abs_addr, ABS_INODE_MAP},
  20. INODE_MAP,
  21. };
  22. use crate::{
  23. libs::rwlock::RwLock,
  24. net::socket::{self, *},
  25. };
  26. type EP = crate::filesystem::epoll::EPollEventType;
  27. pub mod inner;
  28. #[derive(Debug)]
  29. pub struct StreamSocket {
  30. inner: RwLock<Inner>,
  31. shutdown: Shutdown,
  32. _epitems: EPollItems,
  33. wait_queue: WaitQueue,
  34. self_ref: Weak<Self>,
  35. }
  36. impl StreamSocket {
  37. /// 默认的元数据缓冲区大小
  38. #[allow(dead_code)]
  39. pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
  40. /// 默认的缓冲区大小
  41. pub const DEFAULT_BUF_SIZE: usize = 64 * 1024;
  42. pub fn new() -> Arc<Self> {
  43. Arc::new_cyclic(|me| Self {
  44. inner: RwLock::new(Inner::Init(Init::new())),
  45. shutdown: Shutdown::new(),
  46. _epitems: EPollItems::default(),
  47. wait_queue: WaitQueue::default(),
  48. self_ref: me.clone(),
  49. })
  50. }
  51. pub fn new_pairs() -> Result<(Arc<SocketInode>, Arc<SocketInode>), SystemError> {
  52. let socket0 = StreamSocket::new();
  53. let socket1 = StreamSocket::new();
  54. let inode0 = SocketInode::new(socket0.clone());
  55. let inode1 = SocketInode::new(socket1.clone());
  56. let (conn_0, conn_1) = Connected::new_pair(
  57. Some(Endpoint::Inode((inode0.clone(), String::from("")))),
  58. Some(Endpoint::Inode((inode1.clone(), String::from("")))),
  59. );
  60. *socket0.inner.write() = Inner::Connected(conn_0);
  61. *socket1.inner.write() = Inner::Connected(conn_1);
  62. return Ok((inode0, inode1));
  63. }
  64. #[allow(dead_code)]
  65. pub fn new_connected(connected: Connected) -> Arc<Self> {
  66. Arc::new_cyclic(|me| Self {
  67. inner: RwLock::new(Inner::Connected(connected)),
  68. shutdown: Shutdown::new(),
  69. _epitems: EPollItems::default(),
  70. wait_queue: WaitQueue::default(),
  71. self_ref: me.clone(),
  72. })
  73. }
  74. pub fn new_inode() -> Result<Arc<SocketInode>, SystemError> {
  75. let socket = StreamSocket::new();
  76. let inode = SocketInode::new(socket.clone());
  77. let _ = match &mut *socket.inner.write() {
  78. Inner::Init(init) => init.bind(Endpoint::Inode((inode.clone(), String::from("")))),
  79. _ => return Err(SystemError::EINVAL),
  80. };
  81. return Ok(inode);
  82. }
  83. fn is_acceptable(&self) -> bool {
  84. match &*self.inner.read() {
  85. Inner::Listener(listener) => listener.is_acceptable(),
  86. _ => {
  87. panic!("the socket is not listening");
  88. }
  89. }
  90. }
  91. pub fn try_accept(&self) -> Result<(Arc<SocketInode>, Endpoint), SystemError> {
  92. match &*self.inner.read() {
  93. Inner::Listener(listener) => listener.try_accept() as _,
  94. _ => {
  95. log::error!("the socket is not listening");
  96. return Err(SystemError::EINVAL);
  97. }
  98. }
  99. }
  100. fn is_peer_shutdown(&self) -> Result<bool, SystemError> {
  101. let peer_shutdown = match self.get_peer_name()? {
  102. Endpoint::Inode((inode, _)) => Arc::downcast::<StreamSocket>(inode.inner())
  103. .map_err(|_| SystemError::EINVAL)?
  104. .shutdown
  105. .get()
  106. .is_both_shutdown(),
  107. _ => return Err(SystemError::EINVAL),
  108. };
  109. Ok(peer_shutdown)
  110. }
  111. fn can_recv(&self) -> Result<bool, SystemError> {
  112. let can = match &*self.inner.read() {
  113. Inner::Connected(connected) => connected.can_recv(),
  114. _ => return Err(SystemError::ENOTCONN),
  115. };
  116. Ok(can)
  117. }
  118. }
  119. impl Socket for StreamSocket {
  120. fn connect(&self, server_endpoint: Endpoint) -> Result<(), SystemError> {
  121. //获取客户端地址
  122. let client_endpoint = match &mut *self.inner.write() {
  123. Inner::Init(init) => match init.endpoint().cloned() {
  124. Some(endpoint) => {
  125. debug!("bind when connected");
  126. Some(endpoint)
  127. }
  128. None => {
  129. debug!("not bind when connected");
  130. let inode = SocketInode::new(self.self_ref.upgrade().unwrap().clone());
  131. let epoint = Endpoint::Inode((inode.clone(), String::from("")));
  132. let _ = init.bind(epoint.clone());
  133. Some(epoint)
  134. }
  135. },
  136. Inner::Connected(_) => return Err(SystemError::EISCONN),
  137. Inner::Listener(_) => return Err(SystemError::EINVAL),
  138. };
  139. //获取服务端地址
  140. // let peer_inode = match server_endpoint.clone() {
  141. // Endpoint::Inode(socket) => socket,
  142. // _ => return Err(SystemError::EINVAL),
  143. // };
  144. //找到对端socket
  145. let (peer_inode, sun_path) = match server_endpoint {
  146. Endpoint::Inode((inode, path)) => (inode, path),
  147. Endpoint::Unixpath((inode_id, path)) => {
  148. match INODE_MAP.read_irqsave().get(&inode_id) {
  149. Some(Endpoint::Inode((inode, _))) => (inode.clone(), path),
  150. _ => return Err(SystemError::EINVAL),
  151. }
  152. }
  153. Endpoint::Abspath((abs_addr, path)) => {
  154. match ABS_INODE_MAP.lock_irqsave().get(&abs_addr.name()) {
  155. Some(Endpoint::Inode((inode, _))) => (inode.clone(), path),
  156. _ => {
  157. log::debug!("can not find inode from absInodeMap");
  158. return Err(SystemError::EINVAL);
  159. }
  160. }
  161. }
  162. _ => return Err(SystemError::EINVAL),
  163. };
  164. let remote_socket: Arc<StreamSocket> =
  165. Arc::downcast::<StreamSocket>(peer_inode.inner()).map_err(|_| SystemError::EINVAL)?;
  166. //创建新的对端socket
  167. let new_server_socket = StreamSocket::new();
  168. let new_server_inode = SocketInode::new(new_server_socket.clone());
  169. let new_server_endpoint = Some(Endpoint::Inode((new_server_inode.clone(), sun_path)));
  170. //获取connect pair
  171. let (client_conn, server_conn) =
  172. Connected::new_pair(client_endpoint, new_server_endpoint.clone());
  173. *new_server_socket.inner.write() = Inner::Connected(server_conn);
  174. //查看remote_socket是否处于监听状态
  175. let remote_listener = remote_socket.inner.write();
  176. match &*remote_listener {
  177. Inner::Listener(listener) => {
  178. //往服务端socket的连接队列中添加connected
  179. listener.push_incoming(new_server_inode)?;
  180. *self.inner.write() = Inner::Connected(client_conn);
  181. remote_socket.wait_queue.wakeup(None);
  182. }
  183. _ => return Err(SystemError::EINVAL),
  184. }
  185. return Ok(());
  186. }
  187. fn bind(&self, endpoint: Endpoint) -> Result<(), SystemError> {
  188. match endpoint {
  189. Endpoint::Unixpath((inodeid, path)) => {
  190. let inode = match &mut *self.inner.write() {
  191. Inner::Init(init) => init.bind_path(path)?,
  192. _ => {
  193. log::error!("socket has listen or connected");
  194. return Err(SystemError::EINVAL);
  195. }
  196. };
  197. INODE_MAP.write_irqsave().insert(inodeid, inode);
  198. Ok(())
  199. }
  200. Endpoint::Abspath((abshandle, path)) => {
  201. let inode = match &mut *self.inner.write() {
  202. Inner::Init(init) => init.bind_path(path)?,
  203. _ => {
  204. log::error!("socket has listen or connected");
  205. return Err(SystemError::EINVAL);
  206. }
  207. };
  208. ABS_INODE_MAP.lock_irqsave().insert(abshandle.name(), inode);
  209. Ok(())
  210. }
  211. _ => return Err(SystemError::EINVAL),
  212. }
  213. }
  214. fn shutdown(&self, _stype: ShutdownTemp) -> Result<(), SystemError> {
  215. todo!();
  216. }
  217. fn listen(&self, backlog: usize) -> Result<(), SystemError> {
  218. let mut inner = self.inner.write();
  219. let epoint = match &*inner {
  220. Inner::Init(init) => init.endpoint().ok_or(SystemError::EINVAL)?.clone(),
  221. Inner::Connected(_) => {
  222. return Err(SystemError::EINVAL);
  223. }
  224. Inner::Listener(listener) => {
  225. return listener.listen(backlog);
  226. }
  227. };
  228. let listener = Listener::new(Some(epoint), backlog);
  229. *inner = Inner::Listener(listener);
  230. return Ok(());
  231. }
  232. fn accept(&self) -> Result<(Arc<socket::SocketInode>, Endpoint), SystemError> {
  233. debug!("stream server begin accept");
  234. //目前只实现了阻塞式实现
  235. loop {
  236. wq_wait_event_interruptible!(self.wait_queue, self.is_acceptable(), {})?;
  237. match self.try_accept() {
  238. Ok((socket, endpoint)) => {
  239. debug!("server accept!:{:?}", endpoint);
  240. return Ok((socket, endpoint));
  241. }
  242. Err(_) => continue,
  243. }
  244. }
  245. }
  246. fn set_option(&self, _level: PSOL, _optname: usize, _optval: &[u8]) -> Result<(), SystemError> {
  247. log::warn!("setsockopt is not implemented");
  248. Ok(())
  249. }
  250. fn wait_queue(&self) -> &WaitQueue {
  251. return &self.wait_queue;
  252. }
  253. fn poll(&self) -> usize {
  254. let mut mask = EP::empty();
  255. let shutdown = self.shutdown.get();
  256. // 参考linux的unix_poll https://code.dragonos.org.cn/xref/linux-6.1.9/net/unix/af_unix.c#3152
  257. // 用关闭读写端表示连接断开
  258. if shutdown.is_both_shutdown() || self.is_peer_shutdown().unwrap() {
  259. mask |= EP::EPOLLHUP;
  260. }
  261. if shutdown.is_recv_shutdown() {
  262. mask |= EP::EPOLLRDHUP | EP::EPOLLIN | EP::EPOLLRDNORM;
  263. }
  264. match &*self.inner.read() {
  265. Inner::Connected(connected) => {
  266. if connected.can_recv() {
  267. mask |= EP::EPOLLIN | EP::EPOLLRDNORM;
  268. }
  269. // if (sk_is_readable(sk))
  270. // mask |= EPOLLIN | EPOLLRDNORM;
  271. // TODO:处理紧急情况 EPOLLPRI
  272. // TODO:处理连接是否关闭 EPOLLHUP
  273. if !shutdown.is_send_shutdown() {
  274. if connected.can_send().unwrap() {
  275. mask |= EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND;
  276. } else {
  277. todo!("poll: buffer space not enough");
  278. }
  279. }
  280. }
  281. Inner::Listener(_) => mask |= EP::EPOLLIN,
  282. Inner::Init(_) => mask |= EP::EPOLLOUT,
  283. }
  284. mask.bits() as usize
  285. }
  286. fn close(&self) -> Result<(), SystemError> {
  287. self.shutdown.recv_shutdown();
  288. self.shutdown.send_shutdown();
  289. let endpoint = self.get_name()?;
  290. let path = match &endpoint {
  291. Endpoint::Inode((_, path)) => path,
  292. Endpoint::Unixpath((_, path)) => path,
  293. Endpoint::Abspath((_, path)) => path,
  294. _ => return Err(SystemError::EINVAL),
  295. };
  296. if path.is_empty() {
  297. return Ok(());
  298. }
  299. match &endpoint {
  300. Endpoint::Unixpath((inode_id, _)) => {
  301. let mut inode_guard = INODE_MAP.write_irqsave();
  302. inode_guard.remove(inode_id);
  303. }
  304. Endpoint::Inode((current_inode, current_path)) => {
  305. let mut inode_guard = INODE_MAP.write_irqsave();
  306. // 遍历查找匹配的条目
  307. let target_entry = inode_guard
  308. .iter()
  309. .find(|(_, ep)| {
  310. if let Endpoint::Inode((map_inode, map_path)) = ep {
  311. // 通过指针相等性比较确保是同一对象
  312. Arc::ptr_eq(map_inode, current_inode) && map_path == current_path
  313. } else {
  314. log::debug!("not match");
  315. false
  316. }
  317. })
  318. .map(|(id, _)| *id);
  319. if let Some(id) = target_entry {
  320. inode_guard.remove(&id).ok_or(SystemError::EINVAL)?;
  321. }
  322. }
  323. Endpoint::Abspath((abshandle, _)) => {
  324. let mut abs_inode_map = ABS_INODE_MAP.lock_irqsave();
  325. abs_inode_map.remove(&abshandle.name());
  326. }
  327. _ => {
  328. log::error!("invalid endpoint type");
  329. return Err(SystemError::EINVAL);
  330. }
  331. }
  332. *self.inner.write() = Inner::Init(Init::new());
  333. self.wait_queue.wakeup(None);
  334. let _ = remove_abs_addr(path);
  335. Ok(())
  336. }
  337. fn get_peer_name(&self) -> Result<Endpoint, SystemError> {
  338. //获取对端地址
  339. let endpoint = match &*self.inner.read() {
  340. Inner::Connected(connected) => connected.peer_endpoint().cloned(),
  341. _ => return Err(SystemError::ENOTCONN),
  342. };
  343. if let Some(endpoint) = endpoint {
  344. return Ok(endpoint);
  345. } else {
  346. return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
  347. }
  348. }
  349. fn get_name(&self) -> Result<Endpoint, SystemError> {
  350. //获取本端地址
  351. let endpoint = match &*self.inner.read() {
  352. Inner::Init(init) => init.endpoint().cloned(),
  353. Inner::Connected(connected) => connected.endpoint().cloned(),
  354. Inner::Listener(listener) => listener.endpoint().cloned(),
  355. };
  356. if let Some(endpoint) = endpoint {
  357. return Ok(endpoint);
  358. } else {
  359. return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
  360. }
  361. }
  362. fn get_option(
  363. &self,
  364. _level: PSOL,
  365. _name: usize,
  366. _value: &mut [u8],
  367. ) -> Result<usize, SystemError> {
  368. log::warn!("getsockopt is not implemented");
  369. Ok(0)
  370. }
  371. fn read(&self, buffer: &mut [u8]) -> Result<usize, SystemError> {
  372. self.recv(buffer, socket::PMSG::empty())
  373. }
  374. fn recv(&self, buffer: &mut [u8], flags: socket::PMSG) -> Result<usize, SystemError> {
  375. if !flags.contains(PMSG::DONTWAIT) {
  376. loop {
  377. log::debug!("socket try recv");
  378. wq_wait_event_interruptible!(
  379. self.wait_queue,
  380. self.can_recv()? || self.is_peer_shutdown()?,
  381. {}
  382. )?;
  383. // connect锁和flag判断顺序不正确,应该先判断在
  384. match &*self.inner.write() {
  385. Inner::Connected(connected) => match connected.try_recv(buffer) {
  386. Ok(usize) => {
  387. log::debug!("recv successfully");
  388. return Ok(usize);
  389. }
  390. Err(_) => continue,
  391. },
  392. _ => {
  393. log::error!("the socket is not connected");
  394. return Err(SystemError::ENOTCONN);
  395. }
  396. }
  397. }
  398. } else {
  399. unimplemented!("unimplemented non_block")
  400. }
  401. }
  402. fn recv_from(
  403. &self,
  404. buffer: &mut [u8],
  405. flags: socket::PMSG,
  406. _address: Option<Endpoint>,
  407. ) -> Result<(usize, Endpoint), SystemError> {
  408. if flags.contains(PMSG::OOB) {
  409. return Err(SystemError::EOPNOTSUPP_OR_ENOTSUP);
  410. }
  411. if !flags.contains(PMSG::DONTWAIT) {
  412. loop {
  413. log::debug!("socket try recv from");
  414. wq_wait_event_interruptible!(
  415. self.wait_queue,
  416. self.can_recv()? || self.is_peer_shutdown()?,
  417. {}
  418. )?;
  419. // connect锁和flag判断顺序不正确,应该先判断在
  420. log::debug!("try recv");
  421. match &*self.inner.write() {
  422. Inner::Connected(connected) => match connected.try_recv(buffer) {
  423. Ok(usize) => {
  424. log::debug!("recvs from successfully");
  425. return Ok((usize, connected.peer_endpoint().unwrap().clone()));
  426. }
  427. Err(_) => continue,
  428. },
  429. _ => {
  430. log::error!("the socket is not connected");
  431. return Err(SystemError::ENOTCONN);
  432. }
  433. }
  434. }
  435. } else {
  436. unimplemented!("unimplemented non_block")
  437. }
  438. }
  439. fn recv_msg(&self, _msg: &mut MsgHdr, _flags: socket::PMSG) -> Result<usize, SystemError> {
  440. Err(SystemError::ENOSYS)
  441. }
  442. fn send(&self, buffer: &[u8], flags: socket::PMSG) -> Result<usize, SystemError> {
  443. if self.is_peer_shutdown()? {
  444. return Err(SystemError::EPIPE);
  445. }
  446. if !flags.contains(PMSG::DONTWAIT) {
  447. loop {
  448. match &*self.inner.write() {
  449. Inner::Connected(connected) => match connected.try_send(buffer) {
  450. Ok(usize) => {
  451. log::debug!("send successfully");
  452. return Ok(usize);
  453. }
  454. Err(_) => continue,
  455. },
  456. _ => {
  457. log::error!("the socket is not connected");
  458. return Err(SystemError::ENOTCONN);
  459. }
  460. }
  461. }
  462. } else {
  463. unimplemented!("unimplemented non_block")
  464. }
  465. }
  466. fn send_msg(&self, _msg: &MsgHdr, _flags: socket::PMSG) -> Result<usize, SystemError> {
  467. todo!()
  468. }
  469. fn send_to(
  470. &self,
  471. _buffer: &[u8],
  472. _flags: socket::PMSG,
  473. _address: Endpoint,
  474. ) -> Result<usize, SystemError> {
  475. Err(SystemError::ENOSYS)
  476. }
  477. fn write(&self, buffer: &[u8]) -> Result<usize, SystemError> {
  478. self.send(buffer, socket::PMSG::empty())
  479. }
  480. fn send_buffer_size(&self) -> usize {
  481. log::warn!("using default buffer size");
  482. StreamSocket::DEFAULT_BUF_SIZE
  483. }
  484. fn recv_buffer_size(&self) -> usize {
  485. log::warn!("using default buffer size");
  486. StreamSocket::DEFAULT_BUF_SIZE
  487. }
  488. }