4851642: Add fused multiply add to Java math library

Reviewed-by: bpb, nadezhin
This commit is contained in:
Joe Darcy 2016-04-15 10:14:57 -07:00
parent 07cdc33e34
commit 965536262b
4 changed files with 716 additions and 1 deletions

View File

@ -25,6 +25,7 @@
package java.lang;
import java.math.BigDecimal;
import java.util.Random;
import jdk.internal.math.FloatConsts;
import jdk.internal.math.DoubleConsts;
@ -1449,6 +1450,199 @@ public final class Math {
return (a <= b) ? a : b;
}
/**
* Returns the fused multiply add of the three arguments; that is,
* returns the exact product of the first two arguments summed
* with the third argument and then rounded once to the nearest
* {@code double}.
*
* The rounding is done using the {@linkplain
* java.math.RoundingMode#HALF_EVEN round to nearest even
* rounding mode}.
*
* In contrast, if {@code a * b + c} is evaluated as a regular
* floating-point expression, two rounding errors are involved,
* the first for the multiply operation, the second for the
* addition operation.
*
* <p>Special cases:
* <ul>
* <li> If any argument is NaN, the result is NaN.
*
* <li> If one of the first two arguments is infinite and the
* other is zero, the result is NaN.
*
* <li> If the exact product of the first two arguments is infinite
* (in other words, at least one of the arguments is infinite and
* the other is neither zero nor NaN) and the third argument is an
* infinity of the opposite sign, the result is NaN.
*
* </ul>
*
* <p>Note that {@code fma(a, 1.0, c)} returns the same
* result as ({@code a + c}). However,
* {@code fma(a, b, +0.0)} does <em>not</em> always return the
* same result as ({@code a * b}) since
* {@code fma(-0.0, +0.0, +0.0)} is {@code +0.0} while
* ({@code -0.0 * +0.0}) is {@code -0.0}; {@code fma(a, b, -0.0)} is
* equivalent to ({@code a * b}) however.
*
* @apiNote This method corresponds to the fusedMultiplyAdd
* operation defined in IEEE 754-2008.
*
* @param a a value
* @param b a value
* @param c a value
*
* @return (<i>a</i>&nbsp;&times;&nbsp;<i>b</i>&nbsp;+&nbsp;<i>c</i>)
* computed, as if with unlimited range and precision, and rounded
* once to the nearest {@code double} value
*/
// @HotSpotIntrinsicCandidate
public static double fma(double a, double b, double c) {
/*
* Infinity and NaN arithmetic is not quite the same with two
* roundings as opposed to just one so the simple expression
* "a * b + c" cannot always be used to compute the correct
* result. With two roundings, the product can overflow and
* if the addend is infinite, a spurious NaN can be produced
* if the infinity from the overflow and the infinite addend
* have opposite signs.
*/
// First, screen for and handle non-finite input values whose
// arithmetic is not supported by BigDecimal.
if (Double.isNaN(a) || Double.isNaN(b) || Double.isNaN(c)) {
return Double.NaN;
} else { // All inputs non-NaN
boolean infiniteA = Double.isInfinite(a);
boolean infiniteB = Double.isInfinite(b);
boolean infiniteC = Double.isInfinite(c);
double result;
if (infiniteA || infiniteB || infiniteC) {
if (infiniteA && b == 0.0 ||
infiniteB && a == 0.0 ) {
return Double.NaN;
}
// Store product in a double field to cause an
// overflow even if non-strictfp evaluation is being
// used.
double product = a * b;
if (Double.isInfinite(product) && !infiniteA && !infiniteB) {
// Intermediate overflow; might cause a
// spurious NaN if added to infinite c.
assert Double.isInfinite(c);
return c;
} else {
result = product + c;
assert !Double.isFinite(result);
return result;
}
} else { // All inputs finite
BigDecimal product = (new BigDecimal(a)).multiply(new BigDecimal(b));
if (c == 0.0) { // Positive or negative zero
// If the product is an exact zero, use a
// floating-point expression to compute the sign
// of the zero final result. The product is an
// exact zero if and only if at least one of a and
// b is zero.
if (a == 0.0 || b == 0.0) {
return a * b + c;
} else {
// The sign of a zero addend doesn't matter if
// the product is nonzero. The sign of a zero
// addend is not factored in the result if the
// exact product is nonzero but underflows to
// zero; see IEEE-754 2008 section 6.3 "The
// sign bit".
return product.doubleValue();
}
} else {
return product.add(new BigDecimal(c)).doubleValue();
}
}
}
}
/**
* Returns the fused multiply add of the three arguments; that is,
* returns the exact product of the first two arguments summed
* with the third argument and then rounded once to the nearest
* {@code float}.
*
* The rounding is done using the {@linkplain
* java.math.RoundingMode#HALF_EVEN round to nearest even
* rounding mode}.
*
* In contrast, if {@code a * b + c} is evaluated as a regular
* floating-point expression, two rounding errors are involved,
* the first for the multiply operation, the second for the
* addition operation.
*
* <p>Special cases:
* <ul>
* <li> If any argument is NaN, the result is NaN.
*
* <li> If one of the first two arguments is infinite and the
* other is zero, the result is NaN.
*
* <li> If the exact product of the first two arguments is infinite
* (in other words, at least one of the arguments is infinite and
* the other is neither zero nor NaN) and the third argument is an
* infinity of the opposite sign, the result is NaN.
*
* </ul>
*
* <p>Note that {@code fma(a, 1.0f, c)} returns the same
* result as ({@code a + c}). However,
* {@code fma(a, b, +0.0f)} does <em>not</em> always return the
* same result as ({@code a * b}) since
* {@code fma(-0.0f, +0.0f, +0.0f)} is {@code +0.0f} while
* ({@code -0.0f * +0.0f}) is {@code -0.0f}; {@code fma(a, b, -0.0f)} is
* equivalent to ({@code a * b}) however.
*
* @apiNote This method corresponds to the fusedMultiplyAdd
* operation defined in IEEE 754-2008.
*
* @param a a value
* @param b a value
* @param c a value
*
* @return (<i>a</i>&nbsp;&times;&nbsp;<i>b</i>&nbsp;+&nbsp;<i>c</i>)
* computed, as if with unlimited range and precision, and rounded
* once to the nearest {@code float} value
*/
// @HotSpotIntrinsicCandidate
public static float fma(float a, float b, float c) {
/*
* Since the double format has more than twice the precision
* of the float format, the multiply of a * b is exact in
* double. The add of c to the product then incurs one
* rounding error. Since the double format moreover has more
* than (2p + 2) precision bits compared to the p bits of the
* float format, the two roundings of (a * b + c), first to
* the double format and then secondarily to the float format,
* are equivalent to rounding the intermediate result directly
* to the float format.
*
* In terms of strictfp vs default-fp concerns related to
* overflow and underflow, since
*
* (Float.MAX_VALUE * Float.MAX_VALUE) << Double.MAX_VALUE
* (Float.MIN_VALUE * Float.MIN_VALUE) >> Double.MIN_VALUE
*
* neither the multiply nor add will overflow or underflow in
* double. Therefore, it is not necessary for this method to
* be declared strictfp to have reproducible
* behavior. However, it is necessary to explicitly store down
* to a float variable to avoid returning a value in the float
* extended value set.
*/
float result = (float)(((double) a * (double) b ) + (double) c);
return result;
}
/**
* Returns the size of an ulp of the argument. An ulp, unit in
* the last place, of a {@code double} value is the positive

View File

@ -1134,6 +1134,110 @@ public final class StrictMath {
return Math.min(a, b);
}
/**
* Returns the fused multiply add of the three arguments; that is,
* returns the exact product of the first two arguments summed
* with the third argument and then rounded once to the nearest
* {@code double}.
*
* The rounding is done using the {@linkplain
* java.math.RoundingMode#HALF_EVEN round to nearest even
* rounding mode}.
*
* In contrast, if {@code a * b + c} is evaluated as a regular
* floating-point expression, two rounding errors are involved,
* the first for the multiply operation, the second for the
* addition operation.
*
* <p>Special cases:
* <ul>
* <li> If any argument is NaN, the result is NaN.
*
* <li> If one of the first two arguments is infinite and the
* other is zero, the result is NaN.
*
* <li> If the exact product of the first two arguments is infinite
* (in other words, at least one of the arguments is infinite and
* the other is neither zero nor NaN) and the third argument is an
* infinity of the opposite sign, the result is NaN.
*
* </ul>
*
* <p>Note that {@code fusedMac(a, 1.0, c)} returns the same
* result as ({@code a + c}). However,
* {@code fusedMac(a, b, +0.0)} does <em>not</em> always return the
* same result as ({@code a * b}) since
* {@code fusedMac(-0.0, +0.0, +0.0)} is {@code +0.0} while
* ({@code -0.0 * +0.0}) is {@code -0.0}; {@code fusedMac(a, b, -0.0)} is
* equivalent to ({@code a * b}) however.
*
* @apiNote This method corresponds to the fusedMultiplyAdd
* operation defined in IEEE 754-2008.
*
* @param a a value
* @param b a value
* @param c a value
*
* @return (<i>a</i>&nbsp;&times;&nbsp;<i>b</i>&nbsp;+&nbsp;<i>c</i>)
* computed, as if with unlimited range and precision, and rounded
* once to the nearest {@code double} value
*/
public static double fma(double a, double b, double c) {
return Math.fma(a, b, c);
}
/**
* Returns the fused multiply add of the three arguments; that is,
* returns the exact product of the first two arguments summed
* with the third argument and then rounded once to the nearest
* {@code float}.
*
* The rounding is done using the {@linkplain
* java.math.RoundingMode#HALF_EVEN round to nearest even
* rounding mode}.
*
* In contrast, if {@code a * b + c} is evaluated as a regular
* floating-point expression, two rounding errors are involved,
* the first for the multiply operation, the second for the
* addition operation.
*
* <p>Special cases:
* <ul>
* <li> If any argument is NaN, the result is NaN.
*
* <li> If one of the first two arguments is infinite and the
* other is zero, the result is NaN.
*
* <li> If the exact product of the first two arguments is infinite
* (in other words, at least one of the arguments is infinite and
* the other is neither zero nor NaN) and the third argument is an
* infinity of the opposite sign, the result is NaN.
*
* </ul>
*
* <p>Note that {@code fma(a, 1.0f, c)} returns the same
* result as ({@code a + c}). However,
* {@code fma(a, b, +0.0f)} does <em>not</em> always return the
* same result as ({@code a * b}) since
* {@code fma(-0.0f, +0.0f, +0.0f)} is {@code +0.0f} while
* ({@code -0.0f * +0.0f}) is {@code -0.0f}; {@code fma(a, b, -0.0f)} is
* equivalent to ({@code a * b}) however.
*
* @apiNote This method corresponds to the fusedMultiplyAdd
* operation defined in IEEE 754-2008.
*
* @param a a value
* @param b a value
* @param c a value
*
* @return (<i>a</i>&nbsp;&times;&nbsp;<i>b</i>&nbsp;+&nbsp;<i>c</i>)
* computed, as if with unlimited range and precision, and rounded
* once to the nearest {@code float} value
*/
public static float fma(float a, float b, float c) {
return Math.fma(a, b, c);
}
/**
* Returns the size of an ulp of the argument. An ulp, unit in
* the last place, of a {@code double} value is the positive

View File

@ -0,0 +1,385 @@
/*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
/*
* @test
* @bug 4851642
* @summary Tests for Math.fusedMac and StrictMath.fusedMac.
* @build Tests
* @build FusedMultiplyAddTests
* @run main FusedMultiplyAddTests
*/
/**
* The specifications for both Math.fusedMac and StrictMath.fusedMac
* are the same and both are exactly specified. Therefore, both
* methods are tested in this file.
*/
public class FusedMultiplyAddTests {
private FusedMultiplyAddTests(){}
private static final double Infinity = Double.POSITIVE_INFINITY;
private static final float InfinityF = Float.POSITIVE_INFINITY;
private static final double NaN = Double.NaN;
private static final float NaNf = Float.NaN;
public static void main(String... args) {
int failures = 0;
failures += testNonFiniteD();
failures += testZeroesD();
failures += testSimpleD();
failures += testNonFiniteF();
failures += testZeroesF();
failures += testSimpleF();
if (failures > 0) {
System.err.println("Testing fma incurred "
+ failures + " failures.");
throw new RuntimeException();
}
}
private static int testNonFiniteD() {
int failures = 0;
double [][] testCases = {
{Infinity, Infinity, Infinity,
Infinity,
},
{-Infinity, Infinity, -Infinity,
-Infinity,
},
{-Infinity, Infinity, Infinity,
NaN,
},
{Infinity, Infinity, -Infinity,
NaN,
},
{1.0, Infinity, 2.0,
Infinity,
},
{1.0, 2.0, Infinity,
Infinity,
},
{Infinity, 1.0, Infinity,
Infinity,
},
{Double.MAX_VALUE, 2.0, -Infinity,
-Infinity},
{Infinity, 1.0, -Infinity,
NaN,
},
{-Infinity, 1.0, Infinity,
NaN,
},
{1.0, NaN, 2.0,
NaN,
},
{1.0, 2.0, NaN,
NaN,
},
{Infinity, 2.0, NaN,
NaN,
},
{NaN, 2.0, Infinity,
NaN,
},
};
for (double[] testCase: testCases)
failures += testFusedMacCase(testCase[0], testCase[1], testCase[2], testCase[3]);
return failures;
}
private static int testZeroesD() {
int failures = 0;
double [][] testCases = {
{+0.0, +0.0, +0.0,
+0.0,
},
{-0.0, +0.0, +0.0,
+0.0,
},
{+0.0, +0.0, -0.0,
+0.0,
},
{+0.0, +0.0, -0.0,
+0.0,
},
{-0.0, +0.0, -0.0,
-0.0,
},
{-0.0, -0.0, -0.0,
+0.0,
},
{-1.0, +0.0, -0.0,
-0.0,
},
{-1.0, +0.0, +0.0,
+0.0,
},
{-2.0, +0.0, -0.0,
-0.0,
},
{-2.0, +0.0, +0.0,
+0.0,
},
};
for (double[] testCase: testCases)
failures += testFusedMacCase(testCase[0], testCase[1], testCase[2], testCase[3]);
return failures;
}
private static int testSimpleD() {
int failures = 0;
double [][] testCases = {
{1.0, 2.0, 3.0,
5.0,},
{1.0, 2.0, -2.0,
0.0,},
{5.0, 5.0, -25.0,
0.0,},
{Double.MAX_VALUE, 2.0, -Double.MAX_VALUE,
Double.MAX_VALUE},
{Double.MAX_VALUE, 2.0, 1.0,
Infinity},
{Double.MIN_VALUE, -Double.MIN_VALUE, +0.0,
-0.0},
{Double.MIN_VALUE, -Double.MIN_VALUE, -0.0,
-0.0},
{Double.MIN_VALUE, Double.MIN_VALUE, +0.0,
+0.0},
{Double.MIN_VALUE, Double.MIN_VALUE, -0.0,
+0.0},
{Double.MIN_VALUE, +0.0, -0.0,
+0.0},
{Double.MIN_VALUE, -0.0, -0.0,
-0.0},
{Double.MIN_VALUE, +0.0, +0.0,
+0.0},
{Double.MIN_VALUE, -0.0, +0.0,
+0.0},
};
for (double[] testCase: testCases)
failures += testFusedMacCase(testCase[0], testCase[1], testCase[2], testCase[3]);
return failures;
}
private static int testNonFiniteF() {
int failures = 0;
float [][] testCases = {
{1.0f, InfinityF, 2.0f,
InfinityF,
},
{1.0f, 2.0f, InfinityF,
InfinityF,
},
{InfinityF, 1.0f, InfinityF,
InfinityF,
},
{Float.MAX_VALUE, 2.0f, -InfinityF,
-InfinityF},
{InfinityF, 1.0f, -InfinityF,
NaNf,
},
{-InfinityF, 1.0f, InfinityF,
NaNf,
},
{1.0f, NaNf, 2.0f,
NaNf,
},
{1.0f, 2.0f, NaNf,
NaNf,
},
{InfinityF, 2.0f, NaNf,
NaNf,
},
{NaNf, 2.0f, InfinityF,
NaNf,
},
};
for (float[] testCase: testCases)
failures += testFusedMacCase(testCase[0], testCase[1], testCase[2], testCase[3]);
return failures;
}
private static int testZeroesF() {
int failures = 0;
float [][] testCases = {
{+0.0f, +0.0f, +0.0f,
+0.0f,
},
{-0.0f, +0.0f, +0.0f,
+0.0f,
},
{+0.0f, +0.0f, -0.0f,
+0.0f,
},
{+0.0f, +0.0f, -0.0f,
+0.0f,
},
{-0.0f, +0.0f, -0.0f,
-0.0f,
},
{-0.0f, -0.0f, -0.0f,
+0.0f,
},
{-1.0f, +0.0f, -0.0f,
-0.0f,
},
{-1.0f, +0.0f, +0.0f,
+0.0f,
},
{-2.0f, +0.0f, -0.0f,
-0.0f,
},
};
for (float[] testCase: testCases)
failures += testFusedMacCase(testCase[0], testCase[1], testCase[2], testCase[3]);
return failures;
}
private static int testSimpleF() {
int failures = 0;
float [][] testCases = {
{1.0f, 2.0f, 3.0f,
5.0f,},
{1.0f, 2.0f, -2.0f,
0.0f,},
{5.0f, 5.0f, -25.0f,
0.0f,},
{Float.MAX_VALUE, 2.0f, -Float.MAX_VALUE,
Float.MAX_VALUE},
{Float.MAX_VALUE, 2.0f, 1.0f,
InfinityF},
};
for (float[] testCase: testCases)
failures += testFusedMacCase(testCase[0], testCase[1], testCase[2], testCase[3]);
return failures;
}
private static int testFusedMacCase(double input1, double input2, double input3, double expected) {
int failures = 0;
failures += Tests.test("Math.fma(double)", input1, input2, input3,
Math.fma(input1, input2, input3), expected);
failures += Tests.test("StrictMath.fma(double)", input1, input2, input3,
StrictMath.fma(input1, input2, input3), expected);
// Permute first two inputs
failures += Tests.test("Math.fma(double)", input2, input1, input3,
Math.fma(input2, input1, input3), expected);
failures += Tests.test("StrictMath.fma(double)", input2, input1, input3,
StrictMath.fma(input2, input1, input3), expected);
return failures;
}
private static int testFusedMacCase(float input1, float input2, float input3, float expected) {
int failures = 0;
failures += Tests.test("Math.fma(float)", input1, input2, input3,
Math.fma(input1, input2, input3), expected);
failures += Tests.test("StrictMath.fma(float)", input1, input2, input3,
StrictMath.fma(input1, input2, input3), expected);
// Permute first two inputs
failures += Tests.test("Math.fma(float)", input2, input1, input3,
Math.fma(input2, input1, input3), expected);
failures += Tests.test("StrictMath.fma(float)", input2, input1, input3,
StrictMath.fma(input2, input1, input3), expected);
return failures;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2003, 2012, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2003, 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -391,6 +391,38 @@ public class Tests {
return 0;
}
public static int test(String testName,
float input1, float input2, float input3,
float result, float expected) {
if (Float.compare(expected, result ) != 0) {
System.err.println("Failure for " + testName + ":\n" +
"\tFor inputs " + input1 + "\t(" + toHexString(input1) + ") and "
+ input2 + "\t(" + toHexString(input2) + ") and"
+ input3 + "\t(" + toHexString(input3) + ")\n" +
"\texpected " + expected + "\t(" + toHexString(expected) + ")\n" +
"\tgot " + result + "\t(" + toHexString(result) + ").");
return 1;
}
else
return 0;
}
public static int test(String testName,
double input1, double input2, double input3,
double result, double expected) {
if (Double.compare(expected, result ) != 0) {
System.err.println("Failure for " + testName + ":\n" +
"\tFor inputs " + input1 + "\t(" + toHexString(input1) + ") and "
+ input2 + "\t(" + toHexString(input2) + ") and"
+ input3 + "\t(" + toHexString(input3) + ")\n" +
"\texpected " + expected + "\t(" + toHexString(expected) + ")\n" +
"\tgot " + result + "\t(" + toHexString(result) + ").");
return 1;
}
else
return 0;
}
static int testUlpCore(double result, double expected, double ulps) {
// We assume we won't be unlucky and have an inexact expected
// be nextDown(2^i) when 2^i would be the correctly rounded