8284493: Improve computeNextExponential tail performance and accuracy

Reviewed-by: darcy
This commit is contained in:
Chris Hennick 2023-07-17 19:19:13 +00:00 committed by Joe Darcy
parent e737968792
commit f975be44a8
4 changed files with 144 additions and 16 deletions

View File

@ -1137,10 +1137,33 @@ public class RandomSupport {
}
}
/**
* Largest value that {@link #computeNextExponential} can ever return.
*/
private static final double MAX_EXPONENTIAL = 0x1.0p63 * DoubleZigguratTables.exponentialX0;
/**
* Implementation support for the {@code nextExponential} method of
* {@link java.util.random.RandomGenerator}.
*
* @param rng an instance of {@code RandomGenerator}, used to generate uniformly
* pseudorandomly chosen {@code long} values
*
* @return a nonnegative {@code double} value chosen pseudorandomly
* from an exponential distribution whose mean is 1
*/
public static double computeNextExponential(RandomGenerator rng) {
return computeNextExponentialSoftCapped(rng, MAX_EXPONENTIAL);
}
/**
* Generates a pseudorandom value {@code x} such that {@code Math.min(x, maxValue)}
* follows the same distribution as it would if {@code x} was exponentially distributed
* with mean 1, but with a worst-case number of calls to {@link
* RandomGenerator#nextLong()} that's linear with {@code maxValue}. {@code maxValue} is
* a "soft" cap in that a value larger than {@code maxValue} may be returned in order
* to save a calculation.
*
* Certain details of the algorithm used in this method may depend critically
* on the quality of the low-order bits delivered by {@code nextLong()}. This method
* should not be used with RNG algorithms (such as a simple Linear Congruential
@ -1157,7 +1180,7 @@ public class RandomSupport {
* @return a nonnegative {@code double} value chosen pseudorandomly
* from an exponential distribution whose mean is 1
*/
public static double computeNextExponential(RandomGenerator rng) {
public static double computeNextExponentialSoftCapped(RandomGenerator rng, double maxValue) {
/*
* The tables themselves, as well as a number of associated parameters, are
* defined in class DoubleZigguratTables, which is automatically
@ -1186,9 +1209,18 @@ public class RandomSupport {
// This is the fast path (occurring more than 98% of the time). Make an early exit.
return DoubleZigguratTables.exponentialX[(int)i] * (U1 >>> 1);
}
// We didn't use the upper part of U1 after all. We'll be able to use it later.
for (double extra = 0.0; ; ) {
// We didn't use the upper part of U1 after all. We'll probably be able to use it later.
if (maxValue <= 0.0) {
return 0.0;
}
final long maxExtraMinus1;
if (maxValue >= MAX_EXPONENTIAL) {
maxExtraMinus1 = Long.MAX_VALUE;
} else {
// Conversion to long rounds toward zero
maxExtraMinus1 = (long) (maxValue / DoubleZigguratTables.exponentialX0);
}
for (long extra = 0; ; ) {
// Use Walker's alias method to sample an (unsigned) integer j from a discrete
// probability distribution that includes the tail and all the ziggurat overhangs;
// j will be less than DoubleZigguratTables.exponentialNumberOfLayers + 1.
@ -1219,25 +1251,29 @@ public class RandomSupport {
// Compute the actual x-coordinate of the randomly chosen point.
double x = (X[j] * 0x1.0p63) + ((X[j-1] - X[j]) * (double)U1);
if (Udiff >= DoubleZigguratTables.exponentialConvexMargin) {
return x + extra; // The chosen point is way below the curve; accept it.
return Math.fma(extra, DoubleZigguratTables.exponentialX0, x); // The chosen point is way below the curve; accept it.
}
// Compute the actual y-coordinate of the randomly chosen point.
double y = (Y[j] * 0x1.0p63) + ((Y[j-1] - Y[j]) * (double)U2);
// Now see how that y-coordinate compares to the curve
if (y <= Math.exp(-x)) {
return x + extra; // The chosen point is below the curve; accept it.
return Math.fma(extra, DoubleZigguratTables.exponentialX0, x); // The chosen point is below the curve; accept it.
}
// Otherwise, we reject this sample and have to try again.
}
}
if (extra == maxExtraMinus1) {
// We've reached the maximum, so don't waste any more time
return maxValue;
}
// We are now committed to sampling from the tail. We could do a recursive call
// and then add X[0], but we save some time and stack space by using an iterative loop.
extra += DoubleZigguratTables.exponentialX0;
// This is like the first five lines of this method, but if it returns, it first adds "extra".
extra++;
// This is like the first five lines of this method, but if it returns, it first adds "extra" times X0.
U1 = rng.nextLong();
i = U1 & DoubleZigguratTables.exponentialLayerMask;
if (i < DoubleZigguratTables.exponentialNumberOfLayers) {
return DoubleZigguratTables.exponentialX[(int)i] * (U1 >>> 1) + extra;
return Math.fma(extra, DoubleZigguratTables.exponentialX0, DoubleZigguratTables.exponentialX[(int)i] * (U1 >>> 1));
}
}
}
@ -1359,9 +1395,11 @@ public class RandomSupport {
// ACM Comput. Surv. 39, 4, Article 11 (November 2007). See Algorithm 16.
// http://doi.org/10.1145/1287620.1287622
// Compute two separate random exponential samples and then compare them in certain way.
double limit;
do {
x = (1.0 / DoubleZigguratTables.normalX0) * computeNextExponential(rng);
} while (computeNextExponential(rng) < 0.5*x*x);
limit = 0.5*x*x;
} while (computeNextExponentialSoftCapped(rng, limit) < limit);
x += DoubleZigguratTables.normalX0;
} else if (j < DoubleZigguratTables.normalInflectionIndex) { // Convex overhang
for (;; U1 = (rng.nextLong() >>> 1)) {
@ -1392,7 +1430,7 @@ public class RandomSupport {
for (;; U1 = (rng.nextLong() >>> 1)) {
long U2 = (rng.nextLong() >>> 1);
// Compute the actual x-coordinate of the randomly chosen point.
x = (X[j] * 0x1.0p63) + ((X[j-1] - X[j]) * (double)U1);
x = Math.fma(X[j-1] - X[j], (double)U1, X[j] * 0x1.0p63);
// Does the point lie below the curve?
long Udiff = U2 - U1;
if (Udiff >= DoubleZigguratTables.normalConvexMargin) {
@ -1402,7 +1440,7 @@ public class RandomSupport {
continue; // The chosen point is way above the curve; reject it.
}
// Compute the actual y-coordinate of the randomly chosen point.
double y = (Y[j] * 0x1.0p63) + ((Y[j-1] - Y[j]) * (double)U2);
double y = Math.fma(Y[j-1] - Y[j], (double)U2, Y[j] * 0x1.0p63);
// Now see how that y-coordinate compares to the curve
if (y <= Math.exp(-0.5*x*x)) {
break; // The chosen point is below the curve; accept it.

View File

@ -0,0 +1,90 @@
/*
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package org.openjdk.bench.java.util.random;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Param;
import java.util.random.RandomGenerator;
import java.util.random.RandomGeneratorFactory;
import java.util.concurrent.TimeUnit;
/**
* Tests java.util.random.RandomGenerator's implementations of nextExponential and nextGaussian
*/
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@State(Scope.Thread)
public class RandomGeneratorExponentialGaussian {
RandomGenerator randomGenerator;
@Param({"L64X128MixRandom", "L64X1024MixRandom"})
String randomGeneratorName;
@Param({"false","true"})
boolean fixedSeed;
double[] buffer;
@Param("1024")
int size;
@Setup
public void setup() {
buffer = new double[size];
RandomGeneratorFactory factory = RandomGeneratorFactory.of(randomGeneratorName);
randomGenerator = fixedSeed ? factory.create(randomGeneratorName.hashCode()) : factory.create();
}
@Benchmark
@BenchmarkMode({Mode.SampleTime, Mode.AverageTime})
public double testNextGaussian() {
return randomGenerator.nextGaussian();
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
public double[] testFillBufferWithNextGaussian() {
for (int i = 0; i < size; i++) buffer[i] = randomGenerator.nextGaussian();
return buffer;
}
@Benchmark
@BenchmarkMode({Mode.SampleTime, Mode.AverageTime})
public double testNextExponential() {
return randomGenerator.nextExponential();
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
public double[] testFillBufferWithNextExponential() {
for (int i = 0; i < size; i++) buffer[i] = randomGenerator.nextExponential();
return buffer;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2021, 2022, 2023, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -20,7 +20,7 @@
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package org.openjdk.bench.java.util;
package org.openjdk.bench.java.util.random;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2014, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2014, 2022, 2023, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -20,7 +20,7 @@
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package org.openjdk.bench.java.util;
package org.openjdk.bench.java.util.random;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;