浏览代码

Fix snprintf and make strftime use a counting writer

jD91mZM2 6 年之前
父节点
当前提交
40a7380a58

+ 29 - 6
src/platform/src/lib.rs

@@ -181,19 +181,17 @@ pub struct StringWriter(pub *mut u8, pub usize);
 
 impl StringWriter {
     pub unsafe fn write(&mut self, buf: &[u8]) {
-        if self.1 > 0 {
+        if self.1 > 1 {
             let copy_size = buf.len().min(self.1 - 1);
             memcpy(
                 self.0 as *mut c_void,
                 buf.as_ptr() as *const c_void,
                 copy_size,
             );
-            *self.0.offset(copy_size as isize) = b'\0';
+            self.1 -= copy_size;
 
-            // XXX: i believe this correctly mimics the behavior from before, but it seems
-            //      incorrect (the next write will write after the NUL)
-            self.0 = self.0.offset(copy_size as isize + 1);
-            self.1 -= copy_size + 1;
+            self.0 = self.0.offset(copy_size as isize);
+            *self.0 = 0;
         }
     }
 }
@@ -269,3 +267,28 @@ impl Read for UnsafeStringReader {
         }
     }
 }
+
+pub struct CountingWriter<T> {
+    pub inner: T,
+    pub written: usize
+}
+impl<T> CountingWriter<T> {
+    pub /* const */ fn new(writer: T) -> Self {
+        Self {
+            inner: writer,
+            written: 0
+        }
+    }
+}
+impl<T: fmt::Write> fmt::Write for CountingWriter<T> {
+    fn write_str(&mut self, s: &str) -> fmt::Result {
+        self.written += s.len();
+        self.inner.write_str(s)
+    }
+}
+impl<T: Write> Write for CountingWriter<T> {
+    fn write_u8(&mut self, byte: u8) -> fmt::Result {
+        self.written += 1;
+        self.inner.write_u8(byte)
+    }
+}

+ 1 - 1
src/stdio/src/lib.rs

@@ -18,7 +18,7 @@ extern crate va_list as vl;
 use core::fmt::Write as WriteFmt;
 use core::fmt::{self, Error, Result};
 use core::sync::atomic::{AtomicBool, Ordering};
-use core::{mem, ptr, str};
+use core::{ptr, str};
 
 use alloc::vec::Vec;
 use errno::STR_ERROR;

+ 5 - 2
src/stdio/src/printf.rs

@@ -1,10 +1,13 @@
+use core::fmt::Write as CoreWrite;
 use core::{slice, str};
 
 use platform::types::*;
 use platform::{self, Write};
 use vl::VaList;
 
-pub unsafe fn printf<W: Write>(mut w: W, format: *const c_char, mut ap: VaList) -> c_int {
+pub unsafe fn printf<W: Write>(w: W, format: *const c_char, mut ap: VaList) -> c_int {
+    let mut w = platform::CountingWriter::new(w);
+
     let format = slice::from_raw_parts(format as *const u8, usize::max_value());
 
     let mut found_percent = false;
@@ -108,5 +111,5 @@ pub unsafe fn printf<W: Write>(mut w: W, format: *const c_char, mut ap: VaList)
         }
     }
 
-    0
+    w.written as c_int
 }

+ 9 - 7
src/time/src/lib.rs

@@ -3,7 +3,6 @@
 #![no_std]
 #![feature(alloc, const_fn)]
 
-#[macro_use]
 extern crate alloc;
 extern crate errno;
 extern crate platform;
@@ -338,17 +337,20 @@ pub extern "C" fn nanosleep(rqtp: *const timespec, rmtp: *mut timespec) -> c_int
 #[no_mangle]
 pub unsafe extern "C" fn strftime(
     s: *mut c_char,
-    maxsize: usize,
+    maxsize: size_t,
     format: *const c_char,
     timeptr: *const tm,
 ) -> size_t {
-    strftime::strftime(
-        true,
-        &mut platform::UnsafeStringWriter(s as *mut u8),
-        maxsize,
+    let ret = strftime::strftime(
+        &mut platform::StringWriter(s as *mut u8, maxsize),
         format,
         timeptr,
-    )
+    );
+    if ret < maxsize {
+        return ret;
+    } else {
+        return 0;
+    }
 }
 
 // #[no_mangle]

+ 125 - 135
src/time/src/strftime.rs

@@ -1,154 +1,144 @@
 use alloc::string::String;
 use platform::types::*;
-use platform::Write;
+use platform::{self, Write};
 use tm;
 
 pub unsafe fn strftime<W: Write>(
-    toplevel: bool,
-    mut w: &mut W,
-    maxsize: usize,
-    mut format: *const c_char,
+    w: &mut W,
+    format: *const c_char,
     t: *const tm,
 ) -> size_t {
-    let mut written = 0;
-    if toplevel {
-        // Reserve nul byte
-        written += 1;
-    }
-    macro_rules! w {
-        (reserve $amount:expr) => {{
-            if written + $amount > maxsize {
-                return 0;
-            }
-            written += $amount;
-        }};
-        (byte $b:expr) => {{
-            w!(reserve 1);
-            if w.write_u8($b).is_err() {
-                return !0;
-            }
-        }};
-        (char $chr:expr) => {{
-            w!(reserve $chr.len_utf8());
-            if w.write_char($chr).is_err() {
-                return !0;
-            }
-        }};
-        (recurse $fmt:expr) => {{
-            let mut fmt = String::with_capacity($fmt.len() + 1);
-            fmt.push_str($fmt);
-            fmt.push('\0');
-
-            let count = strftime(false, w, maxsize - written, fmt.as_ptr() as *mut c_char, t);
-            if count == 0 {
-                return 0;
-            }
-            written += count;
-            assert!(written <= maxsize);
-        }};
-        ($str:expr) => {{
-            w!(reserve $str.len());
-            if w.write_str($str).is_err() {
-                return !0;
-            }
-        }};
-        ($fmt:expr, $($args:expr),+) => {{
-            // Would use write!() if I could get the length written
-            w!(&format!($fmt, $($args),+))
-        }};
-    }
-    const WDAYS: [&'static str; 7] = [
-        "Sunday",
-        "Monday",
-        "Tuesday",
-        "Wednesday",
-        "Thursday",
-        "Friday",
-        "Saturday",
-    ];
-    const MONTHS: [&'static str; 12] = [
-        "January",
-        "Febuary",
-        "March",
-        "April",
-        "May",
-        "June",
-        "July",
-        "August",
-        "September",
-        "October",
-        "November",
-        "December",
-    ];
+    pub unsafe fn inner_strftime<W: Write>(
+        mut w: &mut W,
+        mut format: *const c_char,
+        t: *const tm,
+    ) -> bool {
+        macro_rules! w {
+            (byte $b:expr) => {{
+                if w.write_u8($b).is_err() {
+                    return false;
+                }
+            }};
+            (char $chr:expr) => {{
+                if w.write_char($chr).is_err() {
+                    return false;
+                }
+            }};
+            (recurse $fmt:expr) => {{
+                let mut fmt = String::with_capacity($fmt.len() + 1);
+                fmt.push_str($fmt);
+                fmt.push('\0');
 
-    while *format != 0 {
-        if *format as u8 != b'%' {
-            w!(byte(*format as u8));
-            format = format.offset(1);
-            continue;
+                if !inner_strftime(w, fmt.as_ptr() as *mut c_char, t) {
+                    return false;
+                }
+            }};
+            ($str:expr) => {{
+                if w.write_str($str).is_err() {
+                    return false;
+                }
+            }};
+            ($fmt:expr, $($args:expr),+) => {{
+                // Would use write!() if I could get the length written
+                if write!(w, $fmt, $($args),+).is_err() {
+                    return false;
+                }
+            }};
         }
+        const WDAYS: [&'static str; 7] = [
+            "Sunday",
+            "Monday",
+            "Tuesday",
+            "Wednesday",
+            "Thursday",
+            "Friday",
+            "Saturday",
+        ];
+        const MONTHS: [&'static str; 12] = [
+            "January",
+            "Febuary",
+            "March",
+            "April",
+            "May",
+            "June",
+            "July",
+            "August",
+            "September",
+            "October",
+            "November",
+            "December",
+        ];
 
-        format = format.offset(1);
+        while *format != 0 {
+            if *format as u8 != b'%' {
+                w!(byte *format as u8);
+                format = format.offset(1);
+                continue;
+            }
 
-        if *format as u8 == b'E' || *format as u8 == b'O' {
-            // Ignore because these do nothing without locale
             format = format.offset(1);
-        }
 
-        match *format as u8 {
-            b'%' => w!(byte b'%'),
-            b'n' => w!(byte b'\n'),
-            b't' => w!(byte b'\t'),
-            b'a' => w!(&WDAYS[(*t).tm_wday as usize][..3]),
-            b'A' => w!(WDAYS[(*t).tm_wday as usize]),
-            b'b' | b'h' => w!(&MONTHS[(*t).tm_mon as usize][..3]),
-            b'B' => w!(MONTHS[(*t).tm_mon as usize]),
-            b'C' => {
-                let mut year = (*t).tm_year / 100;
-                // Round up
-                if (*t).tm_year % 100 != 0 {
-                    year += 1;
+            if *format as u8 == b'E' || *format as u8 == b'O' {
+                // Ignore because these do nothing without locale
+                format = format.offset(1);
+            }
+
+            match *format as u8 {
+                b'%' => w!(byte b'%'),
+                b'n' => w!(byte b'\n'),
+                b't' => w!(byte b'\t'),
+                b'a' => w!(&WDAYS[(*t).tm_wday as usize][..3]),
+                b'A' => w!(WDAYS[(*t).tm_wday as usize]),
+                b'b' | b'h' => w!(&MONTHS[(*t).tm_mon as usize][..3]),
+                b'B' => w!(MONTHS[(*t).tm_mon as usize]),
+                b'C' => {
+                    let mut year = (*t).tm_year / 100;
+                    // Round up
+                    if (*t).tm_year % 100 != 0 {
+                        year += 1;
+                    }
+                    w!("{:02}", year + 19);
                 }
-                w!("{:02}", year + 19);
+                b'd' => w!("{:02}", (*t).tm_mday),
+                b'D' => w!(recurse "%m/%d/%y"),
+                b'e' => w!("{:2}", (*t).tm_mday),
+                b'F' => w!(recurse "%Y-%m-%d"),
+                b'H' => w!("{:02}", (*t).tm_hour),
+                b'I' => w!("{:02}", ((*t).tm_hour + 12 - 1) % 12 + 1),
+                b'j' => w!("{:03}", (*t).tm_yday),
+                b'k' => w!("{:2}", (*t).tm_hour),
+                b'l' => w!("{:2}", ((*t).tm_hour + 12 - 1) % 12 + 1),
+                b'm' => w!("{:02}", (*t).tm_mon + 1),
+                b'M' => w!("{:02}", (*t).tm_min),
+                b'p' => w!(if (*t).tm_hour < 12 { "AM" } else { "PM" }),
+                b'P' => w!(if (*t).tm_hour < 12 { "am" } else { "pm" }),
+                b'r' => w!(recurse "%I:%M:%S %p"),
+                b'R' => w!(recurse "%H:%M"),
+                // Nothing is modified in mktime, but the C standard of course requires a mutable pointer ._.
+                b's' => w!("{}", ::mktime(t as *mut tm)),
+                b'S' => w!("{:02}", (*t).tm_sec),
+                b'T' => w!(recurse "%H:%M:%S"),
+                b'u' => w!("{}", ((*t).tm_wday + 7 - 1) % 7 + 1),
+                b'U' => w!("{}", ((*t).tm_yday + 7 - (*t).tm_wday) / 7),
+                b'w' => w!("{}", (*t).tm_wday),
+                b'W' => w!("{}", ((*t).tm_yday + 7 - ((*t).tm_wday + 6) % 7) / 7),
+                b'y' => w!("{:02}", (*t).tm_year % 100),
+                b'Y' => w!("{}", (*t).tm_year + 1900),
+                b'z' => w!("+0000"), // TODO
+                b'Z' => w!("UTC"),   // TODO
+                b'+' => w!(recurse "%a %b %d %T %Z %Y"),
+                _ => return false,
             }
-            b'd' => w!("{:02}", (*t).tm_mday),
-            b'D' => w!(recurse "%m/%d/%y"),
-            b'e' => w!("{:2}", (*t).tm_mday),
-            b'F' => w!(recurse "%Y-%m-%d"),
-            b'H' => w!("{:02}", (*t).tm_hour),
-            b'I' => w!("{:02}", ((*t).tm_hour + 12 - 1) % 12 + 1),
-            b'j' => w!("{:03}", (*t).tm_yday),
-            b'k' => w!("{:2}", (*t).tm_hour),
-            b'l' => w!("{:2}", ((*t).tm_hour + 12 - 1) % 12 + 1),
-            b'm' => w!("{:02}", (*t).tm_mon + 1),
-            b'M' => w!("{:02}", (*t).tm_min),
-            b'p' => w!(if (*t).tm_hour < 12 { "AM" } else { "PM" }),
-            b'P' => w!(if (*t).tm_hour < 12 { "am" } else { "pm" }),
-            b'r' => w!(recurse "%I:%M:%S %p"),
-            b'R' => w!(recurse "%H:%M"),
-            // Nothing is modified in mktime, but the C standard of course requires a mutable pointer ._.
-            b's' => w!("{}", ::mktime(t as *mut tm)),
-            b'S' => w!("{:02}", (*t).tm_sec),
-            b'T' => w!(recurse "%H:%M:%S"),
-            b'u' => w!("{}", ((*t).tm_wday + 7 - 1) % 7 + 1),
-            b'U' => w!("{}", ((*t).tm_yday + 7 - (*t).tm_wday) / 7),
-            b'w' => w!("{}", (*t).tm_wday),
-            b'W' => w!("{}", ((*t).tm_yday + 7 - ((*t).tm_wday + 6) % 7) / 7),
-            b'y' => w!("{:02}", (*t).tm_year % 100),
-            b'Y' => w!("{}", (*t).tm_year + 1900),
-            b'z' => w!("+0000"), // TODO
-            b'Z' => w!("UTC"),   // TODO
-            b'+' => w!(recurse "%a %b %d %T %Z %Y"),
-            _ => return 0,
-        }
 
-        format = format.offset(1);
-    }
-    if toplevel {
-        // nul byte is already counted in written
-        if w.write_u8(0).is_err() {
-            return !0;
+            format = format.offset(1);
         }
+        true
     }
-    written
+
+    let mut w = platform::CountingWriter::new(w);
+    if !inner_strftime(&mut w, format, t) {
+        return 0;
+    }
+
+    w.written
 }

+ 1 - 0
tests/expected/stdio/printf.stdout

@@ -7,3 +7,4 @@ uint: 32
 hex: beef
 HEX: C0FFEE
 string: end
+len of previous write: 94

+ 1 - 0
tests/expected/stdio/sprintf.stdout

@@ -0,0 +1 @@
+This stri

+ 9 - 8
tests/expected/time/strftime.stdout

@@ -1,8 +1,9 @@
-21: Tue Tuesday Jul July
-17: The 21st century
-12: 06:25:42 AM
-12: 03:00:00 PM
-6: 15:00
-16: 15 1531839600 2
-7: 197 28
-29: Tue Jul 17 15:00:00 UTC 2018
+20: Tue Tuesday Jul July
+16: The 21st century
+11: 06:25:42 AM
+11: 03:00:00 PM
+5: 15:00
+15: 15 1531839600 2
+6: 197 28
+28: Tue Jul 17 15:00:00 UTC 2018
+0: Tue Aug 07 19:17:11 UTC 2018Tue Aug 07 19:17:11 U

+ 2 - 1
tests/stdio/printf.c

@@ -1,7 +1,7 @@
 #include <stdio.h>
 
 int main(int argc, char ** argv) {
-    printf(
+    int len = printf(
         "percent: %%\nstring: %s\nchar: %c\nchar: %c\nint: %d\nuint: %u\nhex: %x\nHEX: %X\nstring: %s\n",
         "String",
         'c',
@@ -12,5 +12,6 @@ int main(int argc, char ** argv) {
         0xC0FFEE,
         "end"
     );
+    printf("len of previous write: %d\n", len);
     return 0;
 }

+ 12 - 11
tests/stdio/sprintf.c

@@ -1,30 +1,31 @@
-
 #include <stdio.h>
+#include <string.h>
 
 int main(int argc, char ** argv) {
     char buffer[72];
+
     int ret = sprintf(
         buffer,
         "This string fits in the buffer because it is only %d bytes in length",
         68
     );
-
-    if (ret) {
-        printf("Failed! %d\n", ret);
+    if (ret != 68) {
+        printf("Failed! Return value was %d\n", ret);
         return -1;
     }
 
+    memset(buffer, 0, sizeof(buffer));
+
     ret = snprintf(
         buffer,
-        72,
+        10,
         "This string is way longer and does not fit in the buffer because it %d bytes in length",
-        87
+        86
     );
-
-    if (!ret) {
-        return 0;
-    } else {
-        printf("Failed! %d", ret);
+    if (ret != 86) {
+        printf("Failed! Return value was %d\n", ret);
         return -1;
     }
+
+    puts(buffer);
 }

+ 1 - 0
tests/time/strftime.c

@@ -17,4 +17,5 @@ int main() {
     print(1531839600, "%H %s %u");
     print(1531839600, "%j %U");
     print(1531839600, "%+");
+    print(1533669431, "%+%+%+%+%+"); // will overflow 50 characters
 }