123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556 |
- use crate::sched::SchedMode;
- use alloc::{
- string::String,
- sync::{Arc, Weak},
- };
- use inner::{Connected, Init, Inner, Listener};
- use log::debug;
- use system_error::SystemError;
- use unix::{
- ns::abs::{remove_abs_addr, ABS_INODE_MAP},
- INODE_MAP,
- };
- use crate::{
- libs::rwlock::RwLock,
- net::socket::{self, *},
- };
- type EP = EPollEventType;
- pub mod inner;
- #[derive(Debug)]
- pub struct StreamSocket {
- inner: RwLock<Inner>,
- shutdown: Shutdown,
- _epitems: EPollItems,
- wait_queue: WaitQueue,
- self_ref: Weak<Self>,
- }
- impl StreamSocket {
- /// 默认的元数据缓冲区大小
- #[allow(dead_code)]
- pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
- /// 默认的缓冲区大小
- pub const DEFAULT_BUF_SIZE: usize = 64 * 1024;
- pub fn new() -> Arc<Self> {
- Arc::new_cyclic(|me| Self {
- inner: RwLock::new(Inner::Init(Init::new())),
- shutdown: Shutdown::new(),
- _epitems: EPollItems::default(),
- wait_queue: WaitQueue::default(),
- self_ref: me.clone(),
- })
- }
- pub fn new_pairs() -> Result<(Arc<Inode>, Arc<Inode>), SystemError> {
- let socket0 = StreamSocket::new();
- let socket1 = StreamSocket::new();
- let inode0 = Inode::new(socket0.clone());
- let inode1 = Inode::new(socket1.clone());
- let (conn_0, conn_1) = Connected::new_pair(
- Some(Endpoint::Inode((inode0.clone(), String::from("")))),
- Some(Endpoint::Inode((inode1.clone(), String::from("")))),
- );
- *socket0.inner.write() = Inner::Connected(conn_0);
- *socket1.inner.write() = Inner::Connected(conn_1);
- return Ok((inode0, inode1));
- }
- #[allow(dead_code)]
- pub fn new_connected(connected: Connected) -> Arc<Self> {
- Arc::new_cyclic(|me| Self {
- inner: RwLock::new(Inner::Connected(connected)),
- shutdown: Shutdown::new(),
- _epitems: EPollItems::default(),
- wait_queue: WaitQueue::default(),
- self_ref: me.clone(),
- })
- }
- pub fn new_inode() -> Result<Arc<Inode>, SystemError> {
- let socket = StreamSocket::new();
- let inode = Inode::new(socket.clone());
- let _ = match &mut *socket.inner.write() {
- Inner::Init(init) => init.bind(Endpoint::Inode((inode.clone(), String::from("")))),
- _ => return Err(SystemError::EINVAL),
- };
- return Ok(inode);
- }
- fn is_acceptable(&self) -> bool {
- match &*self.inner.read() {
- Inner::Listener(listener) => listener.is_acceptable(),
- _ => {
- panic!("the socket is not listening");
- }
- }
- }
- pub fn try_accept(&self) -> Result<(Arc<Inode>, Endpoint), SystemError> {
- match &*self.inner.read() {
- Inner::Listener(listener) => listener.try_accept() as _,
- _ => {
- log::error!("the socket is not listening");
- return Err(SystemError::EINVAL);
- }
- }
- }
- fn is_peer_shutdown(&self) -> Result<bool, SystemError> {
- let peer_shutdown = match self.get_peer_name()? {
- Endpoint::Inode((inode, _)) => Arc::downcast::<StreamSocket>(inode.inner())
- .map_err(|_| SystemError::EINVAL)?
- .shutdown
- .get()
- .is_both_shutdown(),
- _ => return Err(SystemError::EINVAL),
- };
- Ok(peer_shutdown)
- }
- fn can_recv(&self) -> Result<bool, SystemError> {
- let can = match &*self.inner.read() {
- Inner::Connected(connected) => connected.can_recv(),
- _ => return Err(SystemError::ENOTCONN),
- };
- Ok(can)
- }
- }
- impl Socket for StreamSocket {
- fn connect(&self, server_endpoint: Endpoint) -> Result<(), SystemError> {
- //获取客户端地址
- let client_endpoint = match &mut *self.inner.write() {
- Inner::Init(init) => match init.endpoint().cloned() {
- Some(endpoint) => {
- debug!("bind when connected");
- Some(endpoint)
- }
- None => {
- debug!("not bind when connected");
- let inode = Inode::new(self.self_ref.upgrade().unwrap().clone());
- let epoint = Endpoint::Inode((inode.clone(), String::from("")));
- let _ = init.bind(epoint.clone());
- Some(epoint)
- }
- },
- Inner::Connected(_) => return Err(SystemError::EISCONN),
- Inner::Listener(_) => return Err(SystemError::EINVAL),
- };
- //获取服务端地址
- // let peer_inode = match server_endpoint.clone() {
- // Endpoint::Inode(socket) => socket,
- // _ => return Err(SystemError::EINVAL),
- // };
- //找到对端socket
- let (peer_inode, sun_path) = match server_endpoint {
- Endpoint::Inode((inode, path)) => (inode, path),
- Endpoint::Unixpath((inode_id, path)) => {
- let inode_guard = INODE_MAP.read_irqsave();
- let inode = inode_guard.get(&inode_id).unwrap();
- match inode {
- Endpoint::Inode((inode, _)) => (inode.clone(), path),
- _ => return Err(SystemError::EINVAL),
- }
- }
- Endpoint::Abspath((abs_addr, path)) => {
- let inode_guard = ABS_INODE_MAP.lock_irqsave();
- let inode = match inode_guard.get(&abs_addr.name()) {
- Some(inode) => inode,
- None => {
- log::debug!("can not find inode from absInodeMap");
- return Err(SystemError::EINVAL);
- }
- };
- match inode {
- Endpoint::Inode((inode, _)) => (inode.clone(), path),
- _ => {
- debug!("when connect, find inode failed!");
- return Err(SystemError::EINVAL);
- }
- }
- }
- _ => return Err(SystemError::EINVAL),
- };
- let remote_socket: Arc<StreamSocket> =
- Arc::downcast::<StreamSocket>(peer_inode.inner()).map_err(|_| SystemError::EINVAL)?;
- //创建新的对端socket
- let new_server_socket = StreamSocket::new();
- let new_server_inode = Inode::new(new_server_socket.clone());
- let new_server_endpoint = Some(Endpoint::Inode((new_server_inode.clone(), sun_path)));
- //获取connect pair
- let (client_conn, server_conn) =
- Connected::new_pair(client_endpoint, new_server_endpoint.clone());
- *new_server_socket.inner.write() = Inner::Connected(server_conn);
- //查看remote_socket是否处于监听状态
- let remote_listener = remote_socket.inner.write();
- match &*remote_listener {
- Inner::Listener(listener) => {
- //往服务端socket的连接队列中添加connected
- listener.push_incoming(new_server_inode)?;
- *self.inner.write() = Inner::Connected(client_conn);
- remote_socket.wait_queue.wakeup(None);
- }
- _ => return Err(SystemError::EINVAL),
- }
- return Ok(());
- }
- fn bind(&self, endpoint: Endpoint) -> Result<(), SystemError> {
- match endpoint {
- Endpoint::Unixpath((inodeid, path)) => {
- let inode = match &mut *self.inner.write() {
- Inner::Init(init) => init.bind_path(path)?,
- _ => {
- log::error!("socket has listen or connected");
- return Err(SystemError::EINVAL);
- }
- };
- INODE_MAP.write_irqsave().insert(inodeid, inode);
- Ok(())
- }
- Endpoint::Abspath((abshandle, path)) => {
- let inode = match &mut *self.inner.write() {
- Inner::Init(init) => init.bind_path(path)?,
- _ => {
- log::error!("socket has listen or connected");
- return Err(SystemError::EINVAL);
- }
- };
- ABS_INODE_MAP.lock_irqsave().insert(abshandle.name(), inode);
- Ok(())
- }
- _ => return Err(SystemError::EINVAL),
- }
- }
- fn shutdown(&self, _stype: ShutdownTemp) -> Result<(), SystemError> {
- todo!();
- }
- fn listen(&self, backlog: usize) -> Result<(), SystemError> {
- let mut inner = self.inner.write();
- let epoint = match &*inner {
- Inner::Init(init) => init.endpoint().ok_or(SystemError::EINVAL)?.clone(),
- Inner::Connected(_) => {
- return Err(SystemError::EINVAL);
- }
- Inner::Listener(listener) => {
- return listener.listen(backlog);
- }
- };
- let listener = Listener::new(Some(epoint), backlog);
- *inner = Inner::Listener(listener);
- return Ok(());
- }
- fn accept(&self) -> Result<(Arc<socket::Inode>, Endpoint), SystemError> {
- debug!("stream server begin accept");
- //目前只实现了阻塞式实现
- loop {
- wq_wait_event_interruptible!(self.wait_queue, self.is_acceptable(), {})?;
- match self.try_accept() {
- Ok((socket, endpoint)) => {
- debug!("server accept!:{:?}", endpoint);
- return Ok((socket, endpoint));
- }
- Err(_) => continue,
- }
- }
- }
- fn set_option(&self, _level: PSOL, _optname: usize, _optval: &[u8]) -> Result<(), SystemError> {
- log::warn!("setsockopt is not implemented");
- Ok(())
- }
- fn wait_queue(&self) -> &WaitQueue {
- return &self.wait_queue;
- }
- fn poll(&self) -> usize {
- let mut mask = EP::empty();
- let shutdown = self.shutdown.get();
- // 参考linux的unix_poll https://code.dragonos.org.cn/xref/linux-6.1.9/net/unix/af_unix.c#3152
- // 用关闭读写端表示连接断开
- if shutdown.is_both_shutdown() || self.is_peer_shutdown().unwrap() {
- mask |= EP::EPOLLHUP;
- }
- if shutdown.is_recv_shutdown() {
- mask |= EP::EPOLLRDHUP | EP::EPOLLIN | EP::EPOLLRDNORM;
- }
- match &*self.inner.read() {
- Inner::Connected(connected) => {
- if connected.can_recv() {
- mask |= EP::EPOLLIN | EP::EPOLLRDNORM;
- }
- // if (sk_is_readable(sk))
- // mask |= EPOLLIN | EPOLLRDNORM;
- // TODO:处理紧急情况 EPOLLPRI
- // TODO:处理连接是否关闭 EPOLLHUP
- if !shutdown.is_send_shutdown() {
- if connected.can_send().unwrap() {
- mask |= EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND;
- } else {
- todo!("poll: buffer space not enough");
- }
- }
- }
- Inner::Listener(_) => mask |= EP::EPOLLIN,
- Inner::Init(_) => mask |= EP::EPOLLOUT,
- }
- mask.bits() as usize
- }
- fn close(&self) -> Result<(), SystemError> {
- self.shutdown.recv_shutdown();
- self.shutdown.send_shutdown();
- let endpoint = self.get_name()?;
- let path = match &endpoint {
- Endpoint::Inode((_, path)) => path,
- Endpoint::Unixpath((_, path)) => path,
- Endpoint::Abspath((_, path)) => path,
- _ => return Err(SystemError::EINVAL),
- };
- if path.is_empty() {
- return Ok(());
- }
- match &endpoint {
- Endpoint::Unixpath((inode_id, _)) => {
- let mut inode_guard = INODE_MAP.write_irqsave();
- inode_guard.remove(inode_id);
- }
- Endpoint::Inode((current_inode, current_path)) => {
- let mut inode_guard = INODE_MAP.write_irqsave();
- // 遍历查找匹配的条目
- let target_entry = inode_guard
- .iter()
- .find(|(_, ep)| {
- if let Endpoint::Inode((map_inode, map_path)) = ep {
- // 通过指针相等性比较确保是同一对象
- Arc::ptr_eq(map_inode, current_inode) && map_path == current_path
- } else {
- log::debug!("not match");
- false
- }
- })
- .map(|(id, _)| *id);
- if let Some(id) = target_entry {
- inode_guard.remove(&id).ok_or(SystemError::EINVAL)?;
- }
- }
- Endpoint::Abspath((abshandle, _)) => {
- let mut abs_inode_map = ABS_INODE_MAP.lock_irqsave();
- abs_inode_map.remove(&abshandle.name());
- }
- _ => {
- log::error!("invalid endpoint type");
- return Err(SystemError::EINVAL);
- }
- }
- *self.inner.write() = Inner::Init(Init::new());
- self.wait_queue.wakeup(None);
- let _ = remove_abs_addr(path);
- Ok(())
- }
- fn get_peer_name(&self) -> Result<Endpoint, SystemError> {
- //获取对端地址
- let endpoint = match &*self.inner.read() {
- Inner::Connected(connected) => connected.peer_endpoint().cloned(),
- _ => return Err(SystemError::ENOTCONN),
- };
- if let Some(endpoint) = endpoint {
- return Ok(endpoint);
- } else {
- return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
- }
- }
- fn get_name(&self) -> Result<Endpoint, SystemError> {
- //获取本端地址
- let endpoint = match &*self.inner.read() {
- Inner::Init(init) => init.endpoint().cloned(),
- Inner::Connected(connected) => connected.endpoint().cloned(),
- Inner::Listener(listener) => listener.endpoint().cloned(),
- };
- if let Some(endpoint) = endpoint {
- return Ok(endpoint);
- } else {
- return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
- }
- }
- fn get_option(
- &self,
- _level: PSOL,
- _name: usize,
- _value: &mut [u8],
- ) -> Result<usize, SystemError> {
- log::warn!("getsockopt is not implemented");
- Ok(0)
- }
- fn read(&self, buffer: &mut [u8]) -> Result<usize, SystemError> {
- self.recv(buffer, socket::PMSG::empty())
- }
- fn recv(&self, buffer: &mut [u8], flags: socket::PMSG) -> Result<usize, SystemError> {
- if !flags.contains(PMSG::DONTWAIT) {
- loop {
- log::debug!("socket try recv");
- wq_wait_event_interruptible!(
- self.wait_queue,
- self.can_recv()? || self.is_peer_shutdown()?,
- {}
- )?;
- // connect锁和flag判断顺序不正确,应该先判断在
- match &*self.inner.write() {
- Inner::Connected(connected) => match connected.try_recv(buffer) {
- Ok(usize) => {
- log::debug!("recv successfully");
- return Ok(usize);
- }
- Err(_) => continue,
- },
- _ => {
- log::error!("the socket is not connected");
- return Err(SystemError::ENOTCONN);
- }
- }
- }
- } else {
- unimplemented!("unimplemented non_block")
- }
- }
- fn recv_from(
- &self,
- buffer: &mut [u8],
- flags: socket::PMSG,
- _address: Option<Endpoint>,
- ) -> Result<(usize, Endpoint), SystemError> {
- if flags.contains(PMSG::OOB) {
- return Err(SystemError::EOPNOTSUPP_OR_ENOTSUP);
- }
- if !flags.contains(PMSG::DONTWAIT) {
- loop {
- log::debug!("socket try recv from");
- wq_wait_event_interruptible!(
- self.wait_queue,
- self.can_recv()? || self.is_peer_shutdown()?,
- {}
- )?;
- // connect锁和flag判断顺序不正确,应该先判断在
- log::debug!("try recv");
- match &*self.inner.write() {
- Inner::Connected(connected) => match connected.try_recv(buffer) {
- Ok(usize) => {
- log::debug!("recvs from successfully");
- return Ok((usize, connected.peer_endpoint().unwrap().clone()));
- }
- Err(_) => continue,
- },
- _ => {
- log::error!("the socket is not connected");
- return Err(SystemError::ENOTCONN);
- }
- }
- }
- } else {
- unimplemented!("unimplemented non_block")
- }
- }
- fn recv_msg(
- &self,
- _msg: &mut crate::net::syscall::MsgHdr,
- _flags: socket::PMSG,
- ) -> Result<usize, SystemError> {
- Err(SystemError::ENOSYS)
- }
- fn send(&self, buffer: &[u8], flags: socket::PMSG) -> Result<usize, SystemError> {
- if self.is_peer_shutdown()? {
- return Err(SystemError::EPIPE);
- }
- if !flags.contains(PMSG::DONTWAIT) {
- loop {
- match &*self.inner.write() {
- Inner::Connected(connected) => match connected.try_send(buffer) {
- Ok(usize) => {
- log::debug!("send successfully");
- return Ok(usize);
- }
- Err(_) => continue,
- },
- _ => {
- log::error!("the socket is not connected");
- return Err(SystemError::ENOTCONN);
- }
- }
- }
- } else {
- unimplemented!("unimplemented non_block")
- }
- }
- fn send_msg(
- &self,
- _msg: &crate::net::syscall::MsgHdr,
- _flags: socket::PMSG,
- ) -> Result<usize, SystemError> {
- todo!()
- }
- fn send_to(
- &self,
- _buffer: &[u8],
- _flags: socket::PMSG,
- _address: Endpoint,
- ) -> Result<usize, SystemError> {
- Err(SystemError::ENOSYS)
- }
- fn write(&self, buffer: &[u8]) -> Result<usize, SystemError> {
- self.send(buffer, socket::PMSG::empty())
- }
- fn send_buffer_size(&self) -> usize {
- log::warn!("using default buffer size");
- StreamSocket::DEFAULT_BUF_SIZE
- }
- fn recv_buffer_size(&self) -> usize {
- log::warn!("using default buffer size");
- StreamSocket::DEFAULT_BUF_SIZE
- }
- }
|