فهرست منبع

feat: 实现ppoll系统调用并优化poll相关功能 (#1127)

- 新增ppoll系统调用,支持信号屏蔽和精确超时控制
- 优化poll系统调用,修复超时处理逻辑
- 新增ProcessControlBlock::has_pending_not_masked_signal方法,优化信号检测
- 添加Instant::saturating_sub方法,改进时间计算
- 新增rt_sigpending系统调用,支持获取待处理信号
- 添加ppoll测试程序,验证ppoll功能

Signed-off-by: longjin <longjin@DragonOS.org>
LoGin 1 روز پیش
والد
کامیت
2d06264d79

+ 142 - 14
kernel/src/filesystem/poll.rs

@@ -1,12 +1,18 @@
 use core::ffi::c_int;
 
 use crate::{
-    ipc::signal::{RestartBlock, RestartBlockData, RestartFn},
+    arch::ipc::signal::SigSet,
+    ipc::signal::{
+        restore_saved_sigmask_unless, set_user_sigmask, RestartBlock, RestartBlockData, RestartFn,
+    },
     mm::VirtAddr,
     net::event_poll::{EPollCtlOption, EPollEvent, EPollEventType, EventPoll},
     process::ProcessManager,
-    syscall::{user_access::UserBufferWriter, Syscall},
-    time::{Duration, Instant},
+    syscall::{
+        user_access::{UserBufferReader, UserBufferWriter},
+        Syscall,
+    },
+    time::{Duration, Instant, PosixTimeSpec},
 };
 
 use super::vfs::file::{File, FileMode};
@@ -32,11 +38,15 @@ impl<'a> PollAdapter<'a> {
     }
 
     fn add_pollfds(&self) -> Result<(), SystemError> {
-        for pollfd in self.poll_fds.iter() {
+        for (i, pollfd) in self.poll_fds.iter().enumerate() {
+            if pollfd.fd < 0 {
+                continue;
+            }
             let mut epoll_event = EPollEvent::default();
             let poll_flags = PollFlags::from_bits_truncate(pollfd.events);
             let ep_events: EPollEventType = poll_flags.into();
             epoll_event.set_events(ep_events.bits());
+            epoll_event.set_data(i as u64);
 
             EventPoll::epoll_ctl_with_epfile(
                 self.ep_file.clone(),
@@ -64,8 +74,13 @@ impl<'a> PollAdapter<'a> {
             remain_timeout,
         )?;
 
-        for (i, event) in epoll_events.iter().enumerate() {
-            self.poll_fds[i].revents = (event.events() & 0xffff) as u16;
+        for event in epoll_events.iter() {
+            let index = event.data() as usize;
+            if index >= self.poll_fds.len() {
+                log::warn!("poll_all_fds: Invalid index in epoll event: {}", index);
+                continue;
+            }
+            self.poll_fds[index].revents = (event.events() & 0xffff) as u16;
         }
 
         Ok(events)
@@ -74,13 +89,14 @@ impl<'a> PollAdapter<'a> {
 
 impl Syscall {
     /// https://code.dragonos.org.cn/xref/linux-6.6.21/fs/select.c#1068
+    #[inline(never)]
     pub fn poll(pollfd_ptr: usize, nfds: u32, timeout_ms: i32) -> Result<usize, SystemError> {
         let pollfd_ptr = VirtAddr::new(pollfd_ptr);
         let len = nfds as usize * core::mem::size_of::<PollFd>();
 
         let mut timeout: Option<Instant> = None;
         if timeout_ms >= 0 {
-            timeout = poll_select_set_timeout(timeout_ms);
+            timeout = poll_select_set_timeout(timeout_ms as u64);
         }
         let mut poll_fds_writer = UserBufferWriter::new(pollfd_ptr.as_ptr::<PollFd>(), len, true)?;
         let mut r = do_sys_poll(poll_fds_writer.buffer(0)?, timeout);
@@ -92,15 +108,58 @@ impl Syscall {
 
         return r;
     }
-}
 
-/// 计算超时的时刻
-fn poll_select_set_timeout(timeout_ms: i32) -> Option<Instant> {
-    if timeout_ms == 0 {
-        return None;
-    }
+    /// 参考 https://code.dragonos.org.cn/xref/linux-6.1.9/fs/select.c#1101
+    #[inline(never)]
+    pub fn ppoll(
+        pollfd_ptr: usize,
+        nfds: u32,
+        timespec_ptr: usize,
+        sigmask_ptr: usize,
+    ) -> Result<usize, SystemError> {
+        let mut timeout_ts: Option<Instant> = None;
+        let mut sigmask: Option<SigSet> = None;
+        let pollfd_ptr = VirtAddr::new(pollfd_ptr);
+        let pollfds_len = nfds as usize * core::mem::size_of::<PollFd>();
+        let mut poll_fds_writer =
+            UserBufferWriter::new(pollfd_ptr.as_ptr::<PollFd>(), pollfds_len, true)?;
+        let poll_fds = poll_fds_writer.buffer(0)?;
+        if sigmask_ptr != 0 {
+            let sigmask_reader =
+                UserBufferReader::new(sigmask_ptr as *const SigSet, size_of::<SigSet>(), true)?;
+            sigmask = Some(*sigmask_reader.read_one_from_user(0)?);
+        }
+
+        if timespec_ptr != 0 {
+            let tsreader = UserBufferReader::new(
+                timespec_ptr as *const PosixTimeSpec,
+                size_of::<PosixTimeSpec>(),
+                true,
+            )?;
+            let ts: PosixTimeSpec = *tsreader.read_one_from_user(0)?;
+            let timeout_ms = ts.tv_sec * 1000 + ts.tv_nsec / 1_000_000;
+
+            if timeout_ms >= 0 {
+                timeout_ts =
+                    Some(poll_select_set_timeout(timeout_ms as u64).ok_or(SystemError::EINVAL)?);
+            }
+        }
+
+        if let Some(mut sigmask) = sigmask {
+            set_user_sigmask(&mut sigmask);
+        }
+        // log::debug!(
+        //     "ppoll: poll_fds: {:?}, nfds: {}, timeout_ts: {:?},sigmask: {:?}",
+        //     poll_fds,
+        //     nfds,
+        //     timeout_ts,
+        //     sigmask
+        // );
+
+        let r: Result<usize, SystemError> = do_sys_poll(poll_fds, timeout_ts);
 
-    Some(Instant::now() + Duration::from_millis(timeout_ms as u64))
+        return poll_select_finish(timeout_ts, timespec_ptr, PollTimeType::TimeSpec, r);
+    }
 }
 
 fn do_sys_poll(poll_fds: &mut [PollFd], timeout: Option<Instant>) -> Result<usize, SystemError> {
@@ -115,6 +174,75 @@ fn do_sys_poll(poll_fds: &mut [PollFd], timeout: Option<Instant>) -> Result<usiz
     Ok(nevents)
 }
 
+/// 计算超时的时刻
+fn poll_select_set_timeout(timeout_ms: u64) -> Option<Instant> {
+    if timeout_ms == 0 {
+        return None;
+    }
+
+    Some(Instant::now() + Duration::from_millis(timeout_ms))
+}
+
+/// 参考 https://code.dragonos.org.cn/xref/linux-6.1.9/fs/select.c#298
+fn poll_select_finish(
+    end_time: Option<Instant>,
+    user_time_ptr: usize,
+    poll_time_type: PollTimeType,
+    mut result: Result<usize, SystemError>,
+) -> Result<usize, SystemError> {
+    restore_saved_sigmask_unless(result == Err(SystemError::ERESTARTNOHAND));
+
+    if user_time_ptr == 0 {
+        return result;
+    }
+
+    // todo: 处理sticky timeouts
+
+    if end_time.is_none() {
+        return result;
+    }
+
+    let end_time = end_time.unwrap();
+
+    // no update for zero timeout
+    if end_time.total_millis() <= 0 {
+        return result;
+    }
+
+    let ts = Instant::now();
+    let duration = end_time.saturating_sub(ts);
+    let rts: PosixTimeSpec = duration.into();
+
+    match poll_time_type {
+        PollTimeType::TimeSpec => {
+            let mut tswriter = UserBufferWriter::new(
+                user_time_ptr as *mut PosixTimeSpec,
+                size_of::<PosixTimeSpec>(),
+                true,
+            )?;
+            if tswriter.copy_one_to_user(&rts, 0).is_err() {
+                return result;
+            }
+        }
+        _ => todo!(),
+    }
+
+    if result == Err(SystemError::ERESTARTNOHAND) {
+        result = result.map_err(|_| SystemError::EINTR);
+    }
+
+    return result;
+}
+
+#[allow(unused)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+enum PollTimeType {
+    TimeVal,
+    OldTimeVal,
+    TimeSpec,
+    OldTimeSpec,
+}
+
 bitflags! {
     pub struct PollFlags: u16 {
         const POLLIN = 0x0001;

+ 10 - 0
kernel/src/ipc/signal.rs

@@ -419,6 +419,16 @@ pub fn restore_saved_sigmask() {
     }
 }
 
+pub fn restore_saved_sigmask_unless(interrupted: bool) {
+    if interrupted {
+        if !ProcessManager::current_pcb().has_pending_signal_fast() {
+            log::warn!("restore_saved_sigmask_unless: interrupted, but has NO pending signal");
+        }
+    } else {
+        restore_saved_sigmask();
+    }
+}
+
 /// 刷新指定进程的sighand的sigaction,将满足条件的sigaction恢复为默认状态。
 /// 除非某个信号被设置为忽略且 `force_default` 为 `false`,否则都不会将其恢复。
 ///

+ 5 - 1
kernel/src/ipc/signal_types.rs

@@ -78,7 +78,11 @@ impl SignalStruct {
         let mut r = Self {
             inner: Box::<InnerSignalStruct>::default(),
         };
-        let sig_ign = Sigaction::default();
+        let mut sig_ign = Sigaction::default();
+        // 收到忽略的信号,重启系统调用
+        // todo: 看看linux哪些
+        sig_ign.flags_mut().insert(SigFlags::SA_RESTART);
+
         r.inner.handlers[Signal::SIGCHLD as usize - 1] = sig_ign;
         r.inner.handlers[Signal::SIGURG as usize - 1] = sig_ign;
         r.inner.handlers[Signal::SIGWINCH as usize - 1] = sig_ign;

+ 24 - 0
kernel/src/ipc/syscall.rs

@@ -562,4 +562,28 @@ impl Syscall {
             return Ok(0);
         }
     }
+
+    #[inline(never)]
+    pub fn rt_sigpending(user_sigset_ptr: usize, sigsetsize: usize) -> Result<usize, SystemError> {
+        if sigsetsize != size_of::<SigSet>() {
+            return Err(SystemError::EINVAL);
+        }
+
+        let mut user_buffer_writer =
+            UserBufferWriter::new(user_sigset_ptr as *mut SigSet, size_of::<SigSet>(), true)?;
+
+        let pcb = ProcessManager::current_pcb();
+        let siginfo_guard = pcb.sig_info_irqsave();
+        let pending_set = siginfo_guard.sig_pending().signal();
+        let shared_pending_set = siginfo_guard.sig_shared_pending().signal();
+        let blocked_set = *siginfo_guard.sig_blocked();
+        drop(siginfo_guard);
+
+        let mut result = pending_set.union(shared_pending_set);
+        result = result.difference(blocked_set);
+
+        user_buffer_writer.copy_one_to_user(&result, 0)?;
+
+        Ok(0)
+    }
 }

+ 12 - 2
kernel/src/net/event_poll/mod.rs

@@ -531,8 +531,10 @@ impl EventPoll {
                     continue;
                 }
 
-                // 如果有未处理的信号则返回错误
-                if current_pcb.has_pending_signal_fast() {
+                // 如果有未处理且未被屏蔽的信号则返回错误
+                if current_pcb.has_pending_signal_fast()
+                    && current_pcb.has_pending_not_masked_signal()
+                {
                     return Err(SystemError::ERESTARTSYS);
                 }
 
@@ -858,6 +860,14 @@ impl EPollEvent {
     pub fn events(&self) -> u32 {
         self.events
     }
+
+    pub fn set_data(&mut self, data: u64) {
+        self.data = data;
+    }
+
+    pub fn data(&self) -> u64 {
+        self.data
+    }
 }
 
 /// ## epoll_ctl函数的参数

+ 18 - 0
kernel/src/process/mod.rs

@@ -1071,6 +1071,24 @@ impl ProcessControlBlock {
         self.flags.get().contains(ProcessFlags::HAS_PENDING_SIGNAL)
     }
 
+    /// 检查当前进程是否有未被阻塞的待处理信号。
+    ///
+    /// 注:该函数较慢,因此需要与 has_pending_signal_fast 一起使用。
+    pub fn has_pending_not_masked_signal(&self) -> bool {
+        let sig_info = self.sig_info_irqsave();
+        let blocked: SigSet = *sig_info.sig_blocked();
+        let mut pending: SigSet = sig_info.sig_pending().signal();
+        drop(sig_info);
+        pending.remove(blocked);
+        // log::debug!(
+        //     "pending and not masked:{:?}, masked: {:?}",
+        //     pending,
+        //     blocked
+        // );
+        let has_not_masked = !pending.is_empty();
+        return has_not_masked;
+    }
+
     pub fn sig_struct(&self) -> SpinLockGuard<SignalStruct> {
         self.sig_struct.lock_irqsave()
     }

+ 2 - 4
kernel/src/syscall/mod.rs

@@ -883,10 +883,7 @@ impl Syscall {
                 Self::poll(fds, nfds, timeout)
             }
 
-            SYS_PPOLL => {
-                log::warn!("SYS_PPOLL has not yet been implemented");
-                Ok(0)
-            }
+            SYS_PPOLL => Self::ppoll(args[0], args[1] as u32, args[2], args[3]),
 
             SYS_SETPGID => {
                 warn!("SYS_SETPGID has not yet been implemented");
@@ -1233,6 +1230,7 @@ impl Syscall {
             }
             SYS_SETRLIMIT => Ok(0),
             SYS_RESTART_SYSCALL => Self::restart_syscall(),
+            SYS_RT_SIGPENDING => Self::rt_sigpending(args[0], args[1]),
             _ => panic!("Unsupported syscall ID: {}", syscall_num),
         };
 

+ 17 - 0
kernel/src/time/mod.rs

@@ -288,6 +288,23 @@ impl Instant {
         let micros_diff = self.micros - earlier.micros;
         Some(Duration::from_micros(micros_diff as u64))
     }
+
+    /// Saturating subtraction. Computes `self - other`, returning [`Instant::ZERO`] if the result would be negative.
+    ///
+    /// # Arguments
+    ///
+    /// * `other` - The `Instant` to subtract from `self`.
+    ///
+    /// # Returns
+    ///
+    /// The duration between `self` and `other`, or [`Instant::ZERO`] if `other` is later than `self`.
+    pub fn saturating_sub(self, other: Instant) -> Duration {
+        if self.micros >= other.micros {
+            Duration::from_micros((self.micros - other.micros) as u64)
+        } else {
+            Duration::ZERO
+        }
+    }
 }
 
 impl fmt::Display for Instant {

+ 2 - 0
user/apps/test_poll/.gitignore

@@ -1 +1,3 @@
 test_poll
+test_ppoll
+*.o

+ 4 - 1
user/apps/test_poll/Makefile

@@ -8,14 +8,17 @@ BIN_NAME=test_poll
 CC=$(CROSS_COMPILE)gcc
 
 .PHONY: all
-all: main.c
+all: main.c ppoll.c
 	$(CC) -static -o $(BIN_NAME) main.c
+	$(CC) -static -o test_ppoll ppoll.c
 
 .PHONY: install clean
 install: all
 	mv $(BIN_NAME) $(DADK_CURRENT_BUILD_DIR)/$(BIN_NAME)
+	mv test_ppoll $(DADK_CURRENT_BUILD_DIR)/test_ppoll
 
 clean:
 	rm $(BIN_NAME) *.o
+	rm test_ppoll
 
 fmt:

+ 148 - 0
user/apps/test_poll/ppoll.c

@@ -0,0 +1,148 @@
+#include <errno.h>
+#define _GNU_SOURCE
+
+#include <fcntl.h>
+#include <poll.h>
+#include <signal.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/poll.h>
+#include <sys/signalfd.h>
+#include <sys/syscall.h>
+#include <sys/wait.h>
+#include <time.h>
+#include <unistd.h>
+
+#define RED "\x1B[31m"
+#define GREEN "\x1B[32m"
+#define RESET "\x1B[0m"
+
+// 测试用例1:基本功能测试(管道I/O)
+void test_basic_functionality() {
+  int pipefd[2];
+  struct pollfd fds[1];
+  struct timespec timeout = {5, 0}; // 5秒超时
+
+  printf("=== Test 1: Basic functionality test ===\n");
+
+  // 创建管道
+  if (pipe(pipefd) == -1) {
+    perror("pipe creation failed");
+    exit(EXIT_FAILURE);
+  }
+
+  // 设置监听读端管道
+  fds[0].fd = pipefd[0];
+  fds[0].events = POLLIN;
+
+  printf("Test scenario 1: Wait with no data (should timeout)\n");
+  int ret = ppoll(fds, 1, &timeout, NULL);
+  if (ret == 0) {
+    printf(GREEN "Test passed: Correct timeout\n" RESET);
+  } else {
+    printf(RED "Test failed: Return value %d\n" RESET, ret);
+  }
+
+  // 向管道写入数据
+  const char *msg = "test data";
+  write(pipefd[1], msg, strlen(msg));
+
+  printf(
+      "\nTest scenario 2: Should return immediately when data is available\n");
+  timeout.tv_sec = 5;
+  ret = ppoll(fds, 1, &timeout, NULL);
+  if (ret > 0 && (fds[0].revents & POLLIN)) {
+    printf(GREEN "Test passed: Data detected\n" RESET);
+  } else {
+    printf(RED "Test failed: Return value %d, revents %d\n" RESET, ret,
+           fds[0].revents);
+  }
+
+  close(pipefd[0]);
+  close(pipefd[1]);
+}
+
+// 测试用例2:信号屏蔽测试
+void test_signal_handling() {
+  printf("\n=== Test 2: Signal handling test ===\n");
+  sigset_t mask, orig_mask;
+  struct timespec timeout = {5, 0};
+  struct pollfd fds[1];
+
+  fds[0].fd = -1;
+  fds[0].events = 0;
+
+  // 设置信号屏蔽
+  sigemptyset(&mask);
+  sigaddset(&mask, SIGUSR1);
+  // 阻塞SIGUSR1,并保存原来的信号掩码
+  if (sigprocmask(SIG_BLOCK, &mask, &orig_mask)) {
+    perror("sigprocmask");
+    exit(EXIT_FAILURE);
+  }
+
+  printf("Test scenario: Signal should not interrupt when masked\n");
+  pid_t pid = fork();
+  if (pid == 0) { // 子进程
+    sleep(2);     // 等待父进程进入ppoll
+    kill(getppid(), SIGUSR1);
+    exit(0);
+  }
+
+  int ret = ppoll(fds, 1, &timeout, &mask);
+
+  if (ret == 0) {
+    printf(GREEN "Test passed: Completed full 5 second wait\n" RESET);
+  } else {
+    printf(RED "Test failed: Premature return %d\n" RESET, errno);
+  }
+
+  waitpid(pid, NULL, 0);
+
+  // 检查并消费挂起的SIGUSR1信号
+  sigset_t pending;
+  sigpending(&pending);
+  if (sigismember(&pending, SIGUSR1)) {
+    int sig;
+    sigwait(&mask, &sig); // 主动消费信号
+
+    printf("Consumed pending SIGUSR1 signal\n");
+  }
+  // 恢复原来的信号掩码
+  sigprocmask(SIG_SETMASK, &orig_mask, NULL);
+}
+
+// 测试用例3:精确超时测试
+void test_timeout_accuracy() {
+  printf("\n=== Test 3: Timeout accuracy test ===\n");
+  struct timespec start, end, timeout = {0, 500000000};
+  struct pollfd fds[1];
+  fds[0].fd = -1;
+  fds[0].events = 0;
+
+  clock_gettime(CLOCK_MONOTONIC, &start);
+  int ret = ppoll(fds, 1, &timeout, NULL);
+  clock_gettime(CLOCK_MONOTONIC, &end);
+
+  long elapsed = (end.tv_sec - start.tv_sec) * 1000000 +
+                 (end.tv_nsec - start.tv_nsec) / 1000;
+
+  printf("Expected timeout: 500ms, Actual elapsed: %.3fms\n", elapsed / 1000.0);
+  if (labs(elapsed - 500000) < 50000) { // 允许±50ms误差
+    printf(GREEN "Test passed: Timeout within acceptable range\n" RESET);
+  } else {
+    printf(RED "Test failed: Timeout deviation too large\n" RESET);
+  }
+}
+
+int main() {
+  // 设置非阻塞标准输入
+  fcntl(STDIN_FILENO, F_SETFL, O_NONBLOCK);
+
+  test_basic_functionality();
+  test_signal_handling();
+  test_timeout_accuracy();
+
+  return 0;
+}