lib.rs 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. extern crate alloc;
  2. extern crate quote;
  3. use proc_macro::TokenStream;
  4. use quote::quote;
  5. use syn::{
  6. __private::ToTokens,
  7. parse::{self, Parse, ParseStream},
  8. spanned::Spanned,
  9. ItemFn, Path,
  10. };
  11. use uuid::Uuid;
  12. /// 统一初始化宏,
  13. /// 用于将函数注册到统一初始化列表中
  14. ///
  15. /// ## 用法
  16. ///
  17. /// ```rust
  18. /// use system_error::SystemError;
  19. /// use unified_init::define_unified_initializer_slice;
  20. /// use unified_init_macros::unified_init;
  21. ///
  22. /// /// 初始化函数都将会被放到这个列表中
  23. /// define_unified_initializer_slice!(INITIALIZER_LIST);
  24. ///
  25. /// #[unified_init(INITIALIZER_LIST)]
  26. /// fn init1() -> Result<(), SystemError> {
  27. /// Ok(())
  28. /// }
  29. ///
  30. /// #[unified_init(INITIALIZER_LIST)]
  31. /// fn init2() -> Result<(), SystemError> {
  32. /// Ok(())
  33. /// }
  34. ///
  35. /// fn main() {
  36. /// assert_eq!(INITIALIZER_LIST.len(), 2);
  37. /// }
  38. ///
  39. /// ```
  40. #[proc_macro_attribute]
  41. pub fn unified_init(args: TokenStream, input: TokenStream) -> TokenStream {
  42. do_unified_init(args, input)
  43. .unwrap_or_else(|e| e.to_compile_error())
  44. .into()
  45. }
  46. fn do_unified_init(args: TokenStream, input: TokenStream) -> syn::Result<proc_macro2::TokenStream> {
  47. // 解析属性数
  48. let attr_arg = syn::parse::<UnifiedInitArg>(args)?;
  49. // 获取当前函数
  50. let function = syn::parse::<ItemFn>(input)?;
  51. // 检查函数签名
  52. check_function_signature(&function)?;
  53. // 添加#[::linkme::distributed_slice(attr_args.initializer_instance)]属性
  54. let target_slice = attr_arg.initializer_instance.get_ident().unwrap();
  55. // 在旁边添加一个UnifiedInitializer
  56. let initializer =
  57. generate_unified_initializer(&function, target_slice, function.sig.ident.to_string())?;
  58. // 拼接
  59. let mut output = proc_macro2::TokenStream::new();
  60. output.extend(function.into_token_stream());
  61. output.extend(initializer);
  62. Ok(output)
  63. }
  64. /// 检查函数签名是否满足要求
  65. /// 函数签名应该为
  66. ///
  67. /// ```rust
  68. /// use system_error::SystemError;
  69. /// fn xxx() -> Result<(), SystemError> {
  70. /// Ok(())
  71. /// }
  72. /// ```
  73. fn check_function_signature(function: &ItemFn) -> syn::Result<()> {
  74. // 检查函数签名
  75. if !function.sig.inputs.is_empty() {
  76. return Err(syn::Error::new(
  77. function.sig.inputs.span(),
  78. "Expected no arguments",
  79. ));
  80. }
  81. if let syn::ReturnType::Type(_, ty) = &function.sig.output {
  82. // 确认返回类型为 Result<(), SystemError>
  83. // 解析类型
  84. let output_type: syn::Type = syn::parse2(ty.clone().into_token_stream())?;
  85. // 检查类型是否为 Result<(), SystemError>
  86. if let syn::Type::Path(type_path) = output_type {
  87. if type_path.path.segments.last().unwrap().ident == "Result" {
  88. // 检查泛型参数,看看是否满足 Result<(), SystemError>
  89. if let syn::PathArguments::AngleBracketed(generic_args) =
  90. type_path.path.segments.last().unwrap().arguments.clone()
  91. {
  92. if generic_args.args.len() != 2 {
  93. return Err(syn::Error::new(
  94. generic_args.span(),
  95. "Expected two generic arguments",
  96. ));
  97. }
  98. // 检查第一个泛型参数是否为()
  99. if let syn::GenericArgument::Type(type_arg) = generic_args.args.first().unwrap()
  100. {
  101. if let syn::Type::Tuple(tuple) = type_arg {
  102. if !tuple.elems.is_empty() {
  103. return Err(syn::Error::new(tuple.span(), "Expected empty tuple"));
  104. }
  105. } else {
  106. return Err(syn::Error::new(type_arg.span(), "Expected empty tuple"));
  107. }
  108. } else {
  109. return Err(syn::Error::new(
  110. generic_args.span(),
  111. "Expected first generic argument to be a type",
  112. ));
  113. }
  114. // 检查第二个泛型参数是否为SystemError
  115. if let syn::GenericArgument::Type(type_arg) = generic_args.args.last().unwrap()
  116. {
  117. if let syn::Type::Path(type_path) = type_arg {
  118. if type_path.path.segments.last().unwrap().ident == "SystemError" {
  119. // 类型匹配,返回 Ok
  120. return Ok(());
  121. }
  122. }
  123. } else {
  124. return Err(syn::Error::new(
  125. generic_args.span(),
  126. "Expected second generic argument to be a type",
  127. ));
  128. }
  129. return Err(syn::Error::new(
  130. generic_args.span(),
  131. "Expected second generic argument to be SystemError",
  132. ));
  133. }
  134. return Ok(());
  135. }
  136. }
  137. }
  138. Err(syn::Error::new(
  139. function.sig.output.span(),
  140. "Expected -> Result<(), SystemError>",
  141. ))
  142. }
  143. /// 生成UnifiedInitializer全局变量
  144. fn generate_unified_initializer(
  145. function: &ItemFn,
  146. target_slice: &syn::Ident,
  147. raw_initializer_name: String,
  148. ) -> syn::Result<proc_macro2::TokenStream> {
  149. let initializer_name = format!(
  150. "unified_initializer_{}_{}",
  151. raw_initializer_name,
  152. &Uuid::new_v4().to_simple().to_string().to_ascii_uppercase()[..8]
  153. )
  154. .to_ascii_uppercase();
  155. // 获取函数的全名
  156. let initializer_name_ident = syn::Ident::new(&initializer_name, function.sig.ident.span());
  157. let function_ident = &function.sig.ident;
  158. // 生成UnifiedInitializer
  159. let initializer = quote! {
  160. #[::linkme::distributed_slice(#target_slice)]
  161. static #initializer_name_ident: unified_init::UnifiedInitializer = ::unified_init::UnifiedInitializer::new(#raw_initializer_name, &(#function_ident as ::unified_init::UnifiedInitFunction));
  162. };
  163. Ok(initializer)
  164. }
  165. struct UnifiedInitArg {
  166. initializer_instance: Path,
  167. }
  168. impl Parse for UnifiedInitArg {
  169. fn parse(input: ParseStream) -> parse::Result<Self> {
  170. let mut initializer_instance = None;
  171. while !input.is_empty() {
  172. if initializer_instance.is_some() {
  173. return Err(parse::Error::new(
  174. input.span(),
  175. "Expected exactly one initializer instance",
  176. ));
  177. }
  178. // 解析Ident
  179. let ident = input.parse::<syn::Ident>()?;
  180. // 将Ident转换为Path
  181. let initializer = syn::Path::from(ident);
  182. initializer_instance = Some(initializer);
  183. }
  184. if initializer_instance.is_none() {
  185. return Err(parse::Error::new(
  186. input.span(),
  187. "Expected exactly one initializer instance",
  188. ));
  189. }
  190. // 判断是否为标识符
  191. if initializer_instance.as_ref().unwrap().get_ident().is_none() {
  192. return Err(parse::Error::new(
  193. initializer_instance.span(),
  194. "Expected identifier",
  195. ));
  196. }
  197. Ok(UnifiedInitArg {
  198. initializer_instance: initializer_instance.unwrap(),
  199. })
  200. }
  201. }