Browse Source

Semaphore improvements

Jeremy Soller 2 years ago
parent
commit
df75b8d037
2 changed files with 53 additions and 23 deletions
  1. 19 8
      src/platform/pte.rs
  2. 34 15
      src/sync/semaphore.rs

+ 19 - 8
src/platform/pte.rs

@@ -8,7 +8,7 @@ use core::{
 };
 
 use crate::{
-    header::{sys_mman, time::timespec},
+    header::{sys_mman, time::{CLOCK_MONOTONIC, clock_gettime, timespec}},
     ld_so::{
         linker::Linker,
         tcb::{Master, Tcb},
@@ -237,7 +237,7 @@ pub unsafe extern "C" fn pte_osThreadWaitForEnd(handle: pte_osThreadHandle) -> p
 #[no_mangle]
 pub unsafe extern "C" fn pte_osThreadCancel(handle: pte_osThreadHandle) -> pte_osResult {
     //TODO: allow cancel of thread
-    println!("pte_osThreadCancel");
+    eprintln!("pte_osThreadCancel");
     PTE_OS_OK
 }
 
@@ -338,7 +338,7 @@ pub unsafe extern "C" fn pte_osSemaphorePost(
     handle: pte_osSemaphoreHandle,
     count: c_int,
 ) -> pte_osResult {
-    (*handle).post();
+    (*handle).post(count);
     PTE_OS_OK
 }
 
@@ -348,15 +348,26 @@ pub unsafe extern "C" fn pte_osSemaphorePend(
     pTimeout: *mut c_uint,
 ) -> pte_osResult {
     let timeout_opt = if !pTimeout.is_null() {
+        // Get current time
+        let mut time = timespec::default();
+        clock_gettime(CLOCK_MONOTONIC, &mut time);
+
+        // Add timeout to time
         let timeout = *pTimeout as i64;
-        let tv_sec = timeout / 1000;
-        let tv_nsec = (timeout % 1000) * 1000000;
-        Some(timespec { tv_sec, tv_nsec })
+        time.tv_sec += timeout / 1000;
+        time.tv_nsec += (timeout % 1000) * 1_000_000;
+        while time.tv_nsec >= 1_000_000_000 {
+            time.tv_sec += 1;
+            time.tv_nsec -= 1_000_000_000;
+        }
+        Some(time)
     } else {
         None
     };
-    (*handle).wait(timeout_opt.as_ref());
-    PTE_OS_OK
+    match (*handle).wait(timeout_opt.as_ref()) {
+        Ok(()) => PTE_OS_OK,
+        Err(()) => PTE_OS_TIMEOUT,
+    }
 }
 
 #[no_mangle]

+ 34 - 15
src/sync/semaphore.rs

@@ -21,27 +21,46 @@ impl Semaphore {
         }
     }
 
-    pub fn post(&self) {
-        self.lock.fetch_add(1, Ordering::Release);
+    pub fn post(&self, count: c_int) {
+        self.lock.fetch_add(count, Ordering::SeqCst);
     }
 
-    pub fn wait(&self, timeout_opt: Option<&timespec>) {
-        if let Some(timeout) = timeout_opt {
-            println!(
-                "semaphore wait tv_sec: {}, tv_nsec: {}",
-                timeout.tv_sec, timeout.tv_nsec
-            );
+    pub fn try_wait(&self) -> Result<(), ()> {
+        let mut value = self.lock.load(Ordering::SeqCst);
+        if value > 0 {
+            match self.lock.compare_exchange(
+                value,
+                value - 1,
+                Ordering::SeqCst,
+                Ordering::SeqCst
+            ) {
+                Ok(_) => Ok(()),
+                Err(_) => Err(())
+            }
+        } else {
+            Err(())
         }
+    }
+
+    pub fn wait(&self, timeout_opt: Option<&timespec>) -> Result<(), ()> {
+
         loop {
-            while self.lock.load(Ordering::Acquire) < 1 {
-                //spin_loop();
-                Sys::sched_yield();
+            match self.try_wait() {
+                Ok(()) => {
+                    return Ok(());
+                }
+                Err(()) => ()
             }
-            let tmp = self.lock.fetch_sub(1, Ordering::AcqRel);
-            if tmp >= 1 {
-                break;
+            if let Some(timeout) = timeout_opt {
+                let mut time = timespec::default();
+                clock_gettime(CLOCK_MONOTONIC, &mut time);
+                if (time.tv_sec > timeout.tv_sec) ||
+                   (time.tv_sec == timeout.tv_sec && time.tv_nsec >= timeout.tv_nsec)
+                {
+                    return Err(())
+                }
             }
-            self.lock.fetch_add(1, Ordering::Release);
+            Sys::sched_yield();
         }
     }
 }