hart_mask.rs 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. use super::{
  2. mask_commons::{MaskError, has_bit, valid_bit},
  3. sbi_ret::SbiRegister,
  4. };
  5. /// Hart mask structure in SBI function calls.
  6. #[repr(C)]
  7. #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
  8. pub struct HartMask<T = usize> {
  9. hart_mask: T,
  10. hart_mask_base: T,
  11. }
  12. impl<T: SbiRegister> HartMask<T> {
  13. /// Special value to ignore the `mask`, and consider all `bit`s as set.
  14. pub const IGNORE_MASK: T = T::FULL_MASK;
  15. /// Construct a [HartMask] from mask value and base hart id.
  16. #[inline]
  17. pub const fn from_mask_base(hart_mask: T, hart_mask_base: T) -> Self {
  18. Self {
  19. hart_mask,
  20. hart_mask_base,
  21. }
  22. }
  23. /// Construct a [HartMask] that selects all available harts on the current environment.
  24. ///
  25. /// According to the RISC-V SBI Specification, `hart_mask_base` can be set to `-1` (i.e. `usize::MAX`)
  26. /// to indicate that `hart_mask` shall be ignored and all available harts must be considered.
  27. /// In case of this function in the `sbi-spec` crate, we fill in `usize::MAX` in `hart_mask_base`
  28. /// parameter to match the RISC-V SBI standard, while choosing 0 as the ignored `hart_mask` value.
  29. #[inline]
  30. pub const fn all() -> Self {
  31. Self {
  32. hart_mask: T::ZERO,
  33. hart_mask_base: T::FULL_MASK,
  34. }
  35. }
  36. /// Gets the special value for ignoring the `mask` parameter.
  37. #[inline]
  38. pub const fn ignore_mask(&self) -> T {
  39. Self::IGNORE_MASK
  40. }
  41. /// Returns `mask` and `base` parameters from the [HartMask].
  42. #[inline]
  43. pub const fn into_inner(self) -> (T, T) {
  44. (self.hart_mask, self.hart_mask_base)
  45. }
  46. }
  47. // FIXME: implement for T: SbiRegister once we can implement this using const traits.
  48. // Ref: https://rust-lang.github.io/rust-project-goals/2024h2/const-traits.html
  49. impl HartMask<usize> {
  50. /// Returns whether the [HartMask] contains the provided `hart_id`.
  51. #[inline]
  52. pub const fn has_bit(self, hart_id: usize) -> bool {
  53. has_bit(
  54. self.hart_mask,
  55. self.hart_mask_base,
  56. Self::IGNORE_MASK,
  57. hart_id,
  58. )
  59. }
  60. /// Insert a hart id into this [HartMask].
  61. ///
  62. /// Returns error when `hart_id` is invalid.
  63. #[inline]
  64. pub const fn insert(&mut self, hart_id: usize) -> Result<(), MaskError> {
  65. if self.hart_mask_base == Self::IGNORE_MASK {
  66. Ok(())
  67. } else if valid_bit(self.hart_mask_base, hart_id) {
  68. self.hart_mask |= 1usize << (hart_id - self.hart_mask_base);
  69. Ok(())
  70. } else {
  71. Err(MaskError::InvalidBit)
  72. }
  73. }
  74. /// Remove a hart id from this [HartMask].
  75. ///
  76. /// Returns error when `hart_id` is invalid, or it has been ignored.
  77. #[inline]
  78. pub const fn remove(&mut self, hart_id: usize) -> Result<(), MaskError> {
  79. if self.hart_mask_base == Self::IGNORE_MASK {
  80. Err(MaskError::Ignored)
  81. } else if valid_bit(self.hart_mask_base, hart_id) {
  82. self.hart_mask &= !(1usize << (hart_id - self.hart_mask_base));
  83. Ok(())
  84. } else {
  85. Err(MaskError::InvalidBit)
  86. }
  87. }
  88. /// Returns [HartIds] of self.
  89. #[inline]
  90. pub const fn iter(&self) -> HartIds {
  91. HartIds {
  92. inner: match self.hart_mask_base {
  93. Self::IGNORE_MASK => UnvisitedMask::Range(0, usize::MAX),
  94. _ => UnvisitedMask::MaskBase(self.hart_mask, self.hart_mask_base),
  95. },
  96. }
  97. }
  98. }
  99. impl IntoIterator for HartMask {
  100. type Item = usize;
  101. type IntoIter = HartIds;
  102. #[inline]
  103. fn into_iter(self) -> Self::IntoIter {
  104. self.iter()
  105. }
  106. }
  107. /// Iterator structure for `HartMask`.
  108. ///
  109. /// It will iterate hart id from low to high.
  110. #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
  111. pub struct HartIds {
  112. inner: UnvisitedMask,
  113. }
  114. #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
  115. enum UnvisitedMask {
  116. MaskBase(usize, usize),
  117. Range(usize, usize),
  118. }
  119. impl Iterator for HartIds {
  120. type Item = usize;
  121. #[inline]
  122. fn next(&mut self) -> Option<Self::Item> {
  123. match &mut self.inner {
  124. UnvisitedMask::MaskBase(0, _base) => None,
  125. UnvisitedMask::MaskBase(unvisited_mask, base) => {
  126. let low_bit = unvisited_mask.trailing_zeros();
  127. let hart_id = usize::try_from(low_bit).unwrap() + *base;
  128. *unvisited_mask &= !(1usize << low_bit);
  129. Some(hart_id)
  130. }
  131. UnvisitedMask::Range(start, end) => {
  132. assert!(start <= end);
  133. if *start < *end {
  134. let ans = *start;
  135. *start += 1;
  136. Some(ans)
  137. } else {
  138. None
  139. }
  140. }
  141. }
  142. }
  143. #[inline]
  144. fn size_hint(&self) -> (usize, Option<usize>) {
  145. match self.inner {
  146. UnvisitedMask::MaskBase(unvisited_mask, _base) => {
  147. let exact_popcnt = usize::try_from(unvisited_mask.count_ones()).unwrap();
  148. (exact_popcnt, Some(exact_popcnt))
  149. }
  150. UnvisitedMask::Range(start, end) => {
  151. assert!(start <= end);
  152. let exact_num_harts = end - start;
  153. (exact_num_harts, Some(exact_num_harts))
  154. }
  155. }
  156. }
  157. #[inline]
  158. fn count(self) -> usize {
  159. self.size_hint().0
  160. }
  161. #[inline]
  162. fn last(mut self) -> Option<Self::Item> {
  163. self.next_back()
  164. }
  165. #[inline]
  166. fn min(mut self) -> Option<Self::Item> {
  167. self.next()
  168. }
  169. #[inline]
  170. fn max(mut self) -> Option<Self::Item> {
  171. self.next_back()
  172. }
  173. #[inline]
  174. fn is_sorted(self) -> bool {
  175. true
  176. }
  177. // TODO: implement fn advance_by once it's stablized: https://github.com/rust-lang/rust/issues/77404
  178. // #[inline]
  179. // fn advance_by(&mut self, n: usize) -> Result<(), core::num::NonZero<usize>> { ... }
  180. }
  181. impl DoubleEndedIterator for HartIds {
  182. #[inline]
  183. fn next_back(&mut self) -> Option<Self::Item> {
  184. match &mut self.inner {
  185. UnvisitedMask::MaskBase(0, _base) => None,
  186. UnvisitedMask::MaskBase(unvisited_mask, base) => {
  187. let high_bit = unvisited_mask.leading_zeros();
  188. let hart_id = usize::try_from(usize::BITS - high_bit - 1).unwrap() + *base;
  189. *unvisited_mask &= !(1usize << (usize::BITS - high_bit - 1));
  190. Some(hart_id)
  191. }
  192. UnvisitedMask::Range(start, end) => {
  193. assert!(start <= end);
  194. if *start < *end {
  195. let ans = *end;
  196. *end -= 1;
  197. Some(ans)
  198. } else {
  199. None
  200. }
  201. }
  202. }
  203. }
  204. // TODO: implement advance_back_by once stablized.
  205. // #[inline]
  206. // fn advance_back_by(&mut self, n: usize) -> Result<(), core::num::NonZero<usize>> { ... }
  207. }
  208. impl ExactSizeIterator for HartIds {}
  209. impl core::iter::FusedIterator for HartIds {}
  210. #[cfg(test)]
  211. mod tests {
  212. use super::*;
  213. #[test]
  214. fn rustsbi_hart_mask() {
  215. let mask = HartMask::from_mask_base(0b1, 400);
  216. assert!(!mask.has_bit(0));
  217. assert!(mask.has_bit(400));
  218. assert!(!mask.has_bit(401));
  219. let mask = HartMask::from_mask_base(0b110, 500);
  220. assert!(!mask.has_bit(0));
  221. assert!(!mask.has_bit(500));
  222. assert!(mask.has_bit(501));
  223. assert!(mask.has_bit(502));
  224. assert!(!mask.has_bit(500 + (usize::BITS as usize)));
  225. let max_bit = 1 << (usize::BITS - 1);
  226. let mask = HartMask::from_mask_base(max_bit, 600);
  227. assert!(mask.has_bit(600 + (usize::BITS as usize) - 1));
  228. assert!(!mask.has_bit(600 + (usize::BITS as usize)));
  229. let mask = HartMask::from_mask_base(0b11, usize::MAX - 1);
  230. assert!(!mask.has_bit(usize::MAX - 2));
  231. assert!(mask.has_bit(usize::MAX - 1));
  232. assert!(mask.has_bit(usize::MAX));
  233. assert!(!mask.has_bit(0));
  234. // hart_mask_base == usize::MAX is special, it means hart_mask should be ignored
  235. // and this hart mask contains all harts available
  236. let mask = HartMask::from_mask_base(0, usize::MAX);
  237. for i in 0..5 {
  238. assert!(mask.has_bit(i));
  239. }
  240. assert!(mask.has_bit(usize::MAX));
  241. let mut mask = HartMask::from_mask_base(0, 1);
  242. assert!(!mask.has_bit(1));
  243. assert!(mask.insert(1).is_ok());
  244. assert!(mask.has_bit(1));
  245. assert!(mask.remove(1).is_ok());
  246. assert!(!mask.has_bit(1));
  247. }
  248. #[test]
  249. fn rustsbi_hart_ids_iterator() {
  250. let mask = HartMask::from_mask_base(0b101011, 1);
  251. // Test the `next` method of `HartIds` structure.
  252. let mut hart_ids = mask.iter();
  253. assert_eq!(hart_ids.next(), Some(1));
  254. assert_eq!(hart_ids.next(), Some(2));
  255. assert_eq!(hart_ids.next(), Some(4));
  256. assert_eq!(hart_ids.next(), Some(6));
  257. assert_eq!(hart_ids.next(), None);
  258. // `HartIds` structures are fused, meaning they return `None` forever once iteration finished.
  259. assert_eq!(hart_ids.next(), None);
  260. // Test `for` loop on mask (`HartMask`) as `IntoIterator`.
  261. let mut ans = [0; 4];
  262. let mut idx = 0;
  263. for hart_id in mask {
  264. ans[idx] = hart_id;
  265. idx += 1;
  266. }
  267. assert_eq!(ans, [1, 2, 4, 6]);
  268. // Test `Iterator` methods on `HartIds`.
  269. let mut hart_ids = mask.iter();
  270. assert_eq!(hart_ids.size_hint(), (4, Some(4)));
  271. let _ = hart_ids.next();
  272. assert_eq!(hart_ids.size_hint(), (3, Some(3)));
  273. let _ = hart_ids.next();
  274. let _ = hart_ids.next();
  275. assert_eq!(hart_ids.size_hint(), (1, Some(1)));
  276. let _ = hart_ids.next();
  277. assert_eq!(hart_ids.size_hint(), (0, Some(0)));
  278. let _ = hart_ids.next();
  279. assert_eq!(hart_ids.size_hint(), (0, Some(0)));
  280. let mut hart_ids = mask.iter();
  281. assert_eq!(hart_ids.count(), 4);
  282. let _ = hart_ids.next();
  283. assert_eq!(hart_ids.count(), 3);
  284. let _ = hart_ids.next();
  285. let _ = hart_ids.next();
  286. let _ = hart_ids.next();
  287. assert_eq!(hart_ids.count(), 0);
  288. let _ = hart_ids.next();
  289. assert_eq!(hart_ids.count(), 0);
  290. let hart_ids = mask.iter();
  291. assert_eq!(hart_ids.last(), Some(6));
  292. let mut hart_ids = mask.iter();
  293. assert_eq!(hart_ids.nth(2), Some(4));
  294. let mut hart_ids = mask.iter();
  295. assert_eq!(hart_ids.nth(0), Some(1));
  296. let mut iter = mask.iter().step_by(2);
  297. assert_eq!(iter.next(), Some(1));
  298. assert_eq!(iter.next(), Some(4));
  299. assert_eq!(iter.next(), None);
  300. let mask_2 = HartMask::from_mask_base(0b1001101, 64);
  301. let mut iter = mask.iter().chain(mask_2);
  302. assert_eq!(iter.next(), Some(1));
  303. assert_eq!(iter.next(), Some(2));
  304. assert_eq!(iter.next(), Some(4));
  305. assert_eq!(iter.next(), Some(6));
  306. assert_eq!(iter.next(), Some(64));
  307. assert_eq!(iter.next(), Some(66));
  308. assert_eq!(iter.next(), Some(67));
  309. assert_eq!(iter.next(), Some(70));
  310. assert_eq!(iter.next(), None);
  311. let mut iter = mask.iter().zip(mask_2);
  312. assert_eq!(iter.next(), Some((1, 64)));
  313. assert_eq!(iter.next(), Some((2, 66)));
  314. assert_eq!(iter.next(), Some((4, 67)));
  315. assert_eq!(iter.next(), Some((6, 70)));
  316. assert_eq!(iter.next(), None);
  317. fn to_plic_context_id(hart_id_machine: usize) -> usize {
  318. hart_id_machine * 2
  319. }
  320. let mut iter = mask.iter().map(to_plic_context_id);
  321. assert_eq!(iter.next(), Some(2));
  322. assert_eq!(iter.next(), Some(4));
  323. assert_eq!(iter.next(), Some(8));
  324. assert_eq!(iter.next(), Some(12));
  325. assert_eq!(iter.next(), None);
  326. let mut channel_received = [0; 4];
  327. let mut idx = 0;
  328. let mut channel_send = |hart_id| {
  329. channel_received[idx] = hart_id;
  330. idx += 1;
  331. };
  332. mask.iter().for_each(|value| channel_send(value));
  333. assert_eq!(channel_received, [1, 2, 4, 6]);
  334. let is_in_cluster_1 = |hart_id: &usize| *hart_id >= 4 && *hart_id < 7;
  335. let mut iter = mask.iter().filter(is_in_cluster_1);
  336. assert_eq!(iter.next(), Some(4));
  337. assert_eq!(iter.next(), Some(6));
  338. assert_eq!(iter.next(), None);
  339. let if_in_cluster_1_get_plic_context_id = |hart_id: usize| {
  340. if hart_id >= 4 && hart_id < 7 {
  341. Some(hart_id * 2)
  342. } else {
  343. None
  344. }
  345. };
  346. let mut iter = mask.iter().filter_map(if_in_cluster_1_get_plic_context_id);
  347. assert_eq!(iter.next(), Some(8));
  348. assert_eq!(iter.next(), Some(12));
  349. assert_eq!(iter.next(), None);
  350. let mut iter = mask.iter().enumerate();
  351. assert_eq!(iter.next(), Some((0, 1)));
  352. assert_eq!(iter.next(), Some((1, 2)));
  353. assert_eq!(iter.next(), Some((2, 4)));
  354. assert_eq!(iter.next(), Some((3, 6)));
  355. assert_eq!(iter.next(), None);
  356. let mut ans = [(0, 0); 4];
  357. let mut idx = 0;
  358. for (i, hart_id) in mask.iter().enumerate() {
  359. ans[idx] = (i, hart_id);
  360. idx += 1;
  361. }
  362. assert_eq!(ans, [(0, 1), (1, 2), (2, 4), (3, 6)]);
  363. let mut iter = mask.iter().peekable();
  364. assert_eq!(iter.peek(), Some(&1));
  365. assert_eq!(iter.next(), Some(1));
  366. assert_eq!(iter.peek(), Some(&2));
  367. assert_eq!(iter.next(), Some(2));
  368. assert_eq!(iter.peek(), Some(&4));
  369. assert_eq!(iter.next(), Some(4));
  370. assert_eq!(iter.peek(), Some(&6));
  371. assert_eq!(iter.next(), Some(6));
  372. assert_eq!(iter.peek(), None);
  373. assert_eq!(iter.next(), None);
  374. // TODO: other iterator tests.
  375. assert!(mask.iter().is_sorted());
  376. assert!(mask.iter().is_sorted_by(|a, b| a <= b));
  377. // Reverse iterator as `DoubleEndedIterator`.
  378. let mut iter = mask.iter().rev();
  379. assert_eq!(iter.next(), Some(6));
  380. assert_eq!(iter.next(), Some(4));
  381. assert_eq!(iter.next(), Some(2));
  382. assert_eq!(iter.next(), Some(1));
  383. assert_eq!(iter.next(), None);
  384. // Special iterator values.
  385. let nothing = HartMask::from_mask_base(0, 1000);
  386. assert!(nothing.iter().eq([]));
  387. let all_mask_bits_set = HartMask::from_mask_base(usize::MAX, 1000);
  388. let range = 1000..(1000 + usize::BITS as usize);
  389. assert!(all_mask_bits_set.iter().eq(range));
  390. let all_harts = HartMask::all();
  391. let mut iter = all_harts.iter();
  392. assert_eq!(iter.size_hint(), (usize::MAX, Some(usize::MAX)));
  393. // Don't use `Iterator::eq` here; it would literally run `Iterator::try_for_each` from 0 to usize::MAX
  394. // which will cost us forever to run the test.
  395. assert_eq!(iter.next(), Some(0));
  396. assert_eq!(iter.size_hint(), (usize::MAX - 1, Some(usize::MAX - 1)));
  397. assert_eq!(iter.next(), Some(1));
  398. assert_eq!(iter.next(), Some(2));
  399. // skip 500 elements
  400. let _ = iter.nth(500 - 1);
  401. assert_eq!(iter.next(), Some(503));
  402. assert_eq!(iter.size_hint(), (usize::MAX - 504, Some(usize::MAX - 504)));
  403. assert_eq!(iter.next_back(), Some(usize::MAX));
  404. assert_eq!(iter.next_back(), Some(usize::MAX - 1));
  405. assert_eq!(iter.size_hint(), (usize::MAX - 506, Some(usize::MAX - 506)));
  406. // A common usage of `HartMask::all`, we assume that this platform filters out hart 0..=3.
  407. let environment_available_hart_ids = 4..128;
  408. // `hart_mask_iter` contains 64..=usize::MAX.
  409. let hart_mask_iter = all_harts.iter().skip(64);
  410. let filtered_iter = environment_available_hart_ids.filter(|&x| {
  411. hart_mask_iter
  412. .clone()
  413. .find(|&y| y >= x)
  414. .map_or(false, |y| y == x)
  415. });
  416. assert!(filtered_iter.eq(64..128));
  417. // The following operations should have O(1) complexity.
  418. let all_harts = HartMask::all();
  419. assert_eq!(all_harts.iter().count(), usize::MAX);
  420. assert_eq!(all_harts.iter().last(), Some(usize::MAX));
  421. assert_eq!(all_harts.iter().min(), Some(0));
  422. assert_eq!(all_harts.iter().max(), Some(usize::MAX));
  423. assert!(all_harts.iter().is_sorted());
  424. let partial_all_harts = {
  425. let mut ans = HartMask::all().iter();
  426. let _ = ans.nth(65536 - 1);
  427. let _ = ans.nth_back(4096 - 1);
  428. ans
  429. };
  430. assert_eq!(partial_all_harts.clone().count(), usize::MAX - 65536 - 4096);
  431. assert_eq!(partial_all_harts.clone().last(), Some(usize::MAX - 4096));
  432. assert_eq!(partial_all_harts.clone().min(), Some(65536));
  433. assert_eq!(partial_all_harts.clone().max(), Some(usize::MAX - 4096));
  434. assert!(partial_all_harts.is_sorted());
  435. let nothing = HartMask::from_mask_base(0, 1000);
  436. assert_eq!(nothing.iter().count(), 0);
  437. assert_eq!(nothing.iter().last(), None);
  438. assert_eq!(nothing.iter().min(), None);
  439. assert_eq!(nothing.iter().max(), None);
  440. assert!(nothing.iter().is_sorted());
  441. let mask = HartMask::from_mask_base(0b101011, 1);
  442. assert_eq!(mask.iter().count(), 4);
  443. assert_eq!(mask.iter().last(), Some(6));
  444. assert_eq!(mask.iter().min(), Some(1));
  445. assert_eq!(mask.iter().max(), Some(6));
  446. assert!(mask.iter().is_sorted());
  447. let all_mask_bits_set = HartMask::from_mask_base(usize::MAX, 1000);
  448. let last = 1000 + usize::BITS as usize - 1;
  449. assert_eq!(all_mask_bits_set.iter().count(), usize::BITS as usize);
  450. assert_eq!(all_mask_bits_set.iter().last(), Some(last));
  451. assert_eq!(all_mask_bits_set.iter().min(), Some(1000));
  452. assert_eq!(all_mask_bits_set.iter().max(), Some(last));
  453. assert!(all_mask_bits_set.iter().is_sorted());
  454. }
  455. #[test]
  456. fn rustsbi_hart_mask_non_usize() {
  457. assert_eq!(HartMask::<i32>::IGNORE_MASK, -1);
  458. assert_eq!(HartMask::<i64>::IGNORE_MASK, -1);
  459. assert_eq!(HartMask::<i128>::IGNORE_MASK, -1);
  460. assert_eq!(HartMask::<u32>::IGNORE_MASK, u32::MAX);
  461. assert_eq!(HartMask::<u64>::IGNORE_MASK, u64::MAX);
  462. assert_eq!(HartMask::<u128>::IGNORE_MASK, u128::MAX);
  463. assert_eq!(HartMask::<i32>::all(), HartMask::from_mask_base(0, -1));
  464. }
  465. }