f15ab37909
Field arithmetic library for crypto algorithms like Poly1305 and X25519 Reviewed-by: xuelei
378 lines
14 KiB
Java
378 lines
14 KiB
Java
/*
|
|
* Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
|
|
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
|
*
|
|
* This code is free software; you can redistribute it and/or modify it
|
|
* under the terms of the GNU General Public License version 2 only, as
|
|
* published by the Free Software Foundation.
|
|
*
|
|
* This code is distributed in the hope that it will be useful, but WITHOUT
|
|
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
|
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
|
|
* version 2 for more details (a copy is included in the LICENSE file that
|
|
* accompanied this code).
|
|
*
|
|
* You should have received a copy of the GNU General Public License version
|
|
* 2 along with this work; if not, write to the Free Software Foundation,
|
|
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
|
|
*
|
|
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
|
|
* or visit www.oracle.com if you need additional information or have any
|
|
* questions.
|
|
*/
|
|
|
|
/*
|
|
* @test
|
|
* @bug 8181594
|
|
* @summary Test proper operation of integer field arithmetic
|
|
* @modules java.base/sun.security.util java.base/sun.security.util.math java.base/sun.security.util.math.intpoly
|
|
* @build BigIntegerModuloP
|
|
* @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial25519 32 0
|
|
* @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial448 56 1
|
|
* @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial1305 16 2
|
|
*/
|
|
|
|
import sun.security.util.math.*;
|
|
import java.util.function.*;
|
|
|
|
import java.util.*;
|
|
import java.math.*;
|
|
import java.nio.*;
|
|
|
|
public class TestIntegerModuloP {
|
|
|
|
static BigInteger TWO = BigInteger.valueOf(2);
|
|
|
|
// The test has a list of functions, and it selects randomly from that list
|
|
|
|
// The function types
|
|
interface ElemFunction extends BiFunction
|
|
<MutableIntegerModuloP, IntegerModuloP, IntegerModuloP> { }
|
|
interface ElemArrayFunction extends BiFunction
|
|
<MutableIntegerModuloP, IntegerModuloP, byte[]> { }
|
|
interface TriConsumer <T, U, V> {
|
|
void accept(T t, U u, V v);
|
|
}
|
|
interface ElemSetFunction extends TriConsumer
|
|
<MutableIntegerModuloP, IntegerModuloP, byte[]> { }
|
|
|
|
// The lists of functions. Multiple lists are needed because the test
|
|
// respects the limitations of the arithmetic implementations.
|
|
static final List<ElemFunction> ADD_FUNCTIONS = new ArrayList<>();
|
|
static final List<ElemFunction> MULT_FUNCTIONS = new ArrayList<>();
|
|
static final List<ElemArrayFunction> ARRAY_FUNCTIONS = new ArrayList<>();
|
|
static final List<ElemSetFunction> SET_FUNCTIONS = new ArrayList<>();
|
|
|
|
static void setUpFunctions(IntegerFieldModuloP field, int length) {
|
|
|
|
ADD_FUNCTIONS.clear();
|
|
MULT_FUNCTIONS.clear();
|
|
SET_FUNCTIONS.clear();
|
|
ARRAY_FUNCTIONS.clear();
|
|
|
|
byte highByte = (byte)
|
|
(field.getSize().bitLength() > length * 8 ? 1 : 0);
|
|
|
|
// add functions are (im)mutable add/subtract
|
|
ADD_FUNCTIONS.add(IntegerModuloP::add);
|
|
ADD_FUNCTIONS.add(IntegerModuloP::subtract);
|
|
ADD_FUNCTIONS.add(MutableIntegerModuloP::setSum);
|
|
ADD_FUNCTIONS.add(MutableIntegerModuloP::setDifference);
|
|
// also include functions that return the first/second argument
|
|
ADD_FUNCTIONS.add((a, b) -> a);
|
|
ADD_FUNCTIONS.add((a, b) -> b);
|
|
|
|
// mult functions are (im)mutable multiply and square
|
|
MULT_FUNCTIONS.add(IntegerModuloP::multiply);
|
|
MULT_FUNCTIONS.add((a, b) -> a.square());
|
|
MULT_FUNCTIONS.add((a, b) -> b.square());
|
|
MULT_FUNCTIONS.add(MutableIntegerModuloP::setProduct);
|
|
MULT_FUNCTIONS.add((a, b) -> a.setSquare());
|
|
// also test multiplication by a small value
|
|
MULT_FUNCTIONS.add((a, b) -> a.setProduct(b.getField().getSmallValue(
|
|
b.asBigInteger().mod(BigInteger.valueOf(262144)).intValue())));
|
|
|
|
// set functions are setValue with various argument types
|
|
SET_FUNCTIONS.add((a, b, c) -> a.setValue(b));
|
|
SET_FUNCTIONS.add((a, b, c) ->
|
|
a.setValue(c, 0, c.length, (byte) 0));
|
|
SET_FUNCTIONS.add((a, b, c) ->
|
|
a.setValue(ByteBuffer.wrap(c, 0, c.length).order(ByteOrder.LITTLE_ENDIAN),
|
|
c.length, highByte));
|
|
|
|
// array functions return the (possibly modified) value as byte array
|
|
ARRAY_FUNCTIONS.add((a, b ) -> a.asByteArray(length));
|
|
ARRAY_FUNCTIONS.add((a, b) -> a.addModPowerTwo(b, length));
|
|
}
|
|
|
|
public static void main(String[] args) {
|
|
|
|
String className = args[0];
|
|
final int length = Integer.parseInt(args[1]);
|
|
int seed = Integer.parseInt(args[2]);
|
|
|
|
Class<IntegerFieldModuloP> fieldBaseClass = IntegerFieldModuloP.class;
|
|
try {
|
|
Class<? extends IntegerFieldModuloP> clazz =
|
|
Class.forName(className).asSubclass(fieldBaseClass);
|
|
IntegerFieldModuloP field =
|
|
clazz.getDeclaredConstructor().newInstance();
|
|
|
|
setUpFunctions(field, length);
|
|
|
|
runFieldTest(field, length, seed);
|
|
} catch (Exception ex) {
|
|
throw new RuntimeException(ex);
|
|
}
|
|
|
|
System.out.println("All tests passed");
|
|
}
|
|
|
|
|
|
static void assertEqual(IntegerModuloP e1, IntegerModuloP e2) {
|
|
|
|
if (!e1.asBigInteger().equals(e2.asBigInteger())) {
|
|
throw new RuntimeException("values not equal: "
|
|
+ e1.asBigInteger() + " != " + e2.asBigInteger());
|
|
}
|
|
}
|
|
|
|
// A class that holds pairs of actual/expected values, and allows
|
|
// computation on these pairs.
|
|
static class TestPair<T extends IntegerModuloP> {
|
|
private final T test;
|
|
private final T baseline;
|
|
|
|
public TestPair(T test, T baseline) {
|
|
this.test = test;
|
|
this.baseline = baseline;
|
|
}
|
|
|
|
public T getTest() {
|
|
return test;
|
|
}
|
|
public T getBaseline() {
|
|
return baseline;
|
|
}
|
|
|
|
private void assertEqual() {
|
|
TestIntegerModuloP.assertEqual(test, baseline);
|
|
}
|
|
|
|
public TestPair<MutableIntegerModuloP> mutable() {
|
|
return new TestPair<>(test.mutable(), baseline.mutable());
|
|
}
|
|
|
|
public
|
|
<R extends IntegerModuloP, X extends IntegerModuloP>
|
|
TestPair<X> apply(BiFunction<T, R, X> func, TestPair<R> right) {
|
|
X testResult = func.apply(test, right.test);
|
|
X baselineResult = func.apply(baseline, right.baseline);
|
|
return new TestPair(testResult, baselineResult);
|
|
}
|
|
|
|
public
|
|
<U extends IntegerModuloP, V>
|
|
void apply(TriConsumer<T, U, V> func, TestPair<U> right, V argV) {
|
|
func.accept(test, right.test, argV);
|
|
func.accept(baseline, right.baseline, argV);
|
|
}
|
|
|
|
public
|
|
<R extends IntegerModuloP>
|
|
void applyAndCheckArray(BiFunction<T, R, byte[]> func,
|
|
TestPair<R> right) {
|
|
byte[] testResult = func.apply(test, right.test);
|
|
byte[] baselineResult = func.apply(baseline, right.baseline);
|
|
if (!Arrays.equals(testResult, baselineResult)) {
|
|
throw new RuntimeException("Array values do not match: "
|
|
+ byteArrayToHexString(testResult) + " != "
|
|
+ byteArrayToHexString(baselineResult));
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
static String byteArrayToHexString(byte[] arr) {
|
|
StringBuilder result = new StringBuilder();
|
|
for (int i = 0; i < arr.length; ++i) {
|
|
byte curVal = arr[i];
|
|
result.append(Character.forDigit(curVal >> 4 & 0xF, 16));
|
|
result.append(Character.forDigit(curVal & 0xF, 16));
|
|
}
|
|
return result.toString();
|
|
}
|
|
|
|
static TestPair<IntegerModuloP>
|
|
applyAndCheck(ElemFunction func, TestPair<MutableIntegerModuloP> left,
|
|
TestPair<IntegerModuloP> right) {
|
|
|
|
TestPair<IntegerModuloP> result = left.apply(func, right);
|
|
result.assertEqual();
|
|
left.assertEqual();
|
|
right.assertEqual();
|
|
|
|
return result;
|
|
}
|
|
|
|
static void
|
|
setAndCheck(ElemSetFunction func, TestPair<MutableIntegerModuloP> left,
|
|
TestPair<IntegerModuloP> right, byte[] argV) {
|
|
|
|
left.apply(func, right, argV);
|
|
left.assertEqual();
|
|
right.assertEqual();
|
|
}
|
|
|
|
static TestPair<MutableIntegerModuloP>
|
|
applyAndCheckMutable(ElemFunction func,
|
|
TestPair<MutableIntegerModuloP> left,
|
|
TestPair<IntegerModuloP> right) {
|
|
|
|
TestPair<IntegerModuloP> result = applyAndCheck(func, left, right);
|
|
|
|
TestPair<MutableIntegerModuloP> mutableResult = result.mutable();
|
|
mutableResult.assertEqual();
|
|
result.assertEqual();
|
|
left.assertEqual();
|
|
right.assertEqual();
|
|
|
|
return mutableResult;
|
|
}
|
|
|
|
static void
|
|
cswapAndCheck(int swap, TestPair<MutableIntegerModuloP> left,
|
|
TestPair<MutableIntegerModuloP> right) {
|
|
|
|
left.getTest().conditionalSwapWith(right.getTest(), swap);
|
|
left.getBaseline().conditionalSwapWith(right.getBaseline(), swap);
|
|
|
|
left.assertEqual();
|
|
right.assertEqual();
|
|
|
|
}
|
|
|
|
// Request arithmetic that should overflow, and ensure that overflow is
|
|
// detected.
|
|
static void runOverflowTest(TestPair<IntegerModuloP> elem) {
|
|
|
|
TestPair<MutableIntegerModuloP> mutableElem = elem.mutable();
|
|
|
|
try {
|
|
for (int i = 0; i < 1000; i++) {
|
|
applyAndCheck(MutableIntegerModuloP::setSum, mutableElem, elem);
|
|
}
|
|
applyAndCheck(MutableIntegerModuloP::setProduct, mutableElem, elem);
|
|
} catch (ArithmeticException ex) {
|
|
// this is expected
|
|
}
|
|
|
|
mutableElem = elem.mutable();
|
|
try {
|
|
for (int i = 0; i < 1000; i++) {
|
|
elem = applyAndCheck(IntegerModuloP::add,
|
|
mutableElem, elem);
|
|
}
|
|
applyAndCheck(IntegerModuloP::multiply, mutableElem, elem);
|
|
} catch (ArithmeticException ex) {
|
|
// this is expected
|
|
}
|
|
}
|
|
|
|
// Run a large number of random operations and ensure that
|
|
// results are correct
|
|
static void runOperationsTest(Random random, int length,
|
|
TestPair<IntegerModuloP> elem,
|
|
TestPair<IntegerModuloP> right) {
|
|
|
|
TestPair<MutableIntegerModuloP> left = elem.mutable();
|
|
|
|
for (int i = 0; i < 10000; i++) {
|
|
|
|
ElemFunction addFunc1 =
|
|
ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size()));
|
|
TestPair<MutableIntegerModuloP> result1 =
|
|
applyAndCheckMutable(addFunc1, left, right);
|
|
|
|
// left could have been modified, so turn it back into a summand
|
|
applyAndCheckMutable((a, b) -> a.setSquare(), left, right);
|
|
|
|
ElemFunction addFunc2 =
|
|
ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size()));
|
|
TestPair<IntegerModuloP> result2 =
|
|
applyAndCheck(addFunc2, left, right);
|
|
|
|
ElemFunction multFunc2 =
|
|
MULT_FUNCTIONS.get(random.nextInt(MULT_FUNCTIONS.size()));
|
|
TestPair<MutableIntegerModuloP> multResult =
|
|
applyAndCheckMutable(multFunc2, result1, result2);
|
|
|
|
int swap = random.nextInt(2);
|
|
cswapAndCheck(swap, left, multResult);
|
|
|
|
ElemSetFunction setFunc =
|
|
SET_FUNCTIONS.get(random.nextInt(SET_FUNCTIONS.size()));
|
|
byte[] valueArr = new byte[length];
|
|
random.nextBytes(valueArr);
|
|
setAndCheck(setFunc, result1, result2, valueArr);
|
|
|
|
// left could have been modified, so to turn it back into a summand
|
|
applyAndCheckMutable((a, b) -> a.setSquare(), left, right);
|
|
|
|
ElemArrayFunction arrayFunc =
|
|
ARRAY_FUNCTIONS.get(random.nextInt(ARRAY_FUNCTIONS.size()));
|
|
left.applyAndCheckArray(arrayFunc, right);
|
|
}
|
|
}
|
|
|
|
// Run all the tests for a given field
|
|
static void runFieldTest(IntegerFieldModuloP testField,
|
|
int length, int seed) {
|
|
System.out.println("Testing: " + testField.getClass().getSimpleName());
|
|
|
|
Random random = new Random(seed);
|
|
|
|
IntegerFieldModuloP baselineField =
|
|
new BigIntegerModuloP(testField.getSize());
|
|
|
|
int numBits = testField.getSize().bitLength();
|
|
BigInteger r =
|
|
new BigInteger(numBits, random).mod(testField.getSize());
|
|
TestPair<IntegerModuloP> rand =
|
|
new TestPair(testField.getElement(r), baselineField.getElement(r));
|
|
|
|
runOverflowTest(rand);
|
|
|
|
// check combinations of operations for different kinds of elements
|
|
List<TestPair<IntegerModuloP>> testElements = new ArrayList<>();
|
|
testElements.add(rand);
|
|
testElements.add(new TestPair(testField.get0(), baselineField.get0()));
|
|
testElements.add(new TestPair(testField.get1(), baselineField.get1()));
|
|
byte[] testArr = {121, 37, -100, -5, 76, 33};
|
|
testElements.add(new TestPair(testField.getElement(testArr),
|
|
baselineField.getElement(testArr)));
|
|
|
|
testArr = new byte[length];
|
|
random.nextBytes(testArr);
|
|
testElements.add(new TestPair(testField.getElement(testArr),
|
|
baselineField.getElement(testArr)));
|
|
|
|
random.nextBytes(testArr);
|
|
byte highByte = (byte) (numBits > length * 8 ? 1 : 0);
|
|
testElements.add(
|
|
new TestPair(
|
|
testField.getElement(testArr, 0, testArr.length, highByte),
|
|
baselineField.getElement(testArr, 0, testArr.length, highByte)
|
|
)
|
|
);
|
|
|
|
for (int i = 0; i < testElements.size(); i++) {
|
|
for (int j = 0; j < testElements.size(); j++) {
|
|
runOperationsTest(random, length, testElements.get(i),
|
|
testElements.get(j));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|