Browse Source

Vast refactor of pwd.h, add getpwent/setpwent/endpwent

jD91mZM2 5 years ago
parent
commit
4f2a93ea90
5 changed files with 255 additions and 186 deletions
  1. 14 6
      src/header/pwd/linux.rs
  2. 202 167
      src/header/pwd/mod.rs
  3. 13 9
      src/header/pwd/redox.rs
  4. 2 1
      src/lib.rs
  5. 24 3
      tests/pwd.c

+ 14 - 6
src/header/pwd/linux.rs

@@ -1,7 +1,15 @@
-pub fn split(line: &[u8]) -> [&[u8]; 7] {
-    let mut parts: [&[u8]; 7] = [&[]; 7];
-    for (i, part) in line.splitn(7, |b| *b == b':').enumerate() {
-        parts[i] = part;
-    }
-    parts
+use super::{parsed, passwd};
+use crate::platform::types::*;
+
+pub fn split(line: &mut [u8]) -> Option<passwd> {
+    let mut parts = line.split_mut(|&c| c == b'\0');
+    Some(passwd {
+        pw_name: parts.next()?.as_mut_ptr() as *mut c_char,
+        pw_passwd: parts.next()?.as_mut_ptr() as *mut c_char,
+        pw_uid: parsed(parts.next())?,
+        pw_gid: parsed(parts.next())?,
+        pw_gecos: parts.next()?.as_mut_ptr() as *mut c_char,
+        pw_dir: parts.next()?.as_mut_ptr() as *mut c_char,
+        pw_shell: parts.next()?.as_mut_ptr() as *mut c_char,
+    })
 }

+ 202 - 167
src/header/pwd/mod.rs

@@ -1,11 +1,23 @@
 //! pwd implementation for relibc
 
-use core::ptr;
+use alloc::{
+    boxed::Box,
+    vec::Vec,
+};
+use core::{
+    ops::{Deref, DerefMut},
+    pin::Pin,
+    ptr,
+};
 
 use crate::{
     fs::File,
-    header::{errno, fcntl},
-    io::{BufRead, BufReader},
+    header::{
+        errno,
+        fcntl,
+        string::strcmp,
+    },
+    io::{prelude::*, BufReader, SeekFrom},
     platform::{self, types::*},
 };
 
@@ -20,6 +32,7 @@ use self::linux as sys;
 use self::redox as sys;
 
 #[repr(C)]
+#[derive(Debug)]
 pub struct passwd {
     pw_name: *mut c_char,
     pw_passwd: *mut c_char,
@@ -30,7 +43,7 @@ pub struct passwd {
     pw_shell: *mut c_char,
 }
 
-static mut PASSWD_BUF: *mut c_char = ptr::null_mut();
+static mut PASSWD_BUF: Option<MaybeAllocated> = None;
 static mut PASSWD: passwd = passwd {
     pw_name: ptr::null_mut(),
     pw_passwd: ptr::null_mut(),
@@ -41,110 +54,155 @@ static mut PASSWD: passwd = passwd {
     pw_shell: ptr::null_mut(),
 };
 
-enum OptionPasswd {
-    Error,
-    NotFound,
-    Found(*mut c_char),
+#[derive(Clone, Copy, Debug)]
+struct DestBuffer {
+    ptr: *mut u8,
+    len: usize,
 }
 
-fn pwd_lookup<F>(
-    out: *mut passwd,
-    alloc: Option<(*mut c_char, size_t)>,
-    mut callback: F,
-) -> OptionPasswd
-where
-    // TODO F: FnMut(impl Iterator<Item = &[u8]>) -> bool
-    F: FnMut(&[&[u8]]) -> bool,
-{
-    let file = match File::open(c_str!("/etc/passwd"), fcntl::O_RDONLY) {
-        Ok(file) => file,
-        Err(_) => return OptionPasswd::Error,
-    };
-
-    let file = BufReader::new(file);
+#[derive(Debug)]
+enum MaybeAllocated {
+    Owned(Pin<Box<[u8]>>),
+    Borrowed(DestBuffer),
+}
+impl Deref for MaybeAllocated {
+    type Target = [u8];
 
-    for line in file.split(b'\n') {
-        let line = match line {
-            Ok(line) => line,
-            Err(err) => unsafe {
-                platform::errno = errno::EIO;
-                return OptionPasswd::Error;
+    fn deref(&self) -> &Self::Target {
+        match self {
+            MaybeAllocated::Owned(boxed) => boxed,
+            MaybeAllocated::Borrowed(dst) => unsafe {
+                core::slice::from_raw_parts(dst.ptr, dst.len)
             },
-        };
-
-        // Parse into passwd
-        let parts: [&[u8]; 7] = sys::split(&line);
-
-        if !callback(&parts) {
-            continue;
         }
+    }
+}
+impl DerefMut for MaybeAllocated {
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        match self {
+            MaybeAllocated::Owned(boxed) => boxed,
+            MaybeAllocated::Borrowed(dst) => unsafe {
+                core::slice::from_raw_parts_mut(dst.ptr, dst.len)
+            },
+        }
+    }
+}
 
-        let len = parts
-            .iter()
-            .enumerate()
-            .filter(|(i, _)| *i != 2 && *i != 3)
-            .map(|(_, part)| part.len() + 1)
-            .sum();
+#[derive(Debug)]
+struct OwnedPwd {
+    buffer: MaybeAllocated,
+    reference: passwd,
+}
 
-        if alloc.map(|(_, s)| len > s as usize).unwrap_or(false) {
-            unsafe {
-                platform::errno = errno::ERANGE;
-            }
-            return OptionPasswd::Error;
+impl OwnedPwd {
+    fn into_global(self) -> *mut passwd {
+        unsafe {
+            PASSWD_BUF = Some(self.buffer);
+            PASSWD = self.reference;
+            &mut PASSWD
         }
+    }
+}
 
-        let alloc = match alloc {
-            Some((alloc, _)) => alloc,
-            None => unsafe { platform::alloc(len) as *mut c_char },
-        };
-        // _ prefix so it won't complain about the trailing
-        // _off += <thing>
-        // in the macro that is never read
-        let mut _off = 0;
+#[derive(Clone, Copy, Debug)]
+enum Cause {
+    Eof,
+    Other,
+}
 
-        let mut parts = parts.iter();
+static mut READER: Option<BufReader<File>> = None;
 
-        macro_rules! copy_into {
-            ($entry:expr) => {
-                debug_assert!(_off as usize <= len);
+fn parsed<I, O>(buf: Option<I>) -> Option<O>
+where
+    I: core::borrow::Borrow<[u8]>,
+    O: core::str::FromStr,
+{
+    let buf = buf?;
+    let string = core::str::from_utf8(buf.borrow()).ok()?;
+    string.parse().ok()
+}
 
-                let src = parts.next().unwrap_or(&(&[] as &[u8])); // this is madness
-                let dst = unsafe { alloc.offset(_off) };
+fn getpwent_r(reader: &mut BufReader<File>, destination: Option<DestBuffer>) -> Result<OwnedPwd, Cause> {
+    let mut buf = Vec::new();
+    if reader.read_until(b'\n', &mut buf).map_err(|_| Cause::Other)? == 0 {
+        return Err(Cause::Eof);
+    }
 
-                for (i, c) in src.iter().enumerate() {
-                    unsafe {
-                        *dst.add(i) = *c as c_char;
-                    }
-                }
-                unsafe {
-                    *dst.add(src.len()) = 0;
+    // Replace all occurences of ':' with terminating NUL byte
+    let mut start = 0;
+    while let Some(i) = memchr::memchr(b':', &buf[start..]) {
+        buf[start + i] = 0;
+        start += i + 1;
+    }
 
-                    $entry = dst;
-                }
-                _off += src.len() as isize + 1;
-            };
-            ($entry:expr,parse) => {
+    // Place terminating NUL byte at the end, replace newline
+    let last = buf.last_mut();
+    if last == Some(&mut b'\n') {
+        *last.unwrap() = 0;
+    } else {
+        buf.push(0);
+    }
+
+    let mut buf = match destination {
+        None => MaybeAllocated::Owned(Box::into_pin(buf.into_boxed_slice())),
+        Some(dst) => {
+            let mut new = MaybeAllocated::Borrowed(dst);
+            if new.len() < buf.len() {
                 unsafe {
-                    $entry = parts
-                        .next()
-                        .and_then(|part| core::str::from_utf8(part).ok())
-                        .and_then(|part| part.parse().ok())
-                        .unwrap_or(0);
+                    platform::errno = errno::ERANGE;
                 }
-            };
-        }
+                return Err(Cause::Other);
+            }
+            new[..buf.len()].copy_from_slice(&buf);
+            new
+        },
+    };
+
+    // Chop up the result into a valid structure
+    let passwd = sys::split(&mut buf).ok_or(Cause::Other)?;
 
-        copy_into!((*out).pw_name);
-        copy_into!((*out).pw_passwd);
-        copy_into!((*out).pw_uid, parse);
-        copy_into!((*out).pw_gid, parse);
-        copy_into!((*out).pw_gecos);
-        copy_into!((*out).pw_dir);
-        copy_into!((*out).pw_shell);
+    Ok(OwnedPwd {
+        buffer: buf,
+        reference: passwd,
+    })
+}
 
-        return OptionPasswd::Found(alloc);
+fn pwd_lookup<F>(mut matches: F, destination: Option<DestBuffer>) -> Result<OwnedPwd, Cause>
+where
+    F: FnMut(&passwd) -> bool,
+{
+    let file = match File::open(c_str!("/etc/passwd"), fcntl::O_RDONLY) {
+        Ok(file) => file,
+        Err(_) => return Err(Cause::Other),
+    };
+
+    let mut reader = BufReader::new(file);
+
+    loop {
+        let entry = getpwent_r(&mut reader, destination)?;
+
+        if matches(&entry.reference) {
+            return Ok(entry);
+        }
+    }
+}
+
+unsafe fn mux(status: Result<OwnedPwd, Cause>, out: *mut passwd, result: *mut *mut passwd) -> c_int {
+    match status {
+        Ok(owned) => {
+            *out = owned.reference;
+            *result = out;
+            0
+        },
+        Err(Cause::Eof) => {
+            *result = ptr::null_mut();
+            0
+        },
+        Err(Cause::Other) => {
+            *result = ptr::null_mut();
+            -1
+        },
     }
-    OptionPasswd::NotFound
 }
 
 #[no_mangle]
@@ -155,31 +213,14 @@ pub unsafe extern "C" fn getpwnam_r(
     size: size_t,
     result: *mut *mut passwd,
 ) -> c_int {
-    match pwd_lookup(out, Some((buf, size)), |parts| {
-        let part = parts.get(0).unwrap_or(&(&[] as &[u8]));
-        for (i, c) in part.iter().enumerate() {
-            // /etc/passwd should not contain any NUL bytes in the middle
-            // of entries, but if this happens, it can't possibly match the
-            // search query since it's NUL terminated.
-            if *c == 0 || *name.add(i) != *c as c_char {
-                return false;
-            }
-        }
-        true
-    }) {
-        OptionPasswd::Error => {
-            *result = ptr::null_mut();
-            -1
-        }
-        OptionPasswd::NotFound => {
-            *result = ptr::null_mut();
-            0
-        }
-        OptionPasswd::Found(_) => {
-            *result = out;
-            0
-        }
-    }
+    mux(
+        pwd_lookup(|parts| strcmp(parts.pw_name, name) == 0, Some(DestBuffer {
+            ptr: buf as *mut u8,
+            len: size,
+        })),
+        out,
+        result,
+    )
 }
 
 #[no_mangle]
@@ -190,68 +231,62 @@ pub unsafe extern "C" fn getpwuid_r(
     size: size_t,
     result: *mut *mut passwd,
 ) -> c_int {
-    match pwd_lookup(out, Some((buf, size)), |parts| {
-        let part = parts
-            .get(2)
-            .and_then(|part| core::str::from_utf8(part).ok())
-            .and_then(|part| part.parse().ok());
-        part == Some(uid)
-    }) {
-        OptionPasswd::Error => {
-            *result = ptr::null_mut();
-            -1
-        }
-        OptionPasswd::NotFound => {
-            *result = ptr::null_mut();
-            0
-        }
-        OptionPasswd::Found(_) => {
-            *result = out;
-            0
-        }
-    }
+    let slice = core::slice::from_raw_parts_mut(buf as *mut u8, size);
+    mux(
+        pwd_lookup(|part| part.pw_uid == uid, Some(DestBuffer {
+            ptr: buf as *mut u8,
+            len: size,
+        })),
+        out,
+        result,
+    )
 }
 
 #[no_mangle]
 pub extern "C" fn getpwnam(name: *const c_char) -> *mut passwd {
-    match pwd_lookup(unsafe { &mut PASSWD }, None, |parts| {
-        let part = parts.get(0).unwrap_or(&(&[] as &[u8]));
-        for (i, c) in part.iter().enumerate() {
-            // /etc/passwd should not contain any NUL bytes in the middle
-            // of entries, but if this happens, it can't possibly match the
-            // search query since it's NUL terminated.
-            if *c == 0 || unsafe { *name.add(i) } != *c as c_char {
-                return false;
+    pwd_lookup(|parts| unsafe { strcmp(parts.pw_name, name) } == 0, None)
+        .map(|res| res.into_global())
+        .unwrap_or(ptr::null_mut())
+}
+
+#[no_mangle]
+pub extern "C" fn getpwuid(uid: uid_t) -> *mut passwd {
+    pwd_lookup(|parts| parts.pw_uid == uid, None)
+        .map(|res| res.into_global())
+        .unwrap_or(ptr::null_mut())
+}
+
+#[no_mangle]
+pub extern "C" fn getpwent() -> *mut passwd {
+    let reader = match unsafe { &mut READER } {
+        Some(reader) => reader,
+        None => {
+            let file = match File::open(c_str!("/etc/passwd"), fcntl::O_RDONLY) {
+                Ok(file) => file,
+                Err(_) => return ptr::null_mut(),
+            };
+            let reader = BufReader::new(file);
+            unsafe {
+                READER = Some(reader);
+                READER.as_mut().unwrap()
             }
         }
-        true
-    }) {
-        OptionPasswd::Error => ptr::null_mut(),
-        OptionPasswd::NotFound => ptr::null_mut(),
-        OptionPasswd::Found(buf) => unsafe {
-            PASSWD_BUF = buf;
-            &mut PASSWD
-        },
+    };
+    getpwent_r(reader, None)
+        .map(|res| res.into_global())
+        .unwrap_or(ptr::null_mut())
+}
+
+#[no_mangle]
+pub extern "C" fn setpwent() {
+    if let Some(reader) = unsafe { &mut READER } {
+        let _ = reader.seek(SeekFrom::Start(0));
     }
 }
 
 #[no_mangle]
-pub extern "C" fn getpwuid(uid: uid_t) -> *mut passwd {
-    match pwd_lookup(unsafe { &mut PASSWD }, None, |parts| {
-        let part = parts
-            .get(2)
-            .and_then(|part| core::str::from_utf8(part).ok())
-            .and_then(|part| part.parse().ok());
-        part == Some(uid)
-    }) {
-        OptionPasswd::Error => ptr::null_mut(),
-        OptionPasswd::NotFound => ptr::null_mut(),
-        OptionPasswd::Found(buf) => unsafe {
-            if !PASSWD_BUF.is_null() {
-                platform::free(PASSWD_BUF as *mut c_void);
-            }
-            PASSWD_BUF = buf;
-            &mut PASSWD
-        },
+pub extern "C" fn endpwent() {
+    unsafe {
+        READER = None;
     }
 }

+ 13 - 9
src/header/pwd/redox.rs

@@ -1,11 +1,15 @@
-pub fn split(line: &[u8]) -> [&[u8]; 7] {
-    let mut parts: [&[u8]; 7] = [&[]; 7];
-    let mut iter = line.split(|b| *b == b';');
+use super::{parsed, passwd};
+use crate::platform::types::*;
 
-    parts[0] = iter.next().unwrap_or(&[]);
-    // Skip passwd
-    for i in 2..7 {
-        parts[i] = iter.next().unwrap_or(&[]);
-    }
-    parts
+pub fn split(line: &mut [u8]) -> Option<passwd> {
+    let mut parts = line.split_mut(|&c| c == b'\0');
+    Some(passwd {
+        pw_name: parts.next()?.as_mut_ptr() as *mut c_char,
+        pw_passwd: "x\0".as_ptr() as *const c_char as *mut c_char,
+        pw_uid: parsed(parts.next())?,
+        pw_gid: parsed(parts.next())?,
+        pw_gecos: parts.next()?.as_mut_ptr() as *mut c_char,
+        pw_dir: parts.next()?.as_mut_ptr() as *mut c_char,
+        pw_shell: parts.next()?.as_mut_ptr() as *mut c_char,
+    })
 }

+ 2 - 1
src/lib.rs

@@ -4,6 +4,7 @@
 #![allow(unused_variables)]
 #![feature(allocator_api)]
 #![feature(asm)]
+#![feature(box_into_pin)]
 #![feature(c_variadic)]
 #![feature(const_fn)]
 #![feature(const_raw_ptr_deref)]
@@ -11,9 +12,9 @@
 #![feature(const_vec_new)]
 #![feature(core_intrinsics)]
 #![feature(global_asm)]
-#![feature(maybe_uninit_extra)]
 #![feature(lang_items)]
 #![feature(linkage)]
+#![feature(maybe_uninit_extra)]
 #![feature(stmt_expr_attributes)]
 #![feature(str_internals)]
 #![feature(thread_local)]

+ 24 - 3
tests/pwd.c

@@ -41,7 +41,7 @@ int main(void) {
     puts("--- Checking getpwuid_r ---");
     struct passwd pwd2;
     struct passwd* result;
-    char* buf = malloc(100);
+    char* buf = malloc(300);
     if (getpwuid_r(0, &pwd2, buf, 100, &result) < 0) {
         perror("getpwuid_r");
         free(buf);
@@ -56,7 +56,7 @@ int main(void) {
     }
 
     puts("--- Checking getpwnam_r ---");
-    if (getpwnam_r("nobody", &pwd2, buf, 100, &result) < 0) {
+    if (getpwnam_r("nobody", &pwd2, buf, 300, &result) < 0) {
         perror("getpwuid_r");
         free(buf);
         exit(EXIT_FAILURE);
@@ -73,7 +73,7 @@ int main(void) {
     puts("--- Checking getpwuid_r error handling ---");
     char buf2[1];
     if (getpwuid_r(0, &pwd2, buf2, 1, &result) == 0) {
-        puts("This shouldn't have succeeded, but did!");
+        puts("This shouldn't have succeeded, but it did!");
         exit(EXIT_FAILURE);
     }
     if (errno != ERANGE) {
@@ -81,4 +81,25 @@ int main(void) {
         exit(EXIT_FAILURE);
     }
     puts("Returned ERANGE because the buffer was too small 👍");
+
+    errno = 0;
+
+    struct passwd *entry = NULL;
+    for (int i = 1; entry = getpwent(); ++i) {
+        int backup = errno;
+        printf("--- getpwent #%d ---\n", i);
+        if (backup != 0) {
+            errno = backup;
+            perror("getpwent");
+            exit(EXIT_FAILURE);
+        }
+        print(entry);
+    }
+    puts("--- getpwent #1 (rewind) ---");
+    setpwent();
+    entry = getpwent();
+    perror("getpwent");
+    print(entry);
+
+    endpwent();
 }