mul_add.rs 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. /// The fused multiply-add operation.
  2. /// Computes (self * a) + b with only one rounding error.
  3. /// This produces a more accurate result with better performance
  4. /// than a separate multiplication operation followed by an add.
  5. ///
  6. /// Note that `A` and `B` are `Self` by default, but this is not mandatory.
  7. ///
  8. /// # Example
  9. ///
  10. /// ```
  11. /// use std::f32;
  12. ///
  13. /// let m = 10.0_f32;
  14. /// let x = 4.0_f32;
  15. /// let b = 60.0_f32;
  16. ///
  17. /// // 100.0
  18. /// let abs_difference = (m.mul_add(x, b) - (m*x + b)).abs();
  19. ///
  20. /// assert!(abs_difference <= f32::EPSILON);
  21. /// ```
  22. pub trait MulAdd<A = Self, B = Self> {
  23. /// The resulting type after applying the fused multiply-add.
  24. type Output;
  25. /// Performs the fused multiply-add operation.
  26. fn mul_add(self, a: A, b: B) -> Self::Output;
  27. }
  28. /// The fused multiply-add assignment operation.
  29. pub trait MulAddAssign<A = Self, B = Self> {
  30. /// Performs the fused multiply-add operation.
  31. fn mul_add_assign(&mut self, a: A, b: B);
  32. }
  33. #[cfg(feature = "std")]
  34. impl MulAdd<f32, f32> for f32 {
  35. type Output = Self;
  36. #[inline]
  37. fn mul_add(self, a: Self, b: Self) -> Self::Output {
  38. f32::mul_add(self, a, b)
  39. }
  40. }
  41. #[cfg(feature = "std")]
  42. impl MulAdd<f64, f64> for f64 {
  43. type Output = Self;
  44. #[inline]
  45. fn mul_add(self, a: Self, b: Self) -> Self::Output {
  46. f64::mul_add(self, a, b)
  47. }
  48. }
  49. macro_rules! mul_add_impl {
  50. ($trait_name:ident for $($t:ty)*) => {$(
  51. impl $trait_name for $t {
  52. type Output = Self;
  53. #[inline]
  54. fn mul_add(self, a: Self, b: Self) -> Self::Output {
  55. (self * a) + b
  56. }
  57. }
  58. )*}
  59. }
  60. mul_add_impl!(MulAdd for isize usize i8 u8 i16 u16 i32 u32 i64 u64);
  61. #[cfg(has_i128)]
  62. mul_add_impl!(MulAdd for i128 u128);
  63. #[cfg(feature = "std")]
  64. impl MulAddAssign<f32, f32> for f32 {
  65. #[inline]
  66. fn mul_add_assign(&mut self, a: Self, b: Self) {
  67. *self = f32::mul_add(*self, a, b)
  68. }
  69. }
  70. #[cfg(feature = "std")]
  71. impl MulAddAssign<f64, f64> for f64 {
  72. #[inline]
  73. fn mul_add_assign(&mut self, a: Self, b: Self) {
  74. *self = f64::mul_add(*self, a, b)
  75. }
  76. }
  77. macro_rules! mul_add_assign_impl {
  78. ($trait_name:ident for $($t:ty)*) => {$(
  79. impl $trait_name for $t {
  80. #[inline]
  81. fn mul_add_assign(&mut self, a: Self, b: Self) {
  82. *self = (*self * a) + b
  83. }
  84. }
  85. )*}
  86. }
  87. mul_add_assign_impl!(MulAddAssign for isize usize i8 u8 i16 u16 i32 u32 i64 u64);
  88. #[cfg(has_i128)]
  89. mul_add_assign_impl!(MulAddAssign for i128 u128);
  90. #[cfg(test)]
  91. mod tests {
  92. use super::*;
  93. #[test]
  94. fn mul_add_integer() {
  95. macro_rules! test_mul_add {
  96. ($($t:ident)+) => {
  97. $(
  98. {
  99. let m: $t = 2;
  100. let x: $t = 3;
  101. let b: $t = 4;
  102. assert_eq!(MulAdd::mul_add(m, x, b), (m*x + b));
  103. }
  104. )+
  105. };
  106. }
  107. test_mul_add!(usize u8 u16 u32 u64 isize i8 i16 i32 i64);
  108. }
  109. #[test]
  110. #[cfg(feature = "std")]
  111. fn mul_add_float() {
  112. macro_rules! test_mul_add {
  113. ($($t:ident)+) => {
  114. $(
  115. {
  116. use core::$t;
  117. let m: $t = 12.0;
  118. let x: $t = 3.4;
  119. let b: $t = 5.6;
  120. let abs_difference = (MulAdd::mul_add(m, x, b) - (m*x + b)).abs();
  121. assert!(abs_difference <= $t::EPSILON);
  122. }
  123. )+
  124. };
  125. }
  126. test_mul_add!(f32 f64);
  127. }
  128. }