From 08b0022aab249d16f1b63fef51b89b2f5bfff931 Mon Sep 17 00:00:00 2001
From: Kent Overstreet <kent.overstreet@gmail.com>
Date: Thu, 10 Dec 2015 15:25:42 -0900
Subject: [PATCH] Improve multiply performance

The main idea here is to do as much as possible with slices, instead of
allocating new BigUints (= heap allocations).

Current performance:

multiply_0:	10,507 ns/iter (+/- 987)
multiply_1:	2,788,734 ns/iter (+/- 100,079)
multiply_2:	69,923,515 ns/iter (+/- 4,550,902)

After this patch, we get:

multiply_0:	364 ns/iter (+/- 62)
multiply_1:	34,085 ns/iter (+/- 1,179)
multiply_2:	3,753,883 ns/iter (+/- 46,876)
---
 src/bigint.rs | 289 ++++++++++++++++++++++++++++++++++++++------------
 1 file changed, 219 insertions(+), 70 deletions(-)

diff --git a/src/bigint.rs b/src/bigint.rs
index ad0c2ed..c47f99c 100644
--- a/src/bigint.rs
+++ b/src/bigint.rs
@@ -148,6 +148,16 @@ fn sbb(a: BigDigit, b: BigDigit, borrow: &mut BigDigit) -> BigDigit {
     lo
 }
 
+#[inline]
+fn mac_with_carry(a: BigDigit, b: BigDigit, c: BigDigit, carry: &mut BigDigit) -> BigDigit {
+    let (hi, lo) = big_digit::from_doublebigdigit(
+        (a as DoubleBigDigit) +
+        (b as DoubleBigDigit) * (c as DoubleBigDigit) +
+        (*carry as DoubleBigDigit));
+    *carry = hi;
+    lo
+}
+
 /// A big unsigned integer type.
 ///
 /// A `BigUint`-typed value `BigUint { data: vec!(a, b, c) }` represents a number
@@ -172,18 +182,25 @@ impl PartialOrd for BigUint {
     }
 }
 
+fn cmp_slice(a: &[BigDigit], b: &[BigDigit]) -> Ordering {
+    debug_assert!(a.last() != Some(&0));
+    debug_assert!(b.last() != Some(&0));
+
+    let (a_len, b_len) = (a.len(), b.len());
+    if a_len < b_len { return Less; }
+    if a_len > b_len { return Greater;  }
+
+    for (&ai, &bi) in a.iter().rev().zip(b.iter().rev()) {
+        if ai < bi { return Less; }
+        if ai > bi { return Greater; }
+    }
+    return Equal;
+}
+
 impl Ord for BigUint {
     #[inline]
     fn cmp(&self, other: &BigUint) -> Ordering {
-        let (s_len, o_len) = (self.data.len(), other.data.len());
-        if s_len < o_len { return Less; }
-        if s_len > o_len { return Greater;  }
-
-        for (&self_i, &other_i) in self.data.iter().rev().zip(other.data.iter().rev()) {
-            if self_i < other_i { return Less; }
-            if self_i > other_i { return Greater; }
-        }
-        return Equal;
+        cmp_slice(&self.data[..], &other.data[..])
     }
 }
 
@@ -608,80 +625,202 @@ impl<'a> Sub<&'a BigUint> for BigUint {
     }
 }
 
+fn sub_sign(a: &[BigDigit], b: &[BigDigit]) -> BigInt {
+    // Normalize:
+    let a = &a[..a.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
+    let b = &b[..b.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
 
-forward_all_binop_to_val_ref_commutative!(impl Mul for BigUint, mul);
+    match cmp_slice(a, b) {
+        Greater => {
+            let mut ret = BigUint::from_slice(a);
+            sub2(&mut ret.data[..], b);
+            BigInt::from_biguint(Plus, ret.normalize())
+        },
+        Less    => {
+            let mut ret = BigUint::from_slice(b);
+            sub2(&mut ret.data[..], a);
+            BigInt::from_biguint(Minus, ret.normalize())
+        },
+        _       => Zero::zero(),
+    }
+}
 
-impl<'a> Mul<&'a BigUint> for BigUint {
-    type Output = BigUint;
+forward_all_binop_to_ref_ref!(impl Mul for BigUint, mul);
 
-    fn mul(self, other: &BigUint) -> BigUint {
-        if self.is_zero() || other.is_zero() { return Zero::zero(); }
+/// Three argument multiply accumulate:
+/// acc += b * c
+fn mac_digit(acc: &mut [BigDigit], b: &[BigDigit], c: BigDigit) {
+    if c == 0 { return; }
 
-        let (s_len, o_len) = (self.data.len(), other.data.len());
-        if s_len == 1 { return mul_digit(other.clone(), self.data[0]);  }
-        if o_len == 1 { return mul_digit(self,  other.data[0]); }
+    let mut b_iter = b.iter();
+    let mut carry = 0;
 
-        // Using Karatsuba multiplication
-        // (a1 * base + a0) * (b1 * base + b0)
-        // = a1*b1 * base^2 +
-        //   (a1*b1 + a0*b0 - (a1-b0)*(b1-a0)) * base +
-        //   a0*b0
-        let half_len = cmp::max(s_len, o_len) / 2;
-        let (s_hi, s_lo) = cut_at(self,  half_len);
-        let (o_hi, o_lo) = cut_at(other.clone(), half_len);
-
-        let ll = &s_lo * &o_lo;
-        let hh = &s_hi * &o_hi;
-        let mm = {
-            let (s1, n1) = sub_sign(s_hi, s_lo);
-            let (s2, n2) = sub_sign(o_hi, o_lo);
-            match (s1, s2) {
-                (Equal, _) | (_, Equal) => &hh + &ll,
-                (Less, Greater) | (Greater, Less) => &hh + &ll + (n1 * n2),
-                (Less, Less) | (Greater, Greater) => &hh + &ll - (n1 * n2)
-            }
-        };
-
-        return ll + mm.shl_unit(half_len) + hh.shl_unit(half_len * 2);
-
-
-        fn mul_digit(a: BigUint, n: BigDigit) -> BigUint {
-            if n == 0 { return Zero::zero(); }
-            if n == 1 { return a; }
-
-            let mut carry = 0;
-            let mut prod = a.data;
-            for a in &mut prod {
-                let d = (*a as DoubleBigDigit)
-                    * (n as DoubleBigDigit)
-                    + (carry as DoubleBigDigit);
-                let (hi, lo) = big_digit::from_doublebigdigit(d);
-                carry = hi;
-                *a = lo;
-            }
-            if carry != 0 { prod.push(carry); }
-            BigUint::new(prod)
+    for ai in acc.iter_mut() {
+        if let Some(bi) = b_iter.next() {
+            *ai = mac_with_carry(*ai, *bi, c, &mut carry);
+        } else if carry != 0 {
+            *ai = mac_with_carry(*ai, 0, c, &mut carry);
+        } else {
+            break;
         }
+    }
 
-        #[inline]
-        fn cut_at(mut a: BigUint, n: usize) -> (BigUint, BigUint) {
-            let mid = cmp::min(a.data.len(), n);
-            let hi = BigUint::from_slice(&a.data[mid ..]);
-            a.data.truncate(mid);
-            (hi, BigUint::new(a.data))
+    assert!(carry == 0);
+}
+
+/// Three argument multiply accumulate:
+/// acc += b * c
+fn mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) {
+    let (x, y) = if b.len() < c.len() { (b, c) } else { (c, b) };
+
+    /*
+     * Karatsuba multiplication is slower than long multiplication for small x and y:
+     */
+    if x.len() <= 4 {
+        for (i, xi) in x.iter().enumerate() {
+            mac_digit(&mut acc[i..], y, *xi);
         }
+    } else {
+        /*
+         * Karatsuba multiplication:
+         *
+         * The idea is that we break x and y up into two smaller numbers that each have about half
+         * as many digits, like so (note that multiplying by b is just a shift):
+         *
+         * x = x0 + x1 * b
+         * y = y0 + y1 * b
+         *
+         * With some algebra, we can compute x * y with three smaller products, where the inputs to
+         * each of the smaller products have only about half as many digits as x and y:
+         *
+         * x * y = (x0 + x1 * b) * (y0 + y1 * b)
+         *
+         * x * y = x0 * y0
+         *       + x0 * y1 * b
+         *       + x1 * y0 * b
+         *       + x1 * y1 * b^2
+         *
+         * Let p0 = x0 * y0 and p2 = x1 * y1:
+         *
+         * x * y = p0
+         *       + (x0 * y1 + x1 * p0) * b
+         *       + p2 * b^2
+         *
+         * The real trick is that middle term:
+         *
+         *         x0 * y1 + x1 * y0
+         *
+         *       = x0 * y1 + x1 * y0 - p0 + p0 - p2 + p2
+         *
+         *       = x0 * y1 + x1 * y0 - x0 * y0 - x1 * y1 + p0 + p2
+         *
+         * Now we complete the square:
+         *
+         *       = -(x0 * y0 - x0 * y1 - x1 * y0 + x1 * y1) + p0 + p2
+         *
+         *       = -((x1 - x0) * (y1 - y0)) + p0 + p2
+         *
+         * Let p1 = (x1 - x0) * (y1 - y0), and substitute back into our original formula:
+         *
+         * x * y = p0
+         *       + (p0 + p2 - p1) * b
+         *       + p2 * b^2
+         *
+         * Where the three intermediate products are:
+         *
+         * p0 = x0 * y0
+         * p1 = (x1 - x0) * (y1 - y0)
+         * p2 = x1 * y1
+         *
+         * In doing the computation, we take great care to avoid unnecessary temporary variables
+         * (since creating a BigUint requires a heap allocation): thus, we rearrange the formula a
+         * bit so we can use the same temporary variable for all the intermediate products:
+         *
+         * x * y = p2 * b^2 + p2 * b
+         *       + p0 * b + p0
+         *       - p1 * b
+         *
+         * The other trick we use is instead of doing explicit shifts, we slice acc at the
+         * appropriate offset when doing the add.
+         */
 
-        #[inline]
-        fn sub_sign(a: BigUint, b: BigUint) -> (Ordering, BigUint) {
-            match a.cmp(&b) {
-                Less    => (Less,    b - a),
-                Greater => (Greater, a - b),
-                _       => (Equal,   Zero::zero())
-            }
+        /*
+         * When x is smaller than y, it's significantly faster to pick b such that x is split in
+         * half, not y:
+         */
+        let b = x.len() / 2;
+        let (x0, x1) = x.split_at(b);
+        let (y0, y1) = y.split_at(b);
+
+        /* We reuse the same BigUint for all the intermediate multiplies: */
+
+        let len = y.len() + 1;
+        let mut p: BigUint = BigUint { data: Vec::with_capacity(len) };
+        p.data.extend(repeat(0).take(len));
+
+        // p2 = x1 * y1
+        mac3(&mut p.data[..], x1, y1);
+
+        // Not required, but the adds go faster if we drop any unneeded 0s from the end:
+        p = p.normalize();
+
+        add2(&mut acc[b..],        &p.data[..]);
+        add2(&mut acc[b * 2..],    &p.data[..]);
+
+        // Zero out p before the next multiply:
+        p.data.truncate(0);
+        p.data.extend(repeat(0).take(len));
+
+        // p0 = x0 * y0
+        mac3(&mut p.data[..], x0, y0);
+        p = p.normalize();
+
+        add2(&mut acc[..],                &p.data[..]);
+        add2(&mut acc[b..],        &p.data[..]);
+
+        // p1 = (x1 - x0) * (y1 - y0)
+        // We do this one last, since it may be negative and acc can't ever be negative:
+        let j0 = sub_sign(x1, x0);
+        let j1 = sub_sign(y1, y0);
+
+        match j0.sign * j1.sign {
+            Plus    => {
+                p.data.truncate(0);
+                p.data.extend(repeat(0).take(len));
+
+                mac3(&mut p.data[..], &j0.data.data[..], &j1.data.data[..]);
+                p = p.normalize();
+
+                sub2(&mut acc[b..], &p.data[..]);
+            },
+            Minus   => {
+                mac3(&mut acc[b..], &j0.data.data[..], &j1.data.data[..]);
+            },
+            NoSign  => (),
         }
     }
 }
 
+fn mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint {
+    let len = x.len() + y.len() + 1;
+    let mut prod: BigUint = BigUint { data: Vec::with_capacity(len) };
+
+    // resize isn't stable yet:
+    //prod.data.resize(len, 0);
+    prod.data.extend(repeat(0).take(len));
+
+    mac3(&mut prod.data[..], x, y);
+    prod.normalize()
+}
+
+impl<'a, 'b> Mul<&'b BigUint> for &'a BigUint {
+    type Output = BigUint;
+
+    #[inline]
+    fn mul(self, other: &BigUint) -> BigUint {
+        mul3(&self.data[..], &other.data[..])
+    }
+}
 
 forward_all_binop_to_ref_ref!(impl Div for BigUint, div);
 
@@ -3131,6 +3270,16 @@ mod biguint_tests {
         // Switching u and l should fail:
         let _n: BigUint = rng.gen_biguint_range(&u, &l);
     }
+
+    #[test]
+    fn test_sub_sign() {
+        use super::sub_sign;
+        let a = BigInt::from_str_radix("265252859812191058636308480000000", 10).unwrap();
+        let b = BigInt::from_str_radix("26525285981219105863630848000000", 10).unwrap();
+
+        assert_eq!(sub_sign(&a.data.data[..], &b.data.data[..]), &a - &b);
+        assert_eq!(sub_sign(&b.data.data[..], &a.data.data[..]), &b - &a);
+    }
 }
 
 #[cfg(test)]