Sfoglia il codice sorgente

Merge #59

59: Added `MulAdd` and `MulAddAssign` traits r=cuviper a=regexident

Both `f32` and `f64` implement fused multiply-add, which computes `(self * a) + b` with only one rounding error. This produces a more accurate result with better performance than a separate multiplication operation followed by an add:

```rust
fn mul_add(self, a: f32, b: f32) -> f32[src]
```

It is however not possible to make use of this in a generic context by abstracting over a trait.

My concrete use-case is machine learning, [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent) to be specific,  
where the core operation of updating the gradient could make use of `mul_add` for both its `weights: Vector` as well as its `bias: f32`:

```rust
struct Perceptron {
  weights: Vector,
  bias: f32,
}

impl MulAdd<f32, Self> for Vector {
  // ...
}

impl Perceptron {
  fn learn(&mut self, example: Vector, expected: f32, learning_rate: f32) {
    let alpha = self.error(example, expected, learning_rate);
    self.weights = example.mul_add(alpha, self.weights);
    self.bias = self.bias.mul_add(alpha, self.bias)
  }
}
```

(The actual impl of `Vector` would be generic over its value type: `Vector<T>`, thus requiring the trait.)

Co-authored-by: Vincent Esche <regexident@gmail.com>
Co-authored-by: Josh Stone <cuviper@gmail.com>
bors[bot] 7 anni fa
parent
commit
a49013e338
3 ha cambiato i file con 148 aggiunte e 0 eliminazioni
  1. 1 0
      src/lib.rs
  2. 1 0
      src/ops/mod.rs
  3. 146 0
      src/ops/mul_add.rs

+ 1 - 0
src/lib.rs

@@ -37,6 +37,7 @@ pub use ops::inv::Inv;
 pub use ops::checked::{CheckedAdd, CheckedSub, CheckedMul, CheckedDiv,
                        CheckedRem, CheckedNeg, CheckedShl, CheckedShr};
 pub use ops::wrapping::{WrappingAdd, WrappingMul, WrappingSub};
+pub use ops::mul_add::{MulAdd, MulAddAssign};
 pub use ops::saturating::Saturating;
 pub use sign::{Signed, Unsigned, abs, abs_sub, signum};
 pub use cast::{AsPrimitive, FromPrimitive, ToPrimitive, NumCast, cast};

+ 1 - 0
src/ops/mod.rs

@@ -2,3 +2,4 @@ pub mod saturating;
 pub mod checked;
 pub mod wrapping;
 pub mod inv;
+pub mod mul_add;

+ 146 - 0
src/ops/mul_add.rs

@@ -0,0 +1,146 @@
+/// The fused multiply-add operation.
+/// Computes (self * a) + b with only one rounding error.
+/// This produces a more accurate result with better performance
+/// than a separate multiplication operation followed by an add.
+///
+/// Note that `A` and `B` are `Self` by default, but this is not mandatory.
+///
+/// # Example
+///
+/// ```
+/// use std::f32;
+///
+/// let m = 10.0_f32;
+/// let x = 4.0_f32;
+/// let b = 60.0_f32;
+///
+/// // 100.0
+/// let abs_difference = (m.mul_add(x, b) - (m*x + b)).abs();
+///
+/// assert!(abs_difference <= f32::EPSILON);
+/// ```
+pub trait MulAdd<A = Self, B = Self> {
+    /// The resulting type after applying the fused multiply-add.
+    type Output;
+
+    /// Performs the fused multiply-add operation.
+    fn mul_add(self, a: A, b: B) -> Self::Output;
+}
+
+/// The fused multiply-add assignment operation.
+pub trait MulAddAssign<A = Self, B = Self> {
+    /// Performs the fused multiply-add operation.
+    fn mul_add_assign(&mut self, a: A, b: B);
+}
+
+#[cfg(feature = "std")]
+impl MulAdd<f32, f32> for f32 {
+    type Output = Self;
+
+    #[inline]
+    fn mul_add(self, a: Self, b: Self) -> Self::Output {
+        f32::mul_add(self, a, b)
+    }
+}
+
+#[cfg(feature = "std")]
+impl MulAdd<f64, f64> for f64 {
+    type Output = Self;
+
+    #[inline]
+    fn mul_add(self, a: Self, b: Self) -> Self::Output {
+        f64::mul_add(self, a, b)
+    }
+}
+
+macro_rules! mul_add_impl {
+    ($trait_name:ident for $($t:ty)*) => {$(
+        impl $trait_name for $t {
+            type Output = Self;
+
+            #[inline]
+            fn mul_add(self, a: Self, b: Self) -> Self::Output {
+                (self * a) + b
+            }
+        }
+    )*}
+}
+
+mul_add_impl!(MulAdd for isize usize i8 u8 i16 u16 i32 u32 i64 u64);
+
+#[cfg(feature = "std")]
+impl MulAddAssign<f32, f32> for f32 {
+    #[inline]
+    fn mul_add_assign(&mut self, a: Self, b: Self) {
+        *self = f32::mul_add(*self, a, b)
+    }
+}
+
+#[cfg(feature = "std")]
+impl MulAddAssign<f64, f64> for f64 {
+    #[inline]
+    fn mul_add_assign(&mut self, a: Self, b: Self) {
+        *self = f64::mul_add(*self, a, b)
+    }
+}
+
+macro_rules! mul_add_assign_impl {
+    ($trait_name:ident for $($t:ty)*) => {$(
+        impl $trait_name for $t {
+            #[inline]
+            fn mul_add_assign(&mut self, a: Self, b: Self) {
+                *self = (*self * a) + b
+            }
+        }
+    )*}
+}
+
+mul_add_assign_impl!(MulAddAssign for isize usize i8 u8 i16 u16 i32 u32 i64 u64);
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn mul_add_integer() {
+        macro_rules! test_mul_add {
+            ($($t:ident)+) => {
+                $(
+                    {
+                        let m: $t = 2;
+                        let x: $t = 3;
+                        let b: $t = 4;
+
+                        assert_eq!(MulAdd::mul_add(m, x, b), (m*x + b));
+                    }
+                )+
+            };
+        }
+
+        test_mul_add!(usize u8 u16 u32 u64 isize i8 i16 i32 i64);
+    }
+
+    #[test]
+    #[cfg(feature = "std")]
+    fn mul_add_float() {
+        macro_rules! test_mul_add {
+            ($($t:ident)+) => {
+                $(
+                    {
+                        use core::$t;
+
+                        let m: $t = 12.0;
+                        let x: $t = 3.4;
+                        let b: $t = 5.6;
+
+                        let abs_difference = (MulAdd::mul_add(m, x, b) - (m*x + b)).abs();
+
+                        assert!(abs_difference <= $t::EPSILON);
+                    }
+                )+
+            };
+        }
+
+        test_mul_add!(f32 f64);
+    }
+}