Browse Source

stdio, string, platform: fix a bug in printf() involving chars

Because we were previously converting the bytes in the format
string into Rust's char type and then printing that using the
format machinery, byte values that were not valid single-byte
UTF-8 characters failed to print correctly.  I found this while
trying to implement qsort() because the output of my test program
was mysteriously incorrect despite it working when I used glibc.
Alex Lyon 7 years ago
parent
commit
0cabecd5b5
5 changed files with 93 additions and 60 deletions
  1. 65 20
      src/platform/src/lib.rs
  2. 13 10
      src/stdio/src/lib.rs
  3. 12 16
      src/stdio/src/printf.rs
  4. 1 13
      src/string/src/lib.rs
  5. 2 1
      tests/printf.c

+ 65 - 20
src/platform/src/lib.rs

@@ -58,6 +58,30 @@ pub unsafe fn cstr_from_bytes_with_nul_unchecked(bytes: &[u8]) -> *const c_char
     &*(bytes as *const [u8] as *const c_char)
 }
 
+// NOTE: defined here rather than in string because memcpy() is useful in multiple crates
+pub unsafe fn memcpy(s1: *mut c_void, s2: *const c_void, n: usize) -> *mut c_void {
+    let mut i = 0;
+    while i + 7 < n {
+        *(s1.offset(i as isize) as *mut u64) = *(s2.offset(i as isize) as *const u64);
+        i += 8;
+    }
+    while i < n {
+        *(s1 as *mut u8).offset(i as isize) = *(s2 as *const u8).offset(i as isize);
+        i += 1;
+    }
+    s1
+}
+
+pub trait Write: fmt::Write {
+    fn write_u8(&mut self, byte: u8) -> fmt::Result;
+}
+
+impl<'a, W: Write> Write for &'a mut W {
+    fn write_u8(&mut self, byte: u8) -> fmt::Result {
+        (**self).write_u8(byte)
+    }
+}
+
 pub struct FileWriter(pub c_int);
 
 impl FileWriter {
@@ -73,6 +97,13 @@ impl fmt::Write for FileWriter {
     }
 }
 
+impl Write for FileWriter {
+    fn write_u8(&mut self, byte: u8) -> fmt::Result {
+        self.write(&[byte]);
+        Ok(())
+    }
+}
+
 pub struct FileReader(pub c_int);
 
 impl FileReader {
@@ -85,21 +116,19 @@ pub struct StringWriter(pub *mut u8, pub usize);
 
 impl StringWriter {
     pub unsafe fn write(&mut self, buf: &[u8]) {
-        for &b in buf.iter() {
-            if self.1 == 0 {
-                break;
-            } else if self.1 == 1 {
-                *self.0 = b'\0';
-            } else {
-                *self.0 = b;
-            }
-
-            self.0 = self.0.offset(1);
-            self.1 -= 1;
-
-            if self.1 > 0 {
-                *self.0 = b'\0';
-            }
+        if self.1 > 0 {
+            let copy_size = buf.len().min(self.1 - 1);
+            memcpy(
+                self.0 as *mut c_void,
+                buf.as_ptr() as *const c_void,
+                copy_size,
+            );
+            *self.0.offset(copy_size as isize) = b'\0';
+
+            // XXX: i believe this correctly mimics the behavior from before, but it seems
+            //      incorrect (the next write will write after the NUL)
+            self.0 = self.0.offset(copy_size as isize + 1);
+            self.1 -= copy_size + 1;
         }
     }
 }
@@ -111,15 +140,24 @@ impl fmt::Write for StringWriter {
     }
 }
 
+impl Write for StringWriter {
+    fn write_u8(&mut self, byte: u8) -> fmt::Result {
+        unsafe { self.write(&[byte]) };
+        Ok(())
+    }
+}
+
 pub struct UnsafeStringWriter(pub *mut u8);
 
 impl UnsafeStringWriter {
     pub unsafe fn write(&mut self, buf: &[u8]) {
-        for &b in buf.iter() {
-            *self.0 = b;
-            self.0 = self.0.offset(1);
-            *self.0 = b'\0';
-        }
+        memcpy(
+            self.0 as *mut c_void,
+            buf.as_ptr() as *const c_void,
+            buf.len(),
+        );
+        *self.0.offset(buf.len() as isize) = b'\0';
+        self.0 = self.0.offset(buf.len() as isize);
     }
 }
 
@@ -129,3 +167,10 @@ impl fmt::Write for UnsafeStringWriter {
         Ok(())
     }
 }
+
+impl Write for UnsafeStringWriter {
+    fn write_u8(&mut self, byte: u8) -> fmt::Result {
+        unsafe { self.write(&[byte]) };
+        Ok(())
+    }
+}

+ 13 - 10
src/stdio/src/lib.rs

@@ -11,11 +11,12 @@ extern crate va_list as vl;
 
 use core::str;
 use core::ptr;
-use core::fmt::{Error, Result, Write};
+use core::fmt::{self, Error, Result};
+use core::fmt::Write as WriteFmt;
 use core::sync::atomic::{AtomicBool, Ordering};
 
 use platform::types::*;
-use platform::{c_str, errno};
+use platform::{c_str, errno, Write};
 use errno::STR_ERROR;
 use vl::VaList as va_list;
 
@@ -164,7 +165,7 @@ impl FILE {
         unsafe { platform::lseek(self.fd, off, whence) }
     }
 }
-impl Write for FILE {
+impl fmt::Write for FILE {
     fn write_str(&mut self, s: &str) -> Result {
         let s = s.as_bytes();
         if self.write(s) != s.len() {
@@ -174,6 +175,15 @@ impl Write for FILE {
         }
     }
 }
+impl Write for FILE {
+    fn write_u8(&mut self, byte: u8) -> Result {
+        if self.write(&[byte]) != 1 {
+            Err(Error)
+        } else {
+            Ok(())
+        }
+    }
+}
 
 /// Clears EOF and ERR indicators on a stream
 #[no_mangle]
@@ -872,10 +882,3 @@ pub unsafe extern "C" fn vsnprintf(
 pub unsafe extern "C" fn vsprintf(s: *mut c_char, format: *const c_char, ap: va_list) -> c_int {
     printf::printf(&mut platform::UnsafeStringWriter(s as *mut u8), format, ap)
 }
-
-/*
-#[no_mangle]
-pub extern "C" fn func(args) -> c_int {
-    unimplemented!();
-}
-*/

+ 12 - 16
src/stdio/src/printf.rs

@@ -1,19 +1,18 @@
 use core::{fmt, slice, str};
 
+use platform::{self, Write};
 use platform::types::*;
 use vl::VaList;
 
-pub unsafe fn printf<W: fmt::Write>(mut w: W, format: *const c_char, mut ap: VaList) -> c_int {
-    extern "C" {
-        fn strlen(s: *const c_char) -> size_t;
-    }
-
-    let format = slice::from_raw_parts(format as *const u8, strlen(format));
+pub unsafe fn printf<W: Write>(mut w: W, format: *const c_char, mut ap: VaList) -> c_int {
+    let format = unsafe { slice::from_raw_parts(format as *const u8, usize::max_value()) };
 
-    let mut i = 0;
     let mut found_percent = false;
-    while i < format.len() {
-        let b = format[i];
+    for &b in format.iter() {
+        // check for NUL
+        if b == 0 {
+            break;
+        }
 
         if found_percent {
             match b as char {
@@ -24,7 +23,7 @@ pub unsafe fn printf<W: fmt::Write>(mut w: W, format: *const c_char, mut ap: VaL
                 'c' => {
                     let a = ap.get::<u32>();
 
-                    w.write_char(a as u8 as char);
+                    w.write_u8(a as u8);
 
                     found_percent = false;
                 }
@@ -57,9 +56,8 @@ pub unsafe fn printf<W: fmt::Write>(mut w: W, format: *const c_char, mut ap: VaL
                 's' => {
                     let a = ap.get::<usize>();
 
-                    w.write_str(str::from_utf8_unchecked(slice::from_raw_parts(
-                        a as *const u8,
-                        strlen(a as *const c_char),
+                    w.write_str(str::from_utf8_unchecked(platform::c_str(
+                        a as *const c_char,
                     )));
 
                     found_percent = false;
@@ -102,10 +100,8 @@ pub unsafe fn printf<W: fmt::Write>(mut w: W, format: *const c_char, mut ap: VaL
         } else if b == b'%' {
             found_percent = true;
         } else {
-            w.write_char(b as char);
+            w.write_u8(b);
         }
-
-        i += 1;
     }
 
     0

+ 1 - 13
src/string/src/lib.rs

@@ -58,12 +58,7 @@ pub unsafe extern "C" fn memcmp(s1: *const c_void, s2: *const c_void, n: usize)
 
 #[no_mangle]
 pub unsafe extern "C" fn memcpy(s1: *mut c_void, s2: *const c_void, n: usize) -> *mut c_void {
-    let mut i = 0;
-    while i < n {
-        *(s1 as *mut u8).offset(i as isize) = *(s2 as *const u8).offset(i as isize);
-        i += 1;
-    }
-    s1
+    platform::memcpy(s1, s2, n)
 }
 
 #[no_mangle]
@@ -385,10 +380,3 @@ pub extern "C" fn strtok_r(
 pub extern "C" fn strxfrm(s1: *mut c_char, s2: *const c_char, n: usize) -> size_t {
     unimplemented!();
 }
-
-/*
-#[no_mangle]
-pub extern "C" fn func(args) -> c_int {
-    unimplemented!();
-}
-*/

+ 2 - 1
tests/printf.c

@@ -2,9 +2,10 @@
 
 int main(int argc, char ** argv) {
     printf(
-        "percent: %%\nstring: %s\nchar: %c\nint: %d\nuint: %u\nhex: %x\nHEX: %X\nstring: %s\n",
+        "percent: %%\nstring: %s\nchar: %c\nchar: %c\nint: %d\nuint: %u\nhex: %x\nHEX: %X\nstring: %s\n",
         "String",
         'c',
+        254,
         -16,
         32,
         0xbeef,