Browse Source

Add loop_asm macro

mini-ninja-64 1 year ago
parent
commit
d60e71d570
2 changed files with 40 additions and 12 deletions
  1. 31 1
      riscv-rt/macros/src/lib.rs
  2. 9 11
      riscv-rt/src/lib.rs

+ 31 - 1
riscv-rt/macros/src/lib.rs

@@ -9,7 +9,7 @@ extern crate proc_macro2;
 extern crate syn;
 
 use proc_macro2::Span;
-use syn::{parse, spanned::Spanned, FnArg, ItemFn, PathArguments, ReturnType, Type, Visibility};
+use syn::{parse::{self, Parse}, spanned::Spanned, FnArg, ItemFn, PathArguments, ReturnType, Type, Visibility, LitStr, LitInt};
 
 use proc_macro::TokenStream;
 
@@ -205,3 +205,33 @@ pub fn pre_init(args: TokenStream, input: TokenStream) -> TokenStream {
     )
     .into()
 }
+
+struct AsmLoopArgs {
+    asm_template: String,
+    count: usize,
+}
+
+impl Parse for AsmLoopArgs {
+    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+        let template: LitStr = input.parse().unwrap();
+        _ = input.parse::<Token![,]>().unwrap();
+        let count: LitInt = input.parse().unwrap();
+
+        Ok(Self {
+            asm_template: template.value(),
+            count: count.base10_parse().unwrap(),
+        })
+    }
+}
+
+#[proc_macro]
+pub fn loop_asm(input: TokenStream) -> TokenStream {
+    let args = parse_macro_input!(input as AsmLoopArgs);
+
+    let tokens = (0..args.count).map(|i| {
+        let i = i.to_string();
+        let asm = args.asm_template.replace("{}", &i);
+        format!("core::arch::asm!(\"{}\");", asm)
+    }).collect::<Vec<String>>().join("\n");
+    tokens.parse().unwrap()
+}

+ 9 - 11
riscv-rt/src/lib.rs

@@ -524,17 +524,15 @@ pub unsafe extern "C" fn start_rust(a0: usize, a1: usize, a2: usize) -> ! {
         core::arch::asm!("fscsr x0"); // Zero out fcsr register csrrw x0, fcsr, x0
 
         // Zero out floating point registers
-        for i in 0..32 { 
-            if cfg!(all(riscv32, riscvd)) {
-                // rv32 targets with double precision floating point can use fmvp.d.x 
-                // to combine 2 32 bit registers to fill the 64 bit floating point 
-                // register
-                core::arch::asm!("fmvp.d.x f{}, x0, x0", i);
-            } else if cfg!(riscvd) {
-                core::arch::asm!("fmv.d.x f{}, x0", i);
-            } else {
-                core::arch::asm!("fmv.w.x f{}, x0", i);
-            }
+        if cfg!(all(target_arch = "riscv32", riscvd)) {
+            // rv32 targets with double precision floating point can use fmvp.d.x 
+            // to combine 2 32 bit registers to fill the 64 bit floating point 
+            // register
+            riscv_rt_macros::loop_asm!("fmvp.d.x f{}, x0, x0", 32);
+        } else if cfg!(riscvd) {
+            riscv_rt_macros::loop_asm!("fmv.d.x f{}, x0", 32);
+        } else {
+            riscv_rt_macros::loop_asm!("fmv.w.x f{}, x0", 32);
         }
     }