001/*
002 * Copyright (C) 2011 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
005 * in compliance with the License. You may obtain a copy of the License at
006 *
007 * http://www.apache.org/licenses/LICENSE-2.0
008 *
009 * Unless required by applicable law or agreed to in writing, software distributed under the License
010 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
011 * or implied. See the License for the specific language governing permissions and limitations under
012 * the License.
013 */
014
015package com.google.common.math;
016
017import static com.google.common.base.Preconditions.checkArgument;
018import static com.google.common.base.Preconditions.checkNotNull;
019import static com.google.common.math.MathPreconditions.checkNonNegative;
020import static com.google.common.math.MathPreconditions.checkPositive;
021import static com.google.common.math.MathPreconditions.checkRoundingUnnecessary;
022import static java.math.RoundingMode.CEILING;
023import static java.math.RoundingMode.FLOOR;
024import static java.math.RoundingMode.HALF_DOWN;
025import static java.math.RoundingMode.HALF_EVEN;
026import static java.math.RoundingMode.UNNECESSARY;
027
028import com.google.common.annotations.Beta;
029import com.google.common.annotations.GwtCompatible;
030import com.google.common.annotations.GwtIncompatible;
031import com.google.common.annotations.VisibleForTesting;
032import java.math.BigDecimal;
033import java.math.BigInteger;
034import java.math.RoundingMode;
035import java.util.ArrayList;
036import java.util.List;
037
038/**
039 * A class for arithmetic on values of type {@code BigInteger}.
040 *
041 * <p>The implementations of many methods in this class are based on material from Henry S. Warren,
042 * Jr.'s <i>Hacker's Delight</i>, (Addison Wesley, 2002).
043 *
044 * <p>Similar functionality for {@code int} and for {@code long} can be found in {@link IntMath} and
045 * {@link LongMath} respectively.
046 *
047 * @author Louis Wasserman
048 * @since 11.0
049 */
050@GwtCompatible(emulated = true)
051@ElementTypesAreNonnullByDefault
052public final class BigIntegerMath {
053  /**
054   * Returns the smallest power of two greater than or equal to {@code x}. This is equivalent to
055   * {@code BigInteger.valueOf(2).pow(log2(x, CEILING))}.
056   *
057   * @throws IllegalArgumentException if {@code x <= 0}
058   * @since 20.0
059   */
060  @Beta
061  public static BigInteger ceilingPowerOfTwo(BigInteger x) {
062    return BigInteger.ZERO.setBit(log2(x, CEILING));
063  }
064
065  /**
066   * Returns the largest power of two less than or equal to {@code x}. This is equivalent to {@code
067   * BigInteger.valueOf(2).pow(log2(x, FLOOR))}.
068   *
069   * @throws IllegalArgumentException if {@code x <= 0}
070   * @since 20.0
071   */
072  @Beta
073  public static BigInteger floorPowerOfTwo(BigInteger x) {
074    return BigInteger.ZERO.setBit(log2(x, FLOOR));
075  }
076
077  /** Returns {@code true} if {@code x} represents a power of two. */
078  public static boolean isPowerOfTwo(BigInteger x) {
079    checkNotNull(x);
080    return x.signum() > 0 && x.getLowestSetBit() == x.bitLength() - 1;
081  }
082
083  /**
084   * Returns the base-2 logarithm of {@code x}, rounded according to the specified rounding mode.
085   *
086   * @throws IllegalArgumentException if {@code x <= 0}
087   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
088   *     is not a power of two
089   */
090  @SuppressWarnings("fallthrough")
091  // TODO(kevinb): remove after this warning is disabled globally
092  public static int log2(BigInteger x, RoundingMode mode) {
093    checkPositive("x", checkNotNull(x));
094    int logFloor = x.bitLength() - 1;
095    switch (mode) {
096      case UNNECESSARY:
097        checkRoundingUnnecessary(isPowerOfTwo(x)); // fall through
098      case DOWN:
099      case FLOOR:
100        return logFloor;
101
102      case UP:
103      case CEILING:
104        return isPowerOfTwo(x) ? logFloor : logFloor + 1;
105
106      case HALF_DOWN:
107      case HALF_UP:
108      case HALF_EVEN:
109        if (logFloor < SQRT2_PRECOMPUTE_THRESHOLD) {
110          BigInteger halfPower =
111              SQRT2_PRECOMPUTED_BITS.shiftRight(SQRT2_PRECOMPUTE_THRESHOLD - logFloor);
112          if (x.compareTo(halfPower) <= 0) {
113            return logFloor;
114          } else {
115            return logFloor + 1;
116          }
117        }
118        // Since sqrt(2) is irrational, log2(x) - logFloor cannot be exactly 0.5
119        //
120        // To determine which side of logFloor.5 the logarithm is,
121        // we compare x^2 to 2^(2 * logFloor + 1).
122        BigInteger x2 = x.pow(2);
123        int logX2Floor = x2.bitLength() - 1;
124        return (logX2Floor < 2 * logFloor + 1) ? logFloor : logFloor + 1;
125
126      default:
127        throw new AssertionError();
128    }
129  }
130
131  /*
132   * The maximum number of bits in a square root for which we'll precompute an explicit half power
133   * of two. This can be any value, but higher values incur more class load time and linearly
134   * increasing memory consumption.
135   */
136  @VisibleForTesting static final int SQRT2_PRECOMPUTE_THRESHOLD = 256;
137
138  @VisibleForTesting
139  static final BigInteger SQRT2_PRECOMPUTED_BITS =
140      new BigInteger("16a09e667f3bcc908b2fb1366ea957d3e3adec17512775099da2f590b0667322a", 16);
141
142  /**
143   * Returns the base-10 logarithm of {@code x}, rounded according to the specified rounding mode.
144   *
145   * @throws IllegalArgumentException if {@code x <= 0}
146   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
147   *     is not a power of ten
148   */
149  @GwtIncompatible // TODO
150  @SuppressWarnings("fallthrough")
151  public static int log10(BigInteger x, RoundingMode mode) {
152    checkPositive("x", x);
153    if (fitsInLong(x)) {
154      return LongMath.log10(x.longValue(), mode);
155    }
156
157    int approxLog10 = (int) (log2(x, FLOOR) * LN_2 / LN_10);
158    BigInteger approxPow = BigInteger.TEN.pow(approxLog10);
159    int approxCmp = approxPow.compareTo(x);
160
161    /*
162     * We adjust approxLog10 and approxPow until they're equal to floor(log10(x)) and
163     * 10^floor(log10(x)).
164     */
165
166    if (approxCmp > 0) {
167      /*
168       * The code is written so that even completely incorrect approximations will still yield the
169       * correct answer eventually, but in practice this branch should almost never be entered, and
170       * even then the loop should not run more than once.
171       */
172      do {
173        approxLog10--;
174        approxPow = approxPow.divide(BigInteger.TEN);
175        approxCmp = approxPow.compareTo(x);
176      } while (approxCmp > 0);
177    } else {
178      BigInteger nextPow = BigInteger.TEN.multiply(approxPow);
179      int nextCmp = nextPow.compareTo(x);
180      while (nextCmp <= 0) {
181        approxLog10++;
182        approxPow = nextPow;
183        approxCmp = nextCmp;
184        nextPow = BigInteger.TEN.multiply(approxPow);
185        nextCmp = nextPow.compareTo(x);
186      }
187    }
188
189    int floorLog = approxLog10;
190    BigInteger floorPow = approxPow;
191    int floorCmp = approxCmp;
192
193    switch (mode) {
194      case UNNECESSARY:
195        checkRoundingUnnecessary(floorCmp == 0);
196        // fall through
197      case FLOOR:
198      case DOWN:
199        return floorLog;
200
201      case CEILING:
202      case UP:
203        return floorPow.equals(x) ? floorLog : floorLog + 1;
204
205      case HALF_DOWN:
206      case HALF_UP:
207      case HALF_EVEN:
208        // Since sqrt(10) is irrational, log10(x) - floorLog can never be exactly 0.5
209        BigInteger x2 = x.pow(2);
210        BigInteger halfPowerSquared = floorPow.pow(2).multiply(BigInteger.TEN);
211        return (x2.compareTo(halfPowerSquared) <= 0) ? floorLog : floorLog + 1;
212      default:
213        throw new AssertionError();
214    }
215  }
216
217  private static final double LN_10 = Math.log(10);
218  private static final double LN_2 = Math.log(2);
219
220  /**
221   * Returns the square root of {@code x}, rounded with the specified rounding mode.
222   *
223   * @throws IllegalArgumentException if {@code x < 0}
224   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code
225   *     sqrt(x)} is not an integer
226   */
227  @GwtIncompatible // TODO
228  @SuppressWarnings("fallthrough")
229  public static BigInteger sqrt(BigInteger x, RoundingMode mode) {
230    checkNonNegative("x", x);
231    if (fitsInLong(x)) {
232      return BigInteger.valueOf(LongMath.sqrt(x.longValue(), mode));
233    }
234    BigInteger sqrtFloor = sqrtFloor(x);
235    switch (mode) {
236      case UNNECESSARY:
237        checkRoundingUnnecessary(sqrtFloor.pow(2).equals(x)); // fall through
238      case FLOOR:
239      case DOWN:
240        return sqrtFloor;
241      case CEILING:
242      case UP:
243        int sqrtFloorInt = sqrtFloor.intValue();
244        boolean sqrtFloorIsExact =
245            (sqrtFloorInt * sqrtFloorInt == x.intValue()) // fast check mod 2^32
246                && sqrtFloor.pow(2).equals(x); // slow exact check
247        return sqrtFloorIsExact ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
248      case HALF_DOWN:
249      case HALF_UP:
250      case HALF_EVEN:
251        BigInteger halfSquare = sqrtFloor.pow(2).add(sqrtFloor);
252        /*
253         * We wish to test whether or not x <= (sqrtFloor + 0.5)^2 = halfSquare + 0.25. Since both x
254         * and halfSquare are integers, this is equivalent to testing whether or not x <=
255         * halfSquare.
256         */
257        return (halfSquare.compareTo(x) >= 0) ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
258      default:
259        throw new AssertionError();
260    }
261  }
262
263  @GwtIncompatible // TODO
264  private static BigInteger sqrtFloor(BigInteger x) {
265    /*
266     * Adapted from Hacker's Delight, Figure 11-1.
267     *
268     * Using DoubleUtils.bigToDouble, getting a double approximation of x is extremely fast, and
269     * then we can get a double approximation of the square root. Then, we iteratively improve this
270     * guess with an application of Newton's method, which sets guess := (guess + (x / guess)) / 2.
271     * This iteration has the following two properties:
272     *
273     * a) every iteration (except potentially the first) has guess >= floor(sqrt(x)). This is
274     * because guess' is the arithmetic mean of guess and x / guess, sqrt(x) is the geometric mean,
275     * and the arithmetic mean is always higher than the geometric mean.
276     *
277     * b) this iteration converges to floor(sqrt(x)). In fact, the number of correct digits doubles
278     * with each iteration, so this algorithm takes O(log(digits)) iterations.
279     *
280     * We start out with a double-precision approximation, which may be higher or lower than the
281     * true value. Therefore, we perform at least one Newton iteration to get a guess that's
282     * definitely >= floor(sqrt(x)), and then continue the iteration until we reach a fixed point.
283     */
284    BigInteger sqrt0;
285    int log2 = log2(x, FLOOR);
286    if (log2 < Double.MAX_EXPONENT) {
287      sqrt0 = sqrtApproxWithDoubles(x);
288    } else {
289      int shift = (log2 - DoubleUtils.SIGNIFICAND_BITS) & ~1; // even!
290      /*
291       * We have that x / 2^shift < 2^54. Our initial approximation to sqrtFloor(x) will be
292       * 2^(shift/2) * sqrtApproxWithDoubles(x / 2^shift).
293       */
294      sqrt0 = sqrtApproxWithDoubles(x.shiftRight(shift)).shiftLeft(shift >> 1);
295    }
296    BigInteger sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
297    if (sqrt0.equals(sqrt1)) {
298      return sqrt0;
299    }
300    do {
301      sqrt0 = sqrt1;
302      sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
303    } while (sqrt1.compareTo(sqrt0) < 0);
304    return sqrt0;
305  }
306
307  @GwtIncompatible // TODO
308  private static BigInteger sqrtApproxWithDoubles(BigInteger x) {
309    return DoubleMath.roundToBigInteger(Math.sqrt(DoubleUtils.bigToDouble(x)), HALF_EVEN);
310  }
311
312  /**
313   * Returns {@code x}, rounded to a {@code double} with the specified rounding mode. If {@code x}
314   * is precisely representable as a {@code double}, its {@code double} value will be returned;
315   * otherwise, the rounding will choose between the two nearest representable values with {@code
316   * mode}.
317   *
318   * <p>For the case of {@link RoundingMode#HALF_DOWN}, {@code HALF_UP}, and {@code HALF_EVEN},
319   * infinite {@code double} values are considered infinitely far away. For example, 2^2000 is not
320   * representable as a double, but {@code roundToDouble(BigInteger.valueOf(2).pow(2000), HALF_UP)}
321   * will return {@code Double.MAX_VALUE}, not {@code Double.POSITIVE_INFINITY}.
322   *
323   * <p>For the case of {@link RoundingMode#HALF_EVEN}, this implementation uses the IEEE 754
324   * default rounding mode: if the two nearest representable values are equally near, the one with
325   * the least significant bit zero is chosen. (In such cases, both of the nearest representable
326   * values are even integers; this method returns the one that is a multiple of a greater power of
327   * two.)
328   *
329   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
330   *     is not precisely representable as a {@code double}
331   * @since 30.0
332   */
333  @GwtIncompatible
334  public static double roundToDouble(BigInteger x, RoundingMode mode) {
335    return BigIntegerToDoubleRounder.INSTANCE.roundToDouble(x, mode);
336  }
337
338  @GwtIncompatible
339  private static class BigIntegerToDoubleRounder extends ToDoubleRounder<BigInteger> {
340    static final BigIntegerToDoubleRounder INSTANCE = new BigIntegerToDoubleRounder();
341
342    private BigIntegerToDoubleRounder() {}
343
344    @Override
345    double roundToDoubleArbitrarily(BigInteger bigInteger) {
346      return DoubleUtils.bigToDouble(bigInteger);
347    }
348
349    @Override
350    int sign(BigInteger bigInteger) {
351      return bigInteger.signum();
352    }
353
354    @Override
355    BigInteger toX(double d, RoundingMode mode) {
356      return DoubleMath.roundToBigInteger(d, mode);
357    }
358
359    @Override
360    BigInteger minus(BigInteger a, BigInteger b) {
361      return a.subtract(b);
362    }
363  }
364
365  /**
366   * Returns the result of dividing {@code p} by {@code q}, rounding using the specified {@code
367   * RoundingMode}.
368   *
369   * @throws ArithmeticException if {@code q == 0}, or if {@code mode == UNNECESSARY} and {@code a}
370   *     is not an integer multiple of {@code b}
371   */
372  @GwtIncompatible // TODO
373  public static BigInteger divide(BigInteger p, BigInteger q, RoundingMode mode) {
374    BigDecimal pDec = new BigDecimal(p);
375    BigDecimal qDec = new BigDecimal(q);
376    return pDec.divide(qDec, 0, mode).toBigIntegerExact();
377  }
378
379  /**
380   * Returns {@code n!}, that is, the product of the first {@code n} positive integers, or {@code 1}
381   * if {@code n == 0}.
382   *
383   * <p><b>Warning:</b> the result takes <i>O(n log n)</i> space, so use cautiously.
384   *
385   * <p>This uses an efficient binary recursive algorithm to compute the factorial with balanced
386   * multiplies. It also removes all the 2s from the intermediate products (shifting them back in at
387   * the end).
388   *
389   * @throws IllegalArgumentException if {@code n < 0}
390   */
391  public static BigInteger factorial(int n) {
392    checkNonNegative("n", n);
393
394    // If the factorial is small enough, just use LongMath to do it.
395    if (n < LongMath.factorials.length) {
396      return BigInteger.valueOf(LongMath.factorials[n]);
397    }
398
399    // Pre-allocate space for our list of intermediate BigIntegers.
400    int approxSize = IntMath.divide(n * IntMath.log2(n, CEILING), Long.SIZE, CEILING);
401    ArrayList<BigInteger> bignums = new ArrayList<>(approxSize);
402
403    // Start from the pre-computed maximum long factorial.
404    int startingNumber = LongMath.factorials.length;
405    long product = LongMath.factorials[startingNumber - 1];
406    // Strip off 2s from this value.
407    int shift = Long.numberOfTrailingZeros(product);
408    product >>= shift;
409
410    // Use floor(log2(num)) + 1 to prevent overflow of multiplication.
411    int productBits = LongMath.log2(product, FLOOR) + 1;
412    int bits = LongMath.log2(startingNumber, FLOOR) + 1;
413    // Check for the next power of two boundary, to save us a CLZ operation.
414    int nextPowerOfTwo = 1 << (bits - 1);
415
416    // Iteratively multiply the longs as big as they can go.
417    for (long num = startingNumber; num <= n; num++) {
418      // Check to see if the floor(log2(num)) + 1 has changed.
419      if ((num & nextPowerOfTwo) != 0) {
420        nextPowerOfTwo <<= 1;
421        bits++;
422      }
423      // Get rid of the 2s in num.
424      int tz = Long.numberOfTrailingZeros(num);
425      long normalizedNum = num >> tz;
426      shift += tz;
427      // Adjust floor(log2(num)) + 1.
428      int normalizedBits = bits - tz;
429      // If it won't fit in a long, then we store off the intermediate product.
430      if (normalizedBits + productBits >= Long.SIZE) {
431        bignums.add(BigInteger.valueOf(product));
432        product = 1;
433        productBits = 0;
434      }
435      product *= normalizedNum;
436      productBits = LongMath.log2(product, FLOOR) + 1;
437    }
438    // Check for leftovers.
439    if (product > 1) {
440      bignums.add(BigInteger.valueOf(product));
441    }
442    // Efficiently multiply all the intermediate products together.
443    return listProduct(bignums).shiftLeft(shift);
444  }
445
446  static BigInteger listProduct(List<BigInteger> nums) {
447    return listProduct(nums, 0, nums.size());
448  }
449
450  static BigInteger listProduct(List<BigInteger> nums, int start, int end) {
451    switch (end - start) {
452      case 0:
453        return BigInteger.ONE;
454      case 1:
455        return nums.get(start);
456      case 2:
457        return nums.get(start).multiply(nums.get(start + 1));
458      case 3:
459        return nums.get(start).multiply(nums.get(start + 1)).multiply(nums.get(start + 2));
460      default:
461        // Otherwise, split the list in half and recursively do this.
462        int m = (end + start) >>> 1;
463        return listProduct(nums, start, m).multiply(listProduct(nums, m, end));
464    }
465  }
466
467  /**
468   * Returns {@code n} choose {@code k}, also known as the binomial coefficient of {@code n} and
469   * {@code k}, that is, {@code n! / (k! (n - k)!)}.
470   *
471   * <p><b>Warning:</b> the result can take as much as <i>O(k log n)</i> space.
472   *
473   * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0}, or {@code k > n}
474   */
475  public static BigInteger binomial(int n, int k) {
476    checkNonNegative("n", n);
477    checkNonNegative("k", k);
478    checkArgument(k <= n, "k (%s) > n (%s)", k, n);
479    if (k > (n >> 1)) {
480      k = n - k;
481    }
482    if (k < LongMath.biggestBinomials.length && n <= LongMath.biggestBinomials[k]) {
483      return BigInteger.valueOf(LongMath.binomial(n, k));
484    }
485
486    BigInteger accum = BigInteger.ONE;
487
488    long numeratorAccum = n;
489    long denominatorAccum = 1;
490
491    int bits = LongMath.log2(n, CEILING);
492
493    int numeratorBits = bits;
494
495    for (int i = 1; i < k; i++) {
496      int p = n - i;
497      int q = i + 1;
498
499      // log2(p) >= bits - 1, because p >= n/2
500
501      if (numeratorBits + bits >= Long.SIZE - 1) {
502        // The numerator is as big as it can get without risking overflow.
503        // Multiply numeratorAccum / denominatorAccum into accum.
504        accum =
505            accum
506                .multiply(BigInteger.valueOf(numeratorAccum))
507                .divide(BigInteger.valueOf(denominatorAccum));
508        numeratorAccum = p;
509        denominatorAccum = q;
510        numeratorBits = bits;
511      } else {
512        // We can definitely multiply into the long accumulators without overflowing them.
513        numeratorAccum *= p;
514        denominatorAccum *= q;
515        numeratorBits += bits;
516      }
517    }
518    return accum
519        .multiply(BigInteger.valueOf(numeratorAccum))
520        .divide(BigInteger.valueOf(denominatorAccum));
521  }
522
523  // Returns true if BigInteger.valueOf(x.longValue()).equals(x).
524  @GwtIncompatible // TODO
525  static boolean fitsInLong(BigInteger x) {
526    return x.bitLength() <= Long.SIZE - 1;
527  }
528
529  private BigIntegerMath() {}
530}