8334755: Asymptotically faster implementation of square root algorithm

Reviewed-by: rgiulietti
This commit is contained in:
fabioromano1 2024-08-03 13:08:54 +00:00 committed by Raffaello Giulietti
parent 34edc7358f
commit 367e0a6556
4 changed files with 390 additions and 85 deletions

View File

@ -2723,7 +2723,7 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
throw new ArithmeticException("Negative BigInteger");
}
return new MutableBigInteger(this.mag).sqrt().toBigInteger();
return new MutableBigInteger(this.mag).sqrtRem(false)[0].toBigInteger();
}
/**
@ -2742,10 +2742,12 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
* @since 9
*/
public BigInteger[] sqrtAndRemainder() {
BigInteger s = sqrt();
BigInteger r = this.subtract(s.square());
assert r.compareTo(BigInteger.ZERO) >= 0;
return new BigInteger[] {s, r};
if (this.signum < 0) {
throw new ArithmeticException("Negative BigInteger");
}
MutableBigInteger[] sqrtRem = new MutableBigInteger(this.mag).sqrtRem(true);
return new BigInteger[] { sqrtRem[0].toBigInteger(), sqrtRem[1].toBigInteger() };
}
/**

View File

@ -109,9 +109,26 @@ class MutableBigInteger {
* the int val.
*/
MutableBigInteger(int val) {
value = new int[1];
intLen = 1;
value[0] = val;
init(val);
}
/**
* Construct a new MutableBigInteger with a magnitude specified by
* the long val.
*/
MutableBigInteger(long val) {
int hi = (int) (val >>> 32);
if (hi == 0) {
init((int) val);
} else {
value = new int[] { hi, (int) val };
intLen = 2;
}
}
private void init(int val) {
value = new int[] { val };
intLen = val != 0 ? 1 : 0;
}
/**
@ -260,6 +277,7 @@ class MutableBigInteger {
* Compare the magnitude of two MutableBigIntegers. Returns -1, 0 or 1
* as this MutableBigInteger is numerically less than, equal to, or
* greater than {@code b}.
* Assumes no leading unnecessary zeros.
*/
final int compare(MutableBigInteger b) {
int blen = b.intLen;
@ -285,6 +303,7 @@ class MutableBigInteger {
/**
* Returns a value equal to what {@code b.leftShift(32*ints); return compare(b);}
* would return, but doesn't change the value of {@code b}.
* Assumes no leading unnecessary zeros.
*/
private int compareShifted(MutableBigInteger b, int ints) {
int blen = b.intLen;
@ -538,6 +557,7 @@ class MutableBigInteger {
/**
* Right shift this MutableBigInteger n bits. The MutableBigInteger is left
* in normal form.
* Assumes {@code Math.ceilDiv(n, 32) <= intLen || intLen == 0}
*/
void rightShift(int n) {
if (intLen == 0)
@ -911,6 +931,58 @@ class MutableBigInteger {
add(a);
}
/**
* Shifts {@code this} of {@code n} ints to the left and adds {@code addend}.
* Assumes {@code n > 0} for speed.
*/
void shiftAdd(MutableBigInteger addend, int n) {
// Fast cases
if (addend.intLen <= n) {
shiftAddDisjoint(addend, n);
} else if (intLen == 0) {
copyValue(addend);
} else {
leftShift(n << 5);
add(addend);
}
}
/**
* Shifts {@code this} of {@code n} ints to the left and adds {@code addend}.
* Assumes {@code addend.intLen <= n}.
*/
void shiftAddDisjoint(MutableBigInteger addend, int n) {
if (intLen == 0) { // Avoid unnormal values
copyValue(addend);
return;
}
int[] res;
final int resLen = intLen + n, resOffset;
if (resLen > value.length) {
res = new int[resLen];
System.arraycopy(value, offset, res, 0, intLen);
resOffset = 0;
} else {
res = value;
if (offset + resLen > value.length) {
System.arraycopy(value, offset, res, 0, intLen);
resOffset = 0;
} else {
resOffset = offset;
}
// Clear words where necessary
if (addend.intLen < n)
Arrays.fill(res, resOffset + intLen, resOffset + resLen - addend.intLen, 0);
}
System.arraycopy(addend.value, addend.offset, res, resOffset + resLen - addend.intLen, addend.intLen);
value = res;
offset = resOffset;
intLen = resLen;
}
/**
* Subtracts the smaller of this and b from the larger and places the
* result into this MutableBigInteger.
@ -1003,6 +1075,7 @@ class MutableBigInteger {
/**
* Multiply the contents of two MutableBigInteger objects. The result is
* placed into MutableBigInteger z. The contents of y are not changed.
* Assume {@code intLen > 0}
*/
void multiply(MutableBigInteger y, MutableBigInteger z) {
int xLen = intLen;
@ -1793,93 +1866,169 @@ class MutableBigInteger {
}
/**
* Calculate the integer square root {@code floor(sqrt(this))} where
* {@code sqrt(.)} denotes the mathematical square root. The contents of
* {@code this} are <b>not</b> changed. The value of {@code this} is assumed
* to be non-negative.
* Calculate the integer square root {@code floor(sqrt(this))} and the remainder
* if needed, where {@code sqrt(.)} denotes the mathematical square root.
* The contents of {@code this} are <em>not</em> changed.
* The value of {@code this} is assumed to be non-negative.
*
* @implNote The implementation is based on the material in Henry S. Warren,
* Jr., <i>Hacker's Delight (2nd ed.)</i> (Addison Wesley, 2013), 279-282.
*
* @throws ArithmeticException if the value returned by {@code bitLength()}
* overflows the range of {@code int}.
* @return the integer square root of {@code this}
* @since 9
* @return the integer square root of {@code this} and the remainder if needed
*/
MutableBigInteger sqrt() {
MutableBigInteger[] sqrtRem(boolean needRemainder) {
// Special cases.
if (this.isZero()) {
return new MutableBigInteger(0);
} else if (this.value.length == 1
&& (this.value[0] & LONG_MASK) < 4) { // result is unity
return ONE;
if (this.intLen <= 2) {
final long x = this.toLong(); // unsigned
long s = unsignedLongSqrt(x);
return new MutableBigInteger[] {
new MutableBigInteger((int) s),
needRemainder ? new MutableBigInteger(x - s * s) : null
};
}
if (bitLength() <= 63) {
// Initial estimate is the square root of the positive long value.
long v = new BigInteger(this.value, 1).longValueExact();
long xk = (long)Math.floor(Math.sqrt(v));
// Normalize
MutableBigInteger x = this;
final int shift = (Integer.numberOfLeadingZeros(x.value[x.offset]) & ~1) // shift must be even
+ ((x.intLen & 1) << 5); // x.intLen must be even
// Refine the estimate.
do {
long xk1 = (xk + v/xk)/2;
// Terminate when non-decreasing.
if (xk1 >= xk) {
return new MutableBigInteger(new int[] {
(int)(xk >>> 32), (int)(xk & LONG_MASK)
});
if (shift != 0) {
x = new MutableBigInteger(x);
x.leftShift(shift);
}
xk = xk1;
} while (true);
// Compute sqrt and remainder
MutableBigInteger[] sqrtRem = x.sqrtRemKaratsuba(x.intLen, needRemainder);
// Unnormalize
if (shift != 0) {
final int halfShift = shift >> 1;
if (needRemainder) {
// shift <= 62, so s0 is at most 31 bit long
final long s0 = sqrtRem[0].value[sqrtRem[0].offset + sqrtRem[0].intLen - 1]
& (-1 >>> -halfShift); // Remove excess bits
if (s0 != 0L) { // An optimization
MutableBigInteger doubleProd = new MutableBigInteger();
sqrtRem[0].mul((int) (s0 << 1), doubleProd);
sqrtRem[1].add(doubleProd);
sqrtRem[1].subtract(new MutableBigInteger(s0 * s0));
}
sqrtRem[1].rightShift(shift);
}
sqrtRem[0].primitiveRightShift(halfShift);
}
return sqrtRem;
}
private static long unsignedLongSqrt(long x) {
/* For every long value s in [0, 2^32) such that x == s * s,
* it is true that s - 1 <= (long) Math.sqrt(x >= 0 ? x : x + 0x1p64) <= s,
* and if x == 2^64 - 1, then (long) Math.sqrt(x >= 0 ? x : x + 0x1p64) == 2^32.
* Since both cast to long and `Math.sqrt()` are (weakly) increasing,
* this means that the value returned by Math.sqrt()
* for a long value in the range [0, 2^64) is either correct,
* or rounded up/down by one if the value is too high
* and too close to a perfect square.
*/
long s = (long) Math.sqrt(x >= 0 ? x : x + 0x1p64);
long s2 = s * s; // overflows iff s == 2^32
return Long.compareUnsigned(x, s2) < 0 || s > LONG_MASK
? s - 1
: (Long.compareUnsigned(x, s2 + (s << 1)) <= 0 // x <= (s + 1)^2 - 1, does not overflow
? s
: s + 1);
}
/**
* Assumes {@code 2 <= len <= intLen && len % 2 == 0
* && Integer.numberOfLeadingZeros(value[offset]) <= 1}
* @implNote The implementation is based on Zimmermann's works available
* <a href="https://inria.hal.science/inria-00072854v1/document"> here</a> and
* <a href="https://inria.hal.science/inria-00072113/document"> here</a>
*/
private MutableBigInteger[] sqrtRemKaratsuba(int len, boolean needRemainder) {
if (len == 2) { // Base case
long x = ((value[offset] & LONG_MASK) << 32) | (value[offset + 1] & LONG_MASK);
long s = unsignedLongSqrt(x);
// Allocate sufficient space to hold the final square root, assuming intLen % 2 == 0
MutableBigInteger sqrt = new MutableBigInteger(new int[intLen >> 1]);
// Place the partial square root
sqrt.intLen = 1;
sqrt.value[0] = (int) s;
return new MutableBigInteger[] { sqrt, new MutableBigInteger(x - s * s) };
}
// Recursive step (len >= 4)
final int halfLen = len >> 1;
// Recursive invocation
MutableBigInteger[] sr = sqrtRemKaratsuba(halfLen + (halfLen & 1), true);
final int blockLen = halfLen >> 1;
MutableBigInteger dividend = sr[1];
dividend.shiftAddDisjoint(getBlockForSqrt(1, len, blockLen), blockLen);
// Compute dividend / (2*sqrt)
MutableBigInteger sqrt = sr[0];
MutableBigInteger q = new MutableBigInteger();
MutableBigInteger u = dividend.divide(sqrt, q);
if (q.isOdd())
u.add(sqrt);
q.rightShift(1);
sqrt.shiftAdd(q, blockLen);
// Corresponds to ub + a_0 in the paper
u.shiftAddDisjoint(getBlockForSqrt(0, len, blockLen), blockLen);
BigInteger qBig = q.toBigInteger(); // Cast to BigInteger to use fast multiplication
MutableBigInteger qSqr = new MutableBigInteger(qBig.multiply(qBig).mag);
MutableBigInteger rem;
if (needRemainder) {
rem = u;
if (rem.subtract(qSqr) < 0) {
MutableBigInteger twiceSqrt = new MutableBigInteger(sqrt);
twiceSqrt.leftShift(1);
// Since subtract() performs an absolute difference, to get the correct algebraic sum
// we must first add the sum of absolute values of addends concordant with the sign of rem
// and then subtract the sum of absolute values of addends that are discordant
rem.add(ONE);
rem.subtract(twiceSqrt);
sqrt.subtract(ONE);
}
} else {
// Set up the initial estimate of the iteration.
// Obtain the bitLength > 63.
int bitLength = (int) this.bitLength();
if (bitLength != this.bitLength()) {
throw new ArithmeticException("bitLength() integer overflow");
rem = null;
if (u.compare(qSqr) < 0)
sqrt.subtract(ONE);
}
// Determine an even valued right shift into positive long range.
int shift = bitLength - 63;
if (shift % 2 == 1) {
shift++;
sr[1] = rem;
return sr;
}
// Shift the value into positive long range.
MutableBigInteger xk = new MutableBigInteger(this);
xk.rightShift(shift);
xk.normalize();
/**
* Returns a {@code MutableBigInteger} obtained by taking {@code blockLen} ints from
* {@code this} number, ending at {@code blockIndex*blockLen} (exclusive).<br/>
* Used in Karatsuba square root.
* @param blockIndex the block index, starting from the lowest
* @param len the logical length of the input value in units of 32 bits
* @param blockLen the length of the block in units of 32 bits
*
* @return a {@code MutableBigInteger} obtained by taking {@code blockLen} ints from
* {@code this} number, ending at {@code blockIndex*blockLen} (exclusive).
*/
private MutableBigInteger getBlockForSqrt(int blockIndex, int len, int blockLen) {
final int to = offset + len - blockIndex * blockLen;
// Use the square root of the shifted value as an approximation.
double d = new BigInteger(xk.value, 1).doubleValue();
BigInteger bi = BigInteger.valueOf((long)Math.ceil(Math.sqrt(d)));
xk = new MutableBigInteger(bi.mag);
// Skip leading zeros
int from;
for (from = to - blockLen; from < to && value[from] == 0; from++);
// Shift the approximate square root back into the original range.
xk.leftShift(shift / 2);
// Refine the estimate.
MutableBigInteger xk1 = new MutableBigInteger();
do {
// xk1 = (xk + n/xk)/2
this.divide(xk, xk1, false);
xk1.add(xk);
xk1.rightShift(1);
// Terminate when non-decreasing.
if (xk1.compare(xk) >= 0) {
return xk;
}
// xk = xk1
xk.copyValue(xk1);
xk1.reset();
} while (true);
}
return from == to
? new MutableBigInteger()
: new MutableBigInteger(Arrays.copyOfRange(value, from, to));
}
/**

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 1998, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 1998, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -293,8 +293,30 @@ public class BigIntegerTest {
report("squareRootSmall", failCount);
}
private static void perfectSquaresLong() {
/* For every long value n in [0, 2^32) such that x == n * n,
* n - 1 <= (long) Math.sqrt(x >= 0 ? x : x + 0x1p64) <= n
* must be true.
* This property is used to implement MutableBigInteger.unsignedLongSqrt().
*/
int failCount = 0;
long limit = 1L << 32;
for (long n = 0; n < limit; n++) {
long x = n * n;
long s = (long) Math.sqrt(x >= 0 ? x : x + 0x1p64);
if (!(s == n || s == n - 1)) {
failCount++;
System.err.println(s + "^2 != " + x + " && (" + s + "+1)^2 != " + x);
}
}
report("perfectSquaresLong", failCount);
}
public static void squareRoot() {
squareRootSmall();
perfectSquaresLong();
ToIntFunction<BigInteger> f = (n) -> {
int failCount = 0;

View File

@ -0,0 +1,132 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package org.openjdk.bench.java.math;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OperationsPerInvocation;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
import java.math.BigInteger;
import java.util.Random;
import java.util.concurrent.TimeUnit;
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@State(Scope.Thread)
@Warmup(iterations = 5, time = 1)
@Measurement(iterations = 5, time = 1)
@Fork(value = 3)
public class BigIntegerSquareRoot {
private BigInteger[] xsArray, sArray, mArray, lArray, xlArray;
private static final int TESTSIZE = 1000;
@Setup
public void setup() {
Random r = new Random(1123);
xsArray = new BigInteger[TESTSIZE]; /*
* Each array entry is atmost 64 bits
* in size
*/
sArray = new BigInteger[TESTSIZE]; /*
* Each array entry is atmost 256 bits
* in size
*/
mArray = new BigInteger[TESTSIZE]; /*
* Each array entry is atmost 1024 bits
* in size
*/
lArray = new BigInteger[TESTSIZE]; /*
* Each array entry is atmost 4096 bits
* in size
*/
xlArray = new BigInteger[TESTSIZE]; /*
* Each array entry is atmost 16384 bits
* in size
*/
for (int i = 0; i < TESTSIZE; i++) {
xsArray[i] = new BigInteger(r.nextInt(64), r);
sArray[i] = new BigInteger(r.nextInt(256), r);
mArray[i] = new BigInteger(r.nextInt(1024), r);
lArray[i] = new BigInteger(r.nextInt(4096), r);
xlArray[i] = new BigInteger(r.nextInt(16384), r);
}
}
/** Test BigInteger.sqrt() with numbers long at most 64 bits */
@Benchmark
@OperationsPerInvocation(TESTSIZE)
public void testSqrtXS(Blackhole bh) {
for (BigInteger s : xsArray) {
bh.consume(s.sqrt());
}
}
/** Test BigInteger.sqrt() with numbers long at most 256 bits */
@Benchmark
@OperationsPerInvocation(TESTSIZE)
public void testSqrtS(Blackhole bh) {
for (BigInteger s : sArray) {
bh.consume(s.sqrt());
}
}
/** Test BigInteger.sqrt() with numbers long at most 1024 bits */
@Benchmark
@OperationsPerInvocation(TESTSIZE)
public void testSqrtM(Blackhole bh) {
for (BigInteger s : mArray) {
bh.consume(s.sqrt());
}
}
/** Test BigInteger.sqrt() with numbers long at most 4096 bits */
@Benchmark
@OperationsPerInvocation(TESTSIZE)
public void testSqrtL(Blackhole bh) {
for (BigInteger s : lArray) {
bh.consume(s.sqrt());
}
}
/** Test BigInteger.sqrt() with numbers long at most 16384 bits */
@Benchmark
@OperationsPerInvocation(TESTSIZE)
public void testSqrtXL(Blackhole bh) {
for (BigInteger s : xlArray) {
bh.consume(s.sqrt());
}
}
}