From 367e0a65561f95aad61b40930d5f46843fee3444 Mon Sep 17 00:00:00 2001 From: fabioromano1 <51378941+fabioromano1@users.noreply.github.com> Date: Sat, 3 Aug 2024 13:08:54 +0000 Subject: [PATCH] 8334755: Asymptotically faster implementation of square root algorithm Reviewed-by: rgiulietti --- .../share/classes/java/math/BigInteger.java | 12 +- .../classes/java/math/MutableBigInteger.java | 307 +++++++++++++----- .../java/math/BigInteger/BigIntegerTest.java | 24 +- .../bench/java/math/BigIntegerSquareRoot.java | 132 ++++++++ 4 files changed, 390 insertions(+), 85 deletions(-) create mode 100644 test/micro/org/openjdk/bench/java/math/BigIntegerSquareRoot.java diff --git a/src/java.base/share/classes/java/math/BigInteger.java b/src/java.base/share/classes/java/math/BigInteger.java index c69e291d0e2..3a5fd143937 100644 --- a/src/java.base/share/classes/java/math/BigInteger.java +++ b/src/java.base/share/classes/java/math/BigInteger.java @@ -2723,7 +2723,7 @@ public class BigInteger extends Number implements Comparable { throw new ArithmeticException("Negative BigInteger"); } - return new MutableBigInteger(this.mag).sqrt().toBigInteger(); + return new MutableBigInteger(this.mag).sqrtRem(false)[0].toBigInteger(); } /** @@ -2742,10 +2742,12 @@ public class BigInteger extends Number implements Comparable { * @since 9 */ public BigInteger[] sqrtAndRemainder() { - BigInteger s = sqrt(); - BigInteger r = this.subtract(s.square()); - assert r.compareTo(BigInteger.ZERO) >= 0; - return new BigInteger[] {s, r}; + if (this.signum < 0) { + throw new ArithmeticException("Negative BigInteger"); + } + + MutableBigInteger[] sqrtRem = new MutableBigInteger(this.mag).sqrtRem(true); + return new BigInteger[] { sqrtRem[0].toBigInteger(), sqrtRem[1].toBigInteger() }; } /** diff --git a/src/java.base/share/classes/java/math/MutableBigInteger.java b/src/java.base/share/classes/java/math/MutableBigInteger.java index 30ea8e130fc..b84e50f567e 100644 --- a/src/java.base/share/classes/java/math/MutableBigInteger.java +++ b/src/java.base/share/classes/java/math/MutableBigInteger.java @@ -109,9 +109,26 @@ class MutableBigInteger { * the int val. */ MutableBigInteger(int val) { - value = new int[1]; - intLen = 1; - value[0] = val; + init(val); + } + + /** + * Construct a new MutableBigInteger with a magnitude specified by + * the long val. + */ + MutableBigInteger(long val) { + int hi = (int) (val >>> 32); + if (hi == 0) { + init((int) val); + } else { + value = new int[] { hi, (int) val }; + intLen = 2; + } + } + + private void init(int val) { + value = new int[] { val }; + intLen = val != 0 ? 1 : 0; } /** @@ -260,6 +277,7 @@ class MutableBigInteger { * Compare the magnitude of two MutableBigIntegers. Returns -1, 0 or 1 * as this MutableBigInteger is numerically less than, equal to, or * greater than {@code b}. + * Assumes no leading unnecessary zeros. */ final int compare(MutableBigInteger b) { int blen = b.intLen; @@ -285,6 +303,7 @@ class MutableBigInteger { /** * Returns a value equal to what {@code b.leftShift(32*ints); return compare(b);} * would return, but doesn't change the value of {@code b}. + * Assumes no leading unnecessary zeros. */ private int compareShifted(MutableBigInteger b, int ints) { int blen = b.intLen; @@ -538,6 +557,7 @@ class MutableBigInteger { /** * Right shift this MutableBigInteger n bits. The MutableBigInteger is left * in normal form. + * Assumes {@code Math.ceilDiv(n, 32) <= intLen || intLen == 0} */ void rightShift(int n) { if (intLen == 0) @@ -911,6 +931,58 @@ class MutableBigInteger { add(a); } + /** + * Shifts {@code this} of {@code n} ints to the left and adds {@code addend}. + * Assumes {@code n > 0} for speed. + */ + void shiftAdd(MutableBigInteger addend, int n) { + // Fast cases + if (addend.intLen <= n) { + shiftAddDisjoint(addend, n); + } else if (intLen == 0) { + copyValue(addend); + } else { + leftShift(n << 5); + add(addend); + } + } + + /** + * Shifts {@code this} of {@code n} ints to the left and adds {@code addend}. + * Assumes {@code addend.intLen <= n}. + */ + void shiftAddDisjoint(MutableBigInteger addend, int n) { + if (intLen == 0) { // Avoid unnormal values + copyValue(addend); + return; + } + + int[] res; + final int resLen = intLen + n, resOffset; + if (resLen > value.length) { + res = new int[resLen]; + System.arraycopy(value, offset, res, 0, intLen); + resOffset = 0; + } else { + res = value; + if (offset + resLen > value.length) { + System.arraycopy(value, offset, res, 0, intLen); + resOffset = 0; + } else { + resOffset = offset; + } + // Clear words where necessary + if (addend.intLen < n) + Arrays.fill(res, resOffset + intLen, resOffset + resLen - addend.intLen, 0); + } + + System.arraycopy(addend.value, addend.offset, res, resOffset + resLen - addend.intLen, addend.intLen); + + value = res; + offset = resOffset; + intLen = resLen; + } + /** * Subtracts the smaller of this and b from the larger and places the * result into this MutableBigInteger. @@ -1003,6 +1075,7 @@ class MutableBigInteger { /** * Multiply the contents of two MutableBigInteger objects. The result is * placed into MutableBigInteger z. The contents of y are not changed. + * Assume {@code intLen > 0} */ void multiply(MutableBigInteger y, MutableBigInteger z) { int xLen = intLen; @@ -1793,93 +1866,169 @@ class MutableBigInteger { } /** - * Calculate the integer square root {@code floor(sqrt(this))} where - * {@code sqrt(.)} denotes the mathematical square root. The contents of - * {@code this} are not changed. The value of {@code this} is assumed - * to be non-negative. + * Calculate the integer square root {@code floor(sqrt(this))} and the remainder + * if needed, where {@code sqrt(.)} denotes the mathematical square root. + * The contents of {@code this} are not changed. + * The value of {@code this} is assumed to be non-negative. * - * @implNote The implementation is based on the material in Henry S. Warren, - * Jr., Hacker's Delight (2nd ed.) (Addison Wesley, 2013), 279-282. - * - * @throws ArithmeticException if the value returned by {@code bitLength()} - * overflows the range of {@code int}. - * @return the integer square root of {@code this} - * @since 9 + * @return the integer square root of {@code this} and the remainder if needed */ - MutableBigInteger sqrt() { + MutableBigInteger[] sqrtRem(boolean needRemainder) { // Special cases. - if (this.isZero()) { - return new MutableBigInteger(0); - } else if (this.value.length == 1 - && (this.value[0] & LONG_MASK) < 4) { // result is unity - return ONE; + if (this.intLen <= 2) { + final long x = this.toLong(); // unsigned + long s = unsignedLongSqrt(x); + + return new MutableBigInteger[] { + new MutableBigInteger((int) s), + needRemainder ? new MutableBigInteger(x - s * s) : null + }; } - if (bitLength() <= 63) { - // Initial estimate is the square root of the positive long value. - long v = new BigInteger(this.value, 1).longValueExact(); - long xk = (long)Math.floor(Math.sqrt(v)); + // Normalize + MutableBigInteger x = this; + final int shift = (Integer.numberOfLeadingZeros(x.value[x.offset]) & ~1) // shift must be even + + ((x.intLen & 1) << 5); // x.intLen must be even - // Refine the estimate. - do { - long xk1 = (xk + v/xk)/2; + if (shift != 0) { + x = new MutableBigInteger(x); + x.leftShift(shift); + } - // Terminate when non-decreasing. - if (xk1 >= xk) { - return new MutableBigInteger(new int[] { - (int)(xk >>> 32), (int)(xk & LONG_MASK) - }); + // Compute sqrt and remainder + MutableBigInteger[] sqrtRem = x.sqrtRemKaratsuba(x.intLen, needRemainder); + + // Unnormalize + if (shift != 0) { + final int halfShift = shift >> 1; + if (needRemainder) { + // shift <= 62, so s0 is at most 31 bit long + final long s0 = sqrtRem[0].value[sqrtRem[0].offset + sqrtRem[0].intLen - 1] + & (-1 >>> -halfShift); // Remove excess bits + if (s0 != 0L) { // An optimization + MutableBigInteger doubleProd = new MutableBigInteger(); + sqrtRem[0].mul((int) (s0 << 1), doubleProd); + + sqrtRem[1].add(doubleProd); + sqrtRem[1].subtract(new MutableBigInteger(s0 * s0)); } + sqrtRem[1].rightShift(shift); + } + sqrtRem[0].primitiveRightShift(halfShift); + } + return sqrtRem; + } - xk = xk1; - } while (true); + private static long unsignedLongSqrt(long x) { + /* For every long value s in [0, 2^32) such that x == s * s, + * it is true that s - 1 <= (long) Math.sqrt(x >= 0 ? x : x + 0x1p64) <= s, + * and if x == 2^64 - 1, then (long) Math.sqrt(x >= 0 ? x : x + 0x1p64) == 2^32. + * Since both cast to long and `Math.sqrt()` are (weakly) increasing, + * this means that the value returned by Math.sqrt() + * for a long value in the range [0, 2^64) is either correct, + * or rounded up/down by one if the value is too high + * and too close to a perfect square. + */ + long s = (long) Math.sqrt(x >= 0 ? x : x + 0x1p64); + long s2 = s * s; // overflows iff s == 2^32 + return Long.compareUnsigned(x, s2) < 0 || s > LONG_MASK + ? s - 1 + : (Long.compareUnsigned(x, s2 + (s << 1)) <= 0 // x <= (s + 1)^2 - 1, does not overflow + ? s + : s + 1); + } + + /** + * Assumes {@code 2 <= len <= intLen && len % 2 == 0 + * && Integer.numberOfLeadingZeros(value[offset]) <= 1} + * @implNote The implementation is based on Zimmermann's works available + * here and + * here + */ + private MutableBigInteger[] sqrtRemKaratsuba(int len, boolean needRemainder) { + if (len == 2) { // Base case + long x = ((value[offset] & LONG_MASK) << 32) | (value[offset + 1] & LONG_MASK); + long s = unsignedLongSqrt(x); + + // Allocate sufficient space to hold the final square root, assuming intLen % 2 == 0 + MutableBigInteger sqrt = new MutableBigInteger(new int[intLen >> 1]); + + // Place the partial square root + sqrt.intLen = 1; + sqrt.value[0] = (int) s; + + return new MutableBigInteger[] { sqrt, new MutableBigInteger(x - s * s) }; + } + + // Recursive step (len >= 4) + + final int halfLen = len >> 1; + // Recursive invocation + MutableBigInteger[] sr = sqrtRemKaratsuba(halfLen + (halfLen & 1), true); + + final int blockLen = halfLen >> 1; + MutableBigInteger dividend = sr[1]; + dividend.shiftAddDisjoint(getBlockForSqrt(1, len, blockLen), blockLen); + + // Compute dividend / (2*sqrt) + MutableBigInteger sqrt = sr[0]; + MutableBigInteger q = new MutableBigInteger(); + MutableBigInteger u = dividend.divide(sqrt, q); + if (q.isOdd()) + u.add(sqrt); + q.rightShift(1); + + sqrt.shiftAdd(q, blockLen); + // Corresponds to ub + a_0 in the paper + u.shiftAddDisjoint(getBlockForSqrt(0, len, blockLen), blockLen); + BigInteger qBig = q.toBigInteger(); // Cast to BigInteger to use fast multiplication + MutableBigInteger qSqr = new MutableBigInteger(qBig.multiply(qBig).mag); + + MutableBigInteger rem; + if (needRemainder) { + rem = u; + if (rem.subtract(qSqr) < 0) { + MutableBigInteger twiceSqrt = new MutableBigInteger(sqrt); + twiceSqrt.leftShift(1); + + // Since subtract() performs an absolute difference, to get the correct algebraic sum + // we must first add the sum of absolute values of addends concordant with the sign of rem + // and then subtract the sum of absolute values of addends that are discordant + rem.add(ONE); + rem.subtract(twiceSqrt); + sqrt.subtract(ONE); + } } else { - // Set up the initial estimate of the iteration. - - // Obtain the bitLength > 63. - int bitLength = (int) this.bitLength(); - if (bitLength != this.bitLength()) { - throw new ArithmeticException("bitLength() integer overflow"); - } - - // Determine an even valued right shift into positive long range. - int shift = bitLength - 63; - if (shift % 2 == 1) { - shift++; - } - - // Shift the value into positive long range. - MutableBigInteger xk = new MutableBigInteger(this); - xk.rightShift(shift); - xk.normalize(); - - // Use the square root of the shifted value as an approximation. - double d = new BigInteger(xk.value, 1).doubleValue(); - BigInteger bi = BigInteger.valueOf((long)Math.ceil(Math.sqrt(d))); - xk = new MutableBigInteger(bi.mag); - - // Shift the approximate square root back into the original range. - xk.leftShift(shift / 2); - - // Refine the estimate. - MutableBigInteger xk1 = new MutableBigInteger(); - do { - // xk1 = (xk + n/xk)/2 - this.divide(xk, xk1, false); - xk1.add(xk); - xk1.rightShift(1); - - // Terminate when non-decreasing. - if (xk1.compare(xk) >= 0) { - return xk; - } - - // xk = xk1 - xk.copyValue(xk1); - - xk1.reset(); - } while (true); + rem = null; + if (u.compare(qSqr) < 0) + sqrt.subtract(ONE); } + + sr[1] = rem; + return sr; + } + + /** + * Returns a {@code MutableBigInteger} obtained by taking {@code blockLen} ints from + * {@code this} number, ending at {@code blockIndex*blockLen} (exclusive).
+ * Used in Karatsuba square root. + * @param blockIndex the block index, starting from the lowest + * @param len the logical length of the input value in units of 32 bits + * @param blockLen the length of the block in units of 32 bits + * + * @return a {@code MutableBigInteger} obtained by taking {@code blockLen} ints from + * {@code this} number, ending at {@code blockIndex*blockLen} (exclusive). + */ + private MutableBigInteger getBlockForSqrt(int blockIndex, int len, int blockLen) { + final int to = offset + len - blockIndex * blockLen; + + // Skip leading zeros + int from; + for (from = to - blockLen; from < to && value[from] == 0; from++); + + return from == to + ? new MutableBigInteger() + : new MutableBigInteger(Arrays.copyOfRange(value, from, to)); } /** diff --git a/test/jdk/java/math/BigInteger/BigIntegerTest.java b/test/jdk/java/math/BigInteger/BigIntegerTest.java index 2ac4750e43f..7da3fdac618 100644 --- a/test/jdk/java/math/BigInteger/BigIntegerTest.java +++ b/test/jdk/java/math/BigInteger/BigIntegerTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 1998, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1998, 2024, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -293,8 +293,30 @@ public class BigIntegerTest { report("squareRootSmall", failCount); } + private static void perfectSquaresLong() { + /* For every long value n in [0, 2^32) such that x == n * n, + * n - 1 <= (long) Math.sqrt(x >= 0 ? x : x + 0x1p64) <= n + * must be true. + * This property is used to implement MutableBigInteger.unsignedLongSqrt(). + */ + int failCount = 0; + + long limit = 1L << 32; + for (long n = 0; n < limit; n++) { + long x = n * n; + long s = (long) Math.sqrt(x >= 0 ? x : x + 0x1p64); + if (!(s == n || s == n - 1)) { + failCount++; + System.err.println(s + "^2 != " + x + " && (" + s + "+1)^2 != " + x); + } + } + + report("perfectSquaresLong", failCount); + } + public static void squareRoot() { squareRootSmall(); + perfectSquaresLong(); ToIntFunction f = (n) -> { int failCount = 0; diff --git a/test/micro/org/openjdk/bench/java/math/BigIntegerSquareRoot.java b/test/micro/org/openjdk/bench/java/math/BigIntegerSquareRoot.java new file mode 100644 index 00000000000..4b78b4cd8fa --- /dev/null +++ b/test/micro/org/openjdk/bench/java/math/BigIntegerSquareRoot.java @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package org.openjdk.bench.java.math; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OperationsPerInvocation; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import java.math.BigInteger; +import java.util.Random; +import java.util.concurrent.TimeUnit; + +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@State(Scope.Thread) +@Warmup(iterations = 5, time = 1) +@Measurement(iterations = 5, time = 1) +@Fork(value = 3) +public class BigIntegerSquareRoot { + + private BigInteger[] xsArray, sArray, mArray, lArray, xlArray; + private static final int TESTSIZE = 1000; + + @Setup + public void setup() { + Random r = new Random(1123); + + xsArray = new BigInteger[TESTSIZE]; /* + * Each array entry is atmost 64 bits + * in size + */ + sArray = new BigInteger[TESTSIZE]; /* + * Each array entry is atmost 256 bits + * in size + */ + mArray = new BigInteger[TESTSIZE]; /* + * Each array entry is atmost 1024 bits + * in size + */ + lArray = new BigInteger[TESTSIZE]; /* + * Each array entry is atmost 4096 bits + * in size + */ + xlArray = new BigInteger[TESTSIZE]; /* + * Each array entry is atmost 16384 bits + * in size + */ + + for (int i = 0; i < TESTSIZE; i++) { + xsArray[i] = new BigInteger(r.nextInt(64), r); + sArray[i] = new BigInteger(r.nextInt(256), r); + mArray[i] = new BigInteger(r.nextInt(1024), r); + lArray[i] = new BigInteger(r.nextInt(4096), r); + xlArray[i] = new BigInteger(r.nextInt(16384), r); + } + } + + /** Test BigInteger.sqrt() with numbers long at most 64 bits */ + @Benchmark + @OperationsPerInvocation(TESTSIZE) + public void testSqrtXS(Blackhole bh) { + for (BigInteger s : xsArray) { + bh.consume(s.sqrt()); + } + } + + /** Test BigInteger.sqrt() with numbers long at most 256 bits */ + @Benchmark + @OperationsPerInvocation(TESTSIZE) + public void testSqrtS(Blackhole bh) { + for (BigInteger s : sArray) { + bh.consume(s.sqrt()); + } + } + + /** Test BigInteger.sqrt() with numbers long at most 1024 bits */ + @Benchmark + @OperationsPerInvocation(TESTSIZE) + public void testSqrtM(Blackhole bh) { + for (BigInteger s : mArray) { + bh.consume(s.sqrt()); + } + } + + /** Test BigInteger.sqrt() with numbers long at most 4096 bits */ + @Benchmark + @OperationsPerInvocation(TESTSIZE) + public void testSqrtL(Blackhole bh) { + for (BigInteger s : lArray) { + bh.consume(s.sqrt()); + } + } + + /** Test BigInteger.sqrt() with numbers long at most 16384 bits */ + @Benchmark + @OperationsPerInvocation(TESTSIZE) + public void testSqrtXL(Blackhole bh) { + for (BigInteger s : xlArray) { + bh.consume(s.sqrt()); + } + } +}