mod.rs 16 KB


  1. pub mod inner;
  2. use alloc::{
  3. string::String,
  4. sync::{Arc, Weak},
  5. };
  6. use core::sync::atomic::{AtomicBool, Ordering};
  7. use crate::sched::SchedMode;
  8. use crate::{libs::rwlock::RwLock, net::socket::*};
  9. use inner::*;
  10. use system_error::SystemError;
  11. use super::INODE_MAP;
  12. type EP = EPollEventType;
  13. #[derive(Debug)]
  14. pub struct SeqpacketSocket {
  15. inner: RwLock<Inner>,
  16. shutdown: Shutdown,
  17. is_nonblocking: AtomicBool,
  18. wait_queue: WaitQueue,
  19. self_ref: Weak<Self>,
  20. }
  21. impl SeqpacketSocket {
  22. /// 默认的元数据缓冲区大小
  23. pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
  24. /// 默认的缓冲区大小
  25. pub const DEFAULT_BUF_SIZE: usize = 64 * 1024;
  26. pub fn new(is_nonblocking: bool) -> Arc<Self> {
  27. Arc::new_cyclic(|me| Self {
  28. inner: RwLock::new(Inner::Init(Init::new())),
  29. shutdown: Shutdown::new(),
  30. is_nonblocking: AtomicBool::new(is_nonblocking),
  31. wait_queue: WaitQueue::default(),
  32. self_ref: me.clone(),
  33. })
  34. }
  35. pub fn new_inode(is_nonblocking: bool) -> Result<Arc<Inode>, SystemError> {
  36. let socket = SeqpacketSocket::new(is_nonblocking);
  37. let inode = Inode::new(socket.clone());
  38. // 建立时绑定自身为后续能正常获取本端地址
  39. let _ = match &mut *socket.inner.write() {
  40. Inner::Init(init) => init.bind(Endpoint::Inode((inode.clone(), String::from("")))),
  41. _ => return Err(SystemError::EINVAL),
  42. };
  43. return Ok(inode);
  44. }
  45. pub fn new_connected(connected: Connected, is_nonblocking: bool) -> Arc<Self> {
  46. Arc::new_cyclic(|me| Self {
  47. inner: RwLock::new(Inner::Connected(connected)),
  48. shutdown: Shutdown::new(),
  49. is_nonblocking: AtomicBool::new(is_nonblocking),
  50. wait_queue: WaitQueue::default(),
  51. self_ref: me.clone(),
  52. })
  53. }
  54. pub fn new_pairs() -> Result<(Arc<Inode>, Arc<Inode>), SystemError> {
  55. let socket0 = SeqpacketSocket::new(false);
  56. let socket1 = SeqpacketSocket::new(false);
  57. let inode0 = Inode::new(socket0.clone());
  58. let inode1 = Inode::new(socket1.clone());
  59. let (conn_0, conn_1) = Connected::new_pair(
  60. Some(Endpoint::Inode((inode0.clone(), String::from("")))),
  61. Some(Endpoint::Inode((inode1.clone(), String::from("")))),
  62. );
  63. *socket0.inner.write() = Inner::Connected(conn_0);
  64. *socket1.inner.write() = Inner::Connected(conn_1);
  65. return Ok((inode0, inode1));
  66. }
  67. fn try_accept(&self) -> Result<(Arc<Inode>, Endpoint), SystemError> {
  68. match &*self.inner.read() {
  69. Inner::Listen(listen) => listen.try_accept() as _,
  70. _ => {
  71. log::error!("the socket is not listening");
  72. return Err(SystemError::EINVAL);
  73. }
  74. }
  75. }
  76. fn is_acceptable(&self) -> bool {
  77. match &*self.inner.read() {
  78. Inner::Listen(listen) => listen.is_acceptable(),
  79. _ => {
  80. panic!("the socket is not listening");
  81. }
  82. }
  83. }
  84. fn is_peer_shutdown(&self) -> Result<bool, SystemError> {
  85. let peer_shutdown = match self.get_peer_name()? {
  86. Endpoint::Inode((inode, _)) => Arc::downcast::<SeqpacketSocket>(inode.inner())
  87. .map_err(|_| SystemError::EINVAL)?
  88. .shutdown
  89. .get()
  90. .is_both_shutdown(),
  91. _ => return Err(SystemError::EINVAL),
  92. };
  93. Ok(peer_shutdown)
  94. }
  95. fn can_recv(&self) -> Result<bool, SystemError> {
  96. let can = match &*self.inner.read() {
  97. Inner::Connected(connected) => connected.can_recv(),
  98. _ => return Err(SystemError::ENOTCONN),
  99. };
  100. Ok(can)
  101. }
  102. fn is_nonblocking(&self) -> bool {
  103. self.is_nonblocking.load(Ordering::Relaxed)
  104. }
  105. fn set_nonblocking(&self, nonblocking: bool) {
  106. self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
  107. }
  108. }
  109. impl Socket for SeqpacketSocket {
  110. fn connect(&self, endpoint: Endpoint) -> Result<(), SystemError> {
  111. let peer_inode = match endpoint {
  112. Endpoint::Inode((inode, _)) => inode,
  113. Endpoint::Unixpath((inode_id, _)) => {
  114. let inode_guard = INODE_MAP.read_irqsave();
  115. let inode = inode_guard.get(&inode_id).unwrap();
  116. match inode {
  117. Endpoint::Inode((inode, _)) => inode.clone(),
  118. _ => return Err(SystemError::EINVAL),
  119. }
  120. }
  121. _ => return Err(SystemError::EINVAL),
  122. };
  123. // 远端为服务端
  124. let remote_socket = Arc::downcast::<SeqpacketSocket>(peer_inode.inner())
  125. .map_err(|_| SystemError::EINVAL)?;
  126. let client_epoint = match &mut *self.inner.write() {
  127. Inner::Init(init) => match init.endpoint().cloned() {
  128. Some(end) => {
  129. log::debug!("bind when connect");
  130. Some(end)
  131. }
  132. None => {
  133. log::debug!("not bind when connect");
  134. let inode = Inode::new(self.self_ref.upgrade().unwrap().clone());
  135. let epoint = Endpoint::Inode((inode.clone(), String::from("")));
  136. let _ = init.bind(epoint.clone());
  137. Some(epoint)
  138. }
  139. },
  140. Inner::Listen(_) => return Err(SystemError::EINVAL),
  141. Inner::Connected(_) => return Err(SystemError::EISCONN),
  142. };
  143. // ***阻塞与非阻塞处理还未实现
  144. // 客户端与服务端建立连接将服务端inode推入到自身的listen_incom队列中,
  145. // accept时从中获取推出对应的socket
  146. match &*remote_socket.inner.read() {
  147. Inner::Listen(listener) => match listener.push_incoming(client_epoint) {
  148. Ok(connected) => {
  149. *self.inner.write() = Inner::Connected(connected);
  150. log::debug!("try to wake up");
  151. remote_socket.wait_queue.wakeup(None);
  152. return Ok(());
  153. }
  154. // ***错误处理
  155. Err(_) => todo!(),
  156. },
  157. Inner::Init(_) => {
  158. log::debug!("init einval");
  159. return Err(SystemError::EINVAL);
  160. }
  161. Inner::Connected(_) => return Err(SystemError::EISCONN),
  162. };
  163. }
  164. fn bind(&self, endpoint: Endpoint) -> Result<(), SystemError> {
  165. // 将自身socket的inode与用户端提供路径的文件indoe_id进行绑定
  166. match endpoint {
  167. Endpoint::Unixpath((inodeid, path)) => {
  168. let inode = match &mut *self.inner.write() {
  169. Inner::Init(init) => init.bind_path(path)?,
  170. _ => {
  171. log::error!("socket has listen or connected");
  172. return Err(SystemError::EINVAL);
  173. }
  174. };
  175. INODE_MAP.write_irqsave().insert(inodeid, inode);
  176. Ok(())
  177. }
  178. _ => return Err(SystemError::EINVAL),
  179. }
  180. }
  181. fn shutdown(&self, how: ShutdownTemp) -> Result<(), SystemError> {
  182. log::debug!("seqpacket shutdown");
  183. match &*self.inner.write() {
  184. Inner::Connected(connected) => connected.shutdown(how),
  185. _ => Err(SystemError::EINVAL),
  186. }
  187. }
  188. fn listen(&self, backlog: usize) -> Result<(), SystemError> {
  189. let mut state = self.inner.write();
  190. log::debug!("listen into socket");
  191. let epoint = match &*state {
  192. Inner::Init(init) => init.endpoint().ok_or(SystemError::EINVAL)?.clone(),
  193. Inner::Listen(listener) => return listener.listen(backlog),
  194. Inner::Connected(_) => {
  195. log::error!("the socket is connected");
  196. return Err(SystemError::EINVAL);
  197. }
  198. };
  199. let listener = Listener::new(epoint, backlog);
  200. *state = Inner::Listen(listener);
  201. Ok(())
  202. }
  203. fn accept(&self) -> Result<(Arc<Inode>, Endpoint), SystemError> {
  204. if !self.is_nonblocking() {
  205. loop {
  206. wq_wait_event_interruptible!(self.wait_queue, self.is_acceptable(), {})?;
  207. match self.try_accept() {
  208. Ok((socket, epoint)) => return Ok((socket, epoint)),
  209. Err(_) => continue,
  210. }
  211. }
  212. } else {
  213. // ***非阻塞状态
  214. todo!()
  215. }
  216. }
  217. fn set_option(
  218. &self,
  219. _level: crate::net::socket::OptionsLevel,
  220. _optname: usize,
  221. _optval: &[u8],
  222. ) -> Result<(), SystemError> {
  223. log::warn!("setsockopt is not implemented");
  224. Ok(())
  225. }
  226. fn wait_queue(&self) -> &WaitQueue {
  227. return &self.wait_queue;
  228. }
  229. fn close(&self) -> Result<(), SystemError> {
  230. // log::debug!("seqpacket close");
  231. self.shutdown.recv_shutdown();
  232. self.shutdown.send_shutdown();
  233. Ok(())
  234. }
  235. fn get_peer_name(&self) -> Result<Endpoint, SystemError> {
  236. // 获取对端地址
  237. let endpoint = match &*self.inner.read() {
  238. Inner::Connected(connected) => connected.peer_endpoint().cloned(),
  239. _ => return Err(SystemError::ENOTCONN),
  240. };
  241. if let Some(endpoint) = endpoint {
  242. return Ok(endpoint);
  243. } else {
  244. return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
  245. }
  246. }
  247. fn get_name(&self) -> Result<Endpoint, SystemError> {
  248. // 获取本端地址
  249. let endpoint = match &*self.inner.read() {
  250. Inner::Init(init) => init.endpoint().cloned(),
  251. Inner::Listen(listener) => Some(listener.endpoint().clone()),
  252. Inner::Connected(connected) => connected.endpoint().cloned(),
  253. };
  254. if let Some(endpoint) = endpoint {
  255. return Ok(endpoint);
  256. } else {
  257. return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
  258. }
  259. }
  260. fn get_option(
  261. &self,
  262. _level: crate::net::socket::OptionsLevel,
  263. _name: usize,
  264. _value: &mut [u8],
  265. ) -> Result<usize, SystemError> {
  266. log::warn!("getsockopt is not implemented");
  267. Ok(0)
  268. }
  269. fn read(&self, buffer: &mut [u8]) -> Result<usize, SystemError> {
  270. self.recv(buffer, crate::net::socket::MessageFlag::empty())
  271. }
  272. fn recv(
  273. &self,
  274. buffer: &mut [u8],
  275. flags: crate::net::socket::MessageFlag,
  276. ) -> Result<usize, SystemError> {
  277. if flags.contains(MessageFlag::OOB) {
  278. return Err(SystemError::EOPNOTSUPP_OR_ENOTSUP);
  279. }
  280. if !flags.contains(MessageFlag::DONTWAIT) {
  281. loop {
  282. wq_wait_event_interruptible!(
  283. self.wait_queue,
  284. self.can_recv()? || self.is_peer_shutdown()?,
  285. {}
  286. )?;
  287. // connect锁和flag判断顺序不正确,应该先判断在
  288. match &*self.inner.write() {
  289. Inner::Connected(connected) => match connected.try_read(buffer) {
  290. Ok(usize) => {
  291. log::debug!("recv from successfully");
  292. return Ok(usize);
  293. }
  294. Err(_) => continue,
  295. },
  296. _ => {
  297. log::error!("the socket is not connected");
  298. return Err(SystemError::ENOTCONN);
  299. }
  300. }
  301. }
  302. } else {
  303. unimplemented!("unimplemented non_block")
  304. }
  305. }
  306. fn recv_msg(
  307. &self,
  308. _msg: &mut crate::net::syscall::MsgHdr,
  309. _flags: crate::net::socket::MessageFlag,
  310. ) -> Result<usize, SystemError> {
  311. Err(SystemError::ENOSYS)
  312. }
  313. fn send(
  314. &self,
  315. buffer: &[u8],
  316. flags: crate::net::socket::MessageFlag,
  317. ) -> Result<usize, SystemError> {
  318. if flags.contains(MessageFlag::OOB) {
  319. return Err(SystemError::EOPNOTSUPP_OR_ENOTSUP);
  320. }
  321. if self.is_peer_shutdown()? {
  322. return Err(SystemError::EPIPE);
  323. }
  324. if !flags.contains(MessageFlag::DONTWAIT) {
  325. loop {
  326. match &*self.inner.write() {
  327. Inner::Connected(connected) => match connected.try_write(buffer) {
  328. Ok(usize) => {
  329. log::debug!("send successfully");
  330. return Ok(usize);
  331. }
  332. Err(_) => continue,
  333. },
  334. _ => {
  335. log::error!("the socket is not connected");
  336. return Err(SystemError::ENOTCONN);
  337. }
  338. }
  339. }
  340. } else {
  341. unimplemented!("unimplemented non_block")
  342. }
  343. }
  344. fn send_msg(
  345. &self,
  346. _msg: &crate::net::syscall::MsgHdr,
  347. _flags: crate::net::socket::MessageFlag,
  348. ) -> Result<usize, SystemError> {
  349. Err(SystemError::ENOSYS)
  350. }
  351. fn write(&self, buffer: &[u8]) -> Result<usize, SystemError> {
  352. self.send(buffer, crate::net::socket::MessageFlag::empty())
  353. }
  354. fn recv_from(
  355. &self,
  356. buffer: &mut [u8],
  357. flags: MessageFlag,
  358. _address: Option<Endpoint>,
  359. ) -> Result<(usize, Endpoint), SystemError> {
  360. // log::debug!("recvfrom flags {:?}", flags);
  361. if flags.contains(MessageFlag::OOB) {
  362. return Err(SystemError::EOPNOTSUPP_OR_ENOTSUP);
  363. }
  364. if !flags.contains(MessageFlag::DONTWAIT) {
  365. loop {
  366. wq_wait_event_interruptible!(
  367. self.wait_queue,
  368. self.can_recv()? || self.is_peer_shutdown()?,
  369. {}
  370. )?;
  371. // connect锁和flag判断顺序不正确,应该先判断在
  372. match &*self.inner.write() {
  373. Inner::Connected(connected) => match connected.recv_slice(buffer) {
  374. Ok(usize) => {
  375. // log::debug!("recvs from successfully");
  376. return Ok((usize, connected.peer_endpoint().unwrap().clone()));
  377. }
  378. Err(_) => continue,
  379. },
  380. _ => {
  381. log::error!("the socket is not connected");
  382. return Err(SystemError::ENOTCONN);
  383. }
  384. }
  385. }
  386. } else {
  387. unimplemented!("unimplemented non_block")
  388. }
  389. //Err(SystemError::ENOSYS)
  390. }
  391. fn send_buffer_size(&self) -> usize {
  392. log::warn!("using default buffer size");
  393. SeqpacketSocket::DEFAULT_BUF_SIZE
  394. }
  395. fn recv_buffer_size(&self) -> usize {
  396. log::warn!("using default buffer size");
  397. SeqpacketSocket::DEFAULT_BUF_SIZE
  398. }
  399. fn poll(&self) -> usize {
  400. let mut mask = EP::empty();
  401. let shutdown = self.shutdown.get();
  402. // 参考linux的unix_poll https://code.dragonos.org.cn/xref/linux-6.1.9/net/unix/af_unix.c#3152
  403. // 用关闭读写端表示连接断开
  404. if shutdown.is_both_shutdown() || self.is_peer_shutdown().unwrap() {
  405. mask |= EP::EPOLLHUP;
  406. }
  407. if shutdown.is_recv_shutdown() {
  408. mask |= EP::EPOLLRDHUP | EP::EPOLLIN | EP::EPOLLRDNORM;
  409. }
  410. match &*self.inner.read() {
  411. Inner::Connected(connected) => {
  412. if connected.can_recv() {
  413. mask |= EP::EPOLLIN | EP::EPOLLRDNORM;
  414. }
  415. // if (sk_is_readable(sk))
  416. // mask |= EPOLLIN | EPOLLRDNORM;
  417. // TODO:处理紧急情况 EPOLLPRI
  418. // TODO:处理连接是否关闭 EPOLLHUP
  419. if !shutdown.is_send_shutdown() {
  420. if connected.can_send().unwrap() {
  421. mask |= EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND;
  422. } else {
  423. todo!("poll: buffer space not enough");
  424. }
  425. }
  426. }
  427. Inner::Listen(_) => mask |= EP::EPOLLIN,
  428. Inner::Init(_) => mask |= EP::EPOLLOUT,
  429. }
  430. mask.bits() as usize
  431. }
  432. }