diff --git a/src/java.base/share/classes/java/math/BigInteger.java b/src/java.base/share/classes/java/math/BigInteger.java index 34f1953d003..81d9a9cf248 100644 --- a/src/java.base/share/classes/java/math/BigInteger.java +++ b/src/java.base/share/classes/java/math/BigInteger.java @@ -36,6 +36,9 @@ import java.io.ObjectStreamField; import java.util.Arrays; import java.util.Objects; import java.util.Random; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.ForkJoinWorkerThread; +import java.util.concurrent.RecursiveTask; import java.util.concurrent.ThreadLocalRandom; import jdk.internal.math.DoubleConsts; @@ -1581,7 +1584,30 @@ public class BigInteger extends Number implements Comparable<BigInteger> { * @return {@code this * val} */ public BigInteger multiply(BigInteger val) { - return multiply(val, false); + return multiply(val, false, false, 0); + } + + /** + * Returns a BigInteger whose value is {@code (this * val)}. + * When both {@code this} and {@code val} are large, typically + * in the thousands of bits, parallel multiply might be used. + * This method returns the exact same mathematical result as + * {@link #multiply}. + * + * @implNote This implementation may offer better algorithmic + * performance when {@code val == this}. + * + * @implNote Compared to {@link #multiply}, an implementation's + * parallel multiplication algorithm would typically use more + * CPU resources to compute the result faster, and may do so + * with a slight increase in memory consumption. + * + * @param val value to be multiplied by this BigInteger. + * @return {@code this * val} + * @see #multiply + */ + public BigInteger parallelMultiply(BigInteger val) { + return multiply(val, false, true, 0); } /** @@ -1590,16 +1616,17 @@ public class BigInteger extends Number implements Comparable<BigInteger> { * * @param val value to be multiplied by this BigInteger. * @param isRecursion whether this is a recursive invocation + * @param parallel whether the multiply should be done in parallel * @return {@code this * val} */ - private BigInteger multiply(BigInteger val, boolean isRecursion) { + private BigInteger multiply(BigInteger val, boolean isRecursion, boolean parallel, int depth) { if (val.signum == 0 || signum == 0) return ZERO; int xlen = mag.length; if (val == this && xlen > MULTIPLY_SQUARE_THRESHOLD) { - return square(); + return square(true, parallel, depth); } int ylen = val.mag.length; @@ -1677,7 +1704,7 @@ public class BigInteger extends Number implements Comparable<BigInteger> { } } - return multiplyToomCook3(this, val); + return multiplyToomCook3(this, val, parallel, depth); } } } @@ -1844,6 +1871,88 @@ public class BigInteger extends Number implements Comparable<BigInteger> { } } + @SuppressWarnings("serial") + private abstract static sealed class RecursiveOp extends RecursiveTask<BigInteger> { + /** + * The threshold until when we should continue forking recursive ops + * if parallel is true. This threshold is only relevant for Toom Cook 3 + * multiply and square. + */ + private static final int PARALLEL_FORK_DEPTH_THRESHOLD = + calculateMaximumDepth(ForkJoinPool.getCommonPoolParallelism()); + + private static final int calculateMaximumDepth(int parallelism) { + return 32 - Integer.numberOfLeadingZeros(parallelism); + } + + final boolean parallel; + /** + * The current recursing depth. Since it is a logarithmic algorithm, + * we do not need an int to hold the number. + */ + final byte depth; + + private RecursiveOp(boolean parallel, int depth) { + this.parallel = parallel; + this.depth = (byte) depth; + } + + private static int getParallelForkDepthThreshold() { + if (Thread.currentThread() instanceof ForkJoinWorkerThread fjwt) { + return calculateMaximumDepth(fjwt.getPool().getParallelism()); + } + else { + return PARALLEL_FORK_DEPTH_THRESHOLD; + } + } + + protected RecursiveTask<BigInteger> forkOrInvoke() { + if (parallel && depth <= getParallelForkDepthThreshold()) fork(); + else invoke(); + return this; + } + + @SuppressWarnings("serial") + private static final class RecursiveMultiply extends RecursiveOp { + private final BigInteger a; + private final BigInteger b; + + public RecursiveMultiply(BigInteger a, BigInteger b, boolean parallel, int depth) { + super(parallel, depth); + this.a = a; + this.b = b; + } + + @Override + public BigInteger compute() { + return a.multiply(b, true, parallel, depth); + } + } + + @SuppressWarnings("serial") + private static final class RecursiveSquare extends RecursiveOp { + private final BigInteger a; + + public RecursiveSquare(BigInteger a, boolean parallel, int depth) { + super(parallel, depth); + this.a = a; + } + + @Override + public BigInteger compute() { + return a.square(true, parallel, depth); + } + } + + private static RecursiveTask<BigInteger> multiply(BigInteger a, BigInteger b, boolean parallel, int depth) { + return new RecursiveMultiply(a, b, parallel, depth).forkOrInvoke(); + } + + private static RecursiveTask<BigInteger> square(BigInteger a, boolean parallel, int depth) { + return new RecursiveSquare(a, parallel, depth).forkOrInvoke(); + } + } + /** * Multiplies two BigIntegers using a 3-way Toom-Cook multiplication * algorithm. This is a recursive divide-and-conquer algorithm which is @@ -1872,7 +1981,7 @@ public class BigInteger extends Number implements Comparable<BigInteger> { * LNCS #4547. Springer, Madrid, Spain, June 21-22, 2007. * */ - private static BigInteger multiplyToomCook3(BigInteger a, BigInteger b) { + private static BigInteger multiplyToomCook3(BigInteger a, BigInteger b, boolean parallel, int depth) { int alen = a.mag.length; int blen = b.mag.length; @@ -1896,16 +2005,20 @@ public class BigInteger extends Number implements Comparable<BigInteger> { BigInteger v0, v1, v2, vm1, vinf, t1, t2, tm1, da1, db1; - v0 = a0.multiply(b0, true); + depth++; + var v0_task = RecursiveOp.multiply(a0, b0, parallel, depth); da1 = a2.add(a0); db1 = b2.add(b0); - vm1 = da1.subtract(a1).multiply(db1.subtract(b1), true); + var vm1_task = RecursiveOp.multiply(da1.subtract(a1), db1.subtract(b1), parallel, depth); da1 = da1.add(a1); db1 = db1.add(b1); - v1 = da1.multiply(db1, true); + var v1_task = RecursiveOp.multiply(da1, db1, parallel, depth); v2 = da1.add(a2).shiftLeft(1).subtract(a0).multiply( - db1.add(b2).shiftLeft(1).subtract(b0), true); - vinf = a2.multiply(b2, true); + db1.add(b2).shiftLeft(1).subtract(b0), true, parallel, depth); + vinf = a2.multiply(b2, true, parallel, depth); + v0 = v0_task.join(); + vm1 = vm1_task.join(); + v1 = v1_task.join(); // The algorithm requires two divisions by 2 and one by 3. // All divisions are known to be exact, that is, they do not produce @@ -2071,7 +2184,7 @@ public class BigInteger extends Number implements Comparable<BigInteger> { * @return <code>this<sup>2</sup></code> */ private BigInteger square() { - return square(false); + return square(false, false, 0); } /** @@ -2081,7 +2194,7 @@ public class BigInteger extends Number implements Comparable<BigInteger> { * @param isRecursion whether this is a recursive invocation * @return <code>this<sup>2</sup></code> */ - private BigInteger square(boolean isRecursion) { + private BigInteger square(boolean isRecursion, boolean parallel, int depth) { if (signum == 0) { return ZERO; } @@ -2103,7 +2216,7 @@ public class BigInteger extends Number implements Comparable<BigInteger> { } } - return squareToomCook3(); + return squareToomCook3(parallel, depth); } } } @@ -2237,7 +2350,7 @@ public class BigInteger extends Number implements Comparable<BigInteger> { * that has better asymptotic performance than the algorithm used in * squareToLen or squareKaratsuba. */ - private BigInteger squareToomCook3() { + private BigInteger squareToomCook3(boolean parallel, int depth) { int len = mag.length; // k is the size (in ints) of the lower-order slices. @@ -2254,13 +2367,17 @@ public class BigInteger extends Number implements Comparable<BigInteger> { a0 = getToomSlice(k, r, 2, len); BigInteger v0, v1, v2, vm1, vinf, t1, t2, tm1, da1; - v0 = a0.square(true); + depth++; + var v0_fork = RecursiveOp.square(a0, parallel, depth); da1 = a2.add(a0); - vm1 = da1.subtract(a1).square(true); + var vm1_fork = RecursiveOp.square(da1.subtract(a1), parallel, depth); da1 = da1.add(a1); - v1 = da1.square(true); - vinf = a2.square(true); - v2 = da1.add(a2).shiftLeft(1).subtract(a0).square(true); + var v1_fork = RecursiveOp.square(da1, parallel, depth); + vinf = a2.square(true, parallel, depth); + v2 = da1.add(a2).shiftLeft(1).subtract(a0).square(true, parallel, depth); + v0 = v0_fork.join(); + vm1 = vm1_fork.join(); + v1 = v1_fork.join(); // The algorithm requires two divisions by 2 and one by 3. // All divisions are known to be exact, that is, they do not produce diff --git a/test/jdk/java/math/BigInteger/BigIntegerParallelMultiplyTest.java b/test/jdk/java/math/BigInteger/BigIntegerParallelMultiplyTest.java new file mode 100644 index 00000000000..1396ae06d96 --- /dev/null +++ b/test/jdk/java/math/BigInteger/BigIntegerParallelMultiplyTest.java @@ -0,0 +1,82 @@ +/* + * Copyright (c) 1998, 2020, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @run main BigIntegerParallelMultiplyTest + * @summary tests parallelMultiply() method in BigInteger + * @author Heinz Kabutz heinz@javaspecialists.eu + */ + +import java.math.BigInteger; +import java.util.function.BinaryOperator; + +/** + * This is a simple test class created to ensure that the results + * of multiply() are the same as multiplyParallel(). We calculate + * the Fibonacci numbers using Dijkstra's sum of squares to get + * very large numbers (hundreds of thousands of bits). + * + * @author Heinz Kabutz, heinz@javaspecialists.eu + */ +public class BigIntegerParallelMultiplyTest { + public static BigInteger fibonacci(int n, BinaryOperator<BigInteger> multiplyOperator) { + if (n == 0) return BigInteger.ZERO; + if (n == 1) return BigInteger.ONE; + + int half = (n + 1) / 2; + BigInteger f0 = fibonacci(half - 1, multiplyOperator); + BigInteger f1 = fibonacci(half, multiplyOperator); + if (n % 2 == 1) { + BigInteger b0 = multiplyOperator.apply(f0, f0); + BigInteger b1 = multiplyOperator.apply(f1, f1); + return b0.add(b1); + } else { + BigInteger b0 = f0.shiftLeft(1).add(f1); + return multiplyOperator.apply(b0, f1); + } + } + + public static void main(String[] args) throws Exception { + compare(1000, 324); + compare(10_000, 3473); + compare(100_000, 34883); + compare(1_000_000, 347084); + } + + private static void compare(int n, int expectedBitCount) { + BigInteger multiplyResult = fibonacci(n, BigInteger::multiply); + BigInteger parallelMultiplyResult = fibonacci(n, BigInteger::parallelMultiply); + checkBitCount(n, expectedBitCount, multiplyResult); + checkBitCount(n, expectedBitCount, parallelMultiplyResult); + if (!multiplyResult.equals(parallelMultiplyResult)) + throw new AssertionError("multiply() and parallelMultiply() give different results"); + } + + private static void checkBitCount(int n, int expectedBitCount, BigInteger number) { + if (number.bitCount() != expectedBitCount) + throw new AssertionError( + "bitCount of fibonacci(" + n + ") was expected to be " + expectedBitCount + + " but was " + number.bitCount()); + } +} diff --git a/test/micro/org/openjdk/bench/java/math/BigIntegerMersennePrimeMultiply.java b/test/micro/org/openjdk/bench/java/math/BigIntegerMersennePrimeMultiply.java new file mode 100644 index 00000000000..7ac4cf54dc2 --- /dev/null +++ b/test/micro/org/openjdk/bench/java/math/BigIntegerMersennePrimeMultiply.java @@ -0,0 +1,322 @@ +package org.openjdk.bench.java.math; + +import javax.management.MBeanServer; +import javax.management.MalformedObjectNameException; +import javax.management.ObjectName; +import java.lang.management.ManagementFactory; +import java.lang.management.ThreadMXBean; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.math.BigInteger; +import java.util.Arrays; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Locale; +import java.util.LongSummaryStatistics; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.ForkJoinWorkerThread; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.BinaryOperator; +import java.util.function.LongUnaryOperator; +import java.util.stream.Collectors; + +import static java.util.concurrent.ForkJoinPool.defaultForkJoinWorkerThreadFactory; + +/** + * Benchmark for checking performance difference between sequential and parallel + * multiply of very large Mersenne primes using BigInteger. We want to measure + * real time, user time, system time and the amount of memory allocated. To + * calculate this, we create our own thread factory for the common ForkJoinPool + * and then use that to measure user time, cpu time and bytes allocated. + * <p> + * We use reflection to discover all methods that match "*ultiply", and use them + * to multiply two very large Mersenne primes together. + * <p> + * <h3>Results on a 1-6-2 machine running Ubuntu linux</h3> + * <p> + * Memory allocation increased from 83.9GB to 84GB, for both the sequential and + * parallel versions. This is an increase of just 0.1%. On this machine, the + * parallel version was 3.8x faster in latency (real time), but it used 2.7x + * more CPU resources. + * <p> + * Testing multiplying Mersenne primes of 2^57885161-1 and 2^82589933-1 + * <p> + * <pre> + * openjdk version "18-internal" 2022-03-15 + * BigInteger.parallelMultiply() + * real 0m6.288s + * user 1m3.010s + * sys 0m0.027s + * mem 84.0GB + * BigInteger.multiply() + * real 0m23.682s + * user 0m23.530s + * sys 0m0.004s + * mem 84.0GB + * + * openjdk version "1.8.0_302" + * BigInteger.multiply() + * real 0m25.657s + * user 0m25.390s + * sys 0m0.001s + * mem 83.9GB + * + * openjdk version "9.0.7.1" + * BigInteger.multiply() + * real 0m24.907s + * user 0m24.700s + * sys 0m0.001s + * mem 83.9GB + * + * openjdk version "10.0.2" 2018-07-17 + * BigInteger.multiply() + * real 0m24.632s + * user 0m24.380s + * sys 0m0.004s + * mem 83.9GB + * + * openjdk version "11.0.12" 2021-07-20 LTS + * BigInteger.multiply() + * real 0m22.114s + * user 0m21.930s + * sys 0m0.001s + * mem 83.9GB + * + * openjdk version "12.0.2" 2019-07-16 + * BigInteger.multiply() + * real 0m23.015s + * user 0m22.830s + * sys 0m0.000s + * mem 83.9GB + * + * openjdk version "13.0.9" 2021-10-19 + * BigInteger.multiply() + * real 0m23.548s + * user 0m23.350s + * sys 0m0.005s + * mem 83.9GB + * + * openjdk version "14.0.2" 2020-07-14 + * BigInteger.multiply() + * real 0m22.918s + * user 0m22.530s + * sys 0m0.131s + * mem 83.9GB + * + * openjdk version "15.0.5" 2021-10-19 + * BigInteger.multiply() + * real 0m22.038s + * user 0m21.750s + * sys 0m0.003s + * mem 83.9GB + * + * openjdk version "16.0.2" 2021-07-20 + * BigInteger.multiply() + * real 0m23.049s + * user 0m22.760s + * sys 0m0.006s + * mem 83.9GB + * + * openjdk version "17" 2021-09-14 + * BigInteger.multiply() + * real 0m22.580s + * user 0m22.310s + * sys 0m0.001s + * mem 83.9GB + *</pre> + * + * @author Heinz Kabutz, heinz@javaspecialists.eu + */ +public class BigIntegerMersennePrimeMultiply implements ForkJoinPool.ForkJoinWorkerThreadFactory { + // Large Mersenne prime discovered by Curtis Cooper in 2013 + private static final int EXPONENT_1 = 57885161; + private static final BigInteger MERSENNE_1 = + BigInteger.ONE.shiftLeft(EXPONENT_1).subtract(BigInteger.ONE); + // Largest Mersenne prime number discovered by Patrick Laroche in 2018 + private static final int EXPONENT_2 = 82589933; + private static final BigInteger MERSENNE_2 = + BigInteger.ONE.shiftLeft(EXPONENT_2).subtract(BigInteger.ONE); + private static boolean DEBUG = false; + + public static void main(String... args) { + System.setProperty("java.util.concurrent.ForkJoinPool.common.threadFactory", + BigIntegerMersennePrimeMultiply.class.getName()); + System.out.println("Testing multiplying Mersenne primes of " + + "2^" + EXPONENT_1 + "-1 and 2^" + EXPONENT_2 + "-1"); + addCounters(Thread.currentThread()); + System.out.println("Using the following multiply methods:"); + List<Method> methods = Arrays.stream(BigInteger.class.getMethods()) + .filter(method -> method.getName().endsWith("ultiply") && + method.getParameterCount() == 1 && + method.getParameterTypes()[0] == BigInteger.class) + .peek(method -> System.out.println(" " + method)) + .collect(Collectors.toList()); + + for (int i = 0; i < 3; i++) { + System.out.println(); + methods.forEach(BigIntegerMersennePrimeMultiply::test); + } + } + + private static void test(Method method) { + BinaryOperator<BigInteger> multiplyOperator = (a, b) -> { + try { + return (BigInteger) method.invoke(a, b); + } catch (IllegalAccessException e) { + throw new AssertionError(e); + } catch (InvocationTargetException e) { + throw new AssertionError(e.getCause()); + } + }; + test(method.getName(), multiplyOperator); + } + + private static void test(String description, + BinaryOperator<BigInteger> multiplyOperator) { + System.out.println("BigInteger." + description + "()"); + resetAllCounters(); + long elapsedTimeInNanos = System.nanoTime(); + try { + BigInteger result1 = multiplyOperator.apply(MERSENNE_1, MERSENNE_2); + BigInteger result2 = multiplyOperator.apply(MERSENNE_2, MERSENNE_1); + if (result1.bitLength() != 140475094) + throw new AssertionError("Expected bitLength: 140475094, " + + "but was " + result1.bitLength()); + if (result2.bitLength() != 140475094) + throw new AssertionError("Expected bitLength: 140475094, " + + "but was " + result1.bitLength()); + } finally { + elapsedTimeInNanos = System.nanoTime() - elapsedTimeInNanos; + } + + LongSummaryStatistics userTimeStatistics = getStatistics(userTime); + LongSummaryStatistics cpuTimeStatistics = getStatistics(cpuTime); + LongSummaryStatistics memoryAllocationStatistics = getStatistics(bytes); + System.out.println("real " + formatTime(elapsedTimeInNanos)); + System.out.println("user " + formatTime(userTimeStatistics.getSum())); + System.out.println("sys " + + formatTime(cpuTimeStatistics.getSum() - userTimeStatistics.getSum())); + System.out.println("mem " + formatMemory(memoryAllocationStatistics.getSum(), 1)); + } + + private static LongSummaryStatistics getStatistics(Map<Thread, AtomicLong> timeMap) { + return timeMap.entrySet() + .stream() + .peek(entry -> { + long timeInMs = (counterExtractorMap.get(timeMap) + .applyAsLong(entry.getKey().getId()) + - entry.getValue().get()); + entry.getValue().set(timeInMs); + }) + .peek(BigIntegerMersennePrimeMultiply::printTime) + .map(Map.Entry::getValue) + .mapToLong(AtomicLong::get) + .summaryStatistics(); + } + + private static void printTime(Map.Entry<Thread, AtomicLong> threadCounter) { + if (DEBUG) + System.out.printf("%s %d%n", threadCounter.getKey(), threadCounter.getValue() + .get()); + } + + private static void addCounters(Thread thread) { + counterExtractorMap.forEach((map, timeExtractor) -> add(map, thread, timeExtractor)); + } + + private static void add(Map<Thread, AtomicLong> time, Thread thread, + LongUnaryOperator timeExtractor) { + time.put(thread, new AtomicLong(timeExtractor.applyAsLong(thread.getId()))); + } + + private static void resetAllCounters() { + counterExtractorMap.forEach(BigIntegerMersennePrimeMultiply::resetTimes); + } + + private static void resetTimes(Map<Thread, AtomicLong> timeMap, LongUnaryOperator timeMethod) { + timeMap.forEach((thread, time) -> + time.set(timeMethod.applyAsLong(thread.getId()))); + } + + private static final Map<Thread, AtomicLong> userTime = + new ConcurrentHashMap<>(); + private static final Map<Thread, AtomicLong> cpuTime = + new ConcurrentHashMap<>(); + private static final Map<Thread, AtomicLong> bytes = + new ConcurrentHashMap<>(); + private static final ThreadMXBean tmb = ManagementFactory.getThreadMXBean(); + + private static final Map<Map<Thread, AtomicLong>, LongUnaryOperator> counterExtractorMap = + new IdentityHashMap<>(); + + static { + counterExtractorMap.put(userTime, tmb::getThreadUserTime); + counterExtractorMap.put(cpuTime, tmb::getThreadCpuTime); + counterExtractorMap.put(bytes, BigIntegerMersennePrimeMultiply::threadAllocatedBytes); + } + + public final ForkJoinWorkerThread newThread(ForkJoinPool pool) { + ForkJoinWorkerThread thread = defaultForkJoinWorkerThreadFactory.newThread(pool); + addCounters(thread); + return thread; + } + + private static final String[] SIGNATURE = new String[]{long.class.getName()}; + private static final MBeanServer mBeanServer; + private static final ObjectName name; + + static { + try { + name = new ObjectName(ManagementFactory.THREAD_MXBEAN_NAME); + mBeanServer = ManagementFactory.getPlatformMBeanServer(); + } catch (MalformedObjectNameException e) { + throw new ExceptionInInitializerError(e); + } + } + + public static long threadAllocatedBytes(long threadId) { + try { + return (long) mBeanServer.invoke( + name, + "getThreadAllocatedBytes", + new Object[]{threadId}, + SIGNATURE + ); + } catch (Exception e) { + throw new IllegalArgumentException(e); + } + } + + public static String formatMemory(double bytes, int decimals) { + double val; + String unitStr; + if (bytes < 1024) { + val = bytes; + unitStr = "B"; + } else if (bytes < 1024 * 1024) { + val = bytes / 1024; + unitStr = "KB"; + } else if (bytes < 1024 * 1024 * 1024) { + val = bytes / (1024 * 1024); + unitStr = "MB"; + } else if (bytes < 1024 * 1024 * 1024 * 1024L) { + val = bytes / (1024 * 1024 * 1024L); + unitStr = "GB"; + } else { + val = bytes / (1024 * 1024 * 1024 * 1024L); + unitStr = "TB"; + } + return String.format(Locale.US, "%." + decimals + "f%s", val, unitStr); + } + + public static String formatTime(long nanos) { + if (nanos < 0) nanos = 0; + long timeInMs = TimeUnit.NANOSECONDS.toMillis(nanos); + long minutes = timeInMs / 60_000; + double remainingMs = (timeInMs % 60_000) / 1000.0; + return String.format(Locale.US, "%dm%.3fs", minutes, remainingMs); + } +} \ No newline at end of file diff --git a/test/micro/org/openjdk/bench/java/math/BigIntegerParallelMultiply.java b/test/micro/org/openjdk/bench/java/math/BigIntegerParallelMultiply.java new file mode 100644 index 00000000000..92a1c5a8237 --- /dev/null +++ b/test/micro/org/openjdk/bench/java/math/BigIntegerParallelMultiply.java @@ -0,0 +1,61 @@ +package org.openjdk.bench.java.math; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.math.BigInteger; +import java.util.concurrent.TimeUnit; +import java.util.function.BinaryOperator; + +/** + * Benchmark for checking performance difference between + * sequential and parallel multiply methods in BigInteger, + * using a large Fibonacci calculation of up to n = 100 million. + * + * @author Heinz Kabutz, heinz@javaspecialists.eu + */ +@BenchmarkMode(Mode.SingleShotTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Fork(value = 2) +@Warmup(iterations = 2) +@Measurement(iterations = 2) // only 2 iterations because each one takes very long +@State(Scope.Thread) +public class BigIntegerParallelMultiply { + private static BigInteger fibonacci(int n, BinaryOperator<BigInteger> multiplyOperator) { + if (n == 0) return BigInteger.ZERO; + if (n == 1) return BigInteger.ONE; + + int half = (n + 1) / 2; + BigInteger f0 = fibonacci(half - 1, multiplyOperator); + BigInteger f1 = fibonacci(half, multiplyOperator); + if (n % 2 == 1) { + BigInteger b0 = multiplyOperator.apply(f0, f0); + BigInteger b1 = multiplyOperator.apply(f1, f1); + return b0.add(b1); + } else { + BigInteger b0 = f0.shiftLeft(1).add(f1); + return multiplyOperator.apply(b0, f1); + } + } + + @Param({"1000000", "10000000", "100000000"}) + private int n; + + @Benchmark + public void multiply() { + fibonacci(n, BigInteger::multiply); + } + + @Benchmark + public void parallelMultiply() { + fibonacci(n, BigInteger::parallelMultiply); + } +} \ No newline at end of file