mul_add.rs 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. /// The fused multiply-add operation.
  2. /// Computes (self * a) + b with only one rounding error.
  3. /// This produces a more accurate result with better performance
  4. /// than a separate multiplication operation followed by an add.
  5. ///
  6. /// Note that `A` and `B` are `Self` by default, but this is not mandatory.
  7. ///
  8. /// # Example
  9. ///
  10. /// ```
  11. /// use std::f32;
  12. ///
  13. /// let m = 10.0_f32;
  14. /// let x = 4.0_f32;
  15. /// let b = 60.0_f32;
  16. ///
  17. /// // 100.0
  18. /// let abs_difference = (m.mul_add(x, b) - (m*x + b)).abs();
  19. ///
  20. /// assert!(abs_difference <= f32::EPSILON);
  21. /// ```
  22. pub trait MulAdd<A = Self, B = Self> {
  23. /// The resulting type after applying the fused multiply-add.
  24. type Output;
  25. /// Performs the fused multiply-add operation.
  26. fn mul_add(self, a: A, b: B) -> Self::Output;
  27. }
  28. /// The fused multiply-add assignment operation.
  29. pub trait MulAddAssign<A = Self, B = Self> {
  30. /// Performs the fused multiply-add operation.
  31. fn mul_add_assign(&mut self, a: A, b: B);
  32. }
  33. #[cfg(feature = "std")]
  34. impl MulAdd<f32, f32> for f32 {
  35. type Output = Self;
  36. #[inline]
  37. fn mul_add(self, a: Self, b: Self) -> Self::Output {
  38. f32::mul_add(self, a, b)
  39. }
  40. }
  41. #[cfg(feature = "std")]
  42. impl MulAdd<f64, f64> for f64 {
  43. type Output = Self;
  44. #[inline]
  45. fn mul_add(self, a: Self, b: Self) -> Self::Output {
  46. f64::mul_add(self, a, b)
  47. }
  48. }
  49. macro_rules! mul_add_impl {
  50. ($trait_name:ident for $($t:ty)*) => {$(
  51. impl $trait_name for $t {
  52. type Output = Self;
  53. #[inline]
  54. fn mul_add(self, a: Self, b: Self) -> Self::Output {
  55. (self * a) + b
  56. }
  57. }
  58. )*}
  59. }
  60. mul_add_impl!(MulAdd for isize usize i8 u8 i16 u16 i32 u32 i64 u64);
  61. #[cfg(feature = "std")]
  62. impl MulAddAssign<f32, f32> for f32 {
  63. #[inline]
  64. fn mul_add_assign(&mut self, a: Self, b: Self) {
  65. *self = f32::mul_add(*self, a, b)
  66. }
  67. }
  68. #[cfg(feature = "std")]
  69. impl MulAddAssign<f64, f64> for f64 {
  70. #[inline]
  71. fn mul_add_assign(&mut self, a: Self, b: Self) {
  72. *self = f64::mul_add(*self, a, b)
  73. }
  74. }
  75. macro_rules! mul_add_assign_impl {
  76. ($trait_name:ident for $($t:ty)*) => {$(
  77. impl $trait_name for $t {
  78. #[inline]
  79. fn mul_add_assign(&mut self, a: Self, b: Self) {
  80. *self = (*self * a) + b
  81. }
  82. }
  83. )*}
  84. }
  85. mul_add_assign_impl!(MulAddAssign for isize usize i8 u8 i16 u16 i32 u32 i64 u64);
  86. #[cfg(test)]
  87. mod tests {
  88. use super::*;
  89. #[test]
  90. fn mul_add_integer() {
  91. macro_rules! test_mul_add {
  92. ($($t:ident)+) => {
  93. $(
  94. {
  95. let m: $t = 2;
  96. let x: $t = 3;
  97. let b: $t = 4;
  98. assert_eq!(MulAdd::mul_add(m, x, b), (m*x + b));
  99. }
  100. )+
  101. };
  102. }
  103. test_mul_add!(usize u8 u16 u32 u64 isize i8 i16 i32 i64);
  104. }
  105. #[test]
  106. #[cfg(feature = "std")]
  107. fn mul_add_float() {
  108. macro_rules! test_mul_add {
  109. ($($t:ident)+) => {
  110. $(
  111. {
  112. use core::$t;
  113. let m: $t = 12.0;
  114. let x: $t = 3.4;
  115. let b: $t = 5.6;
  116. let abs_difference = (MulAdd::mul_add(m, x, b) - (m*x + b)).abs();
  117. assert!(abs_difference <= $t::EPSILON);
  118. }
  119. )+
  120. };
  121. }
  122. test_mul_add!(f32 f64);
  123. }
  124. }