/* * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it * under the terms of the GNU General Public License version 2 only, as * published by the Free Software Foundation. * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * * You should have received a copy of the GNU General Public License version * 2 along with this work; if not, write to the Free Software Foundation, * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. * * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA * or visit www.oracle.com if you need additional information or have any * questions. */ /* * @test * @bug 8181594 * @summary Test proper operation of integer field arithmetic * @modules java.base/sun.security.util java.base/sun.security.util.math java.base/sun.security.util.math.intpoly * @build BigIntegerModuloP * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial25519 32 0 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial448 56 1 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial1305 16 2 */ import sun.security.util.math.*; import java.util.function.*; import java.util.*; import java.math.*; import java.nio.*; public class TestIntegerModuloP { static BigInteger TWO = BigInteger.valueOf(2); // The test has a list of functions, and it selects randomly from that list // The function types interface ElemFunction extends BiFunction { } interface ElemArrayFunction extends BiFunction { } interface TriConsumer { void accept(T t, U u, V v); } interface ElemSetFunction extends TriConsumer { } // The lists of functions. Multiple lists are needed because the test // respects the limitations of the arithmetic implementations. static final List ADD_FUNCTIONS = new ArrayList<>(); static final List MULT_FUNCTIONS = new ArrayList<>(); static final List ARRAY_FUNCTIONS = new ArrayList<>(); static final List SET_FUNCTIONS = new ArrayList<>(); static void setUpFunctions(IntegerFieldModuloP field, int length) { ADD_FUNCTIONS.clear(); MULT_FUNCTIONS.clear(); SET_FUNCTIONS.clear(); ARRAY_FUNCTIONS.clear(); byte highByte = (byte) (field.getSize().bitLength() > length * 8 ? 1 : 0); // add functions are (im)mutable add/subtract ADD_FUNCTIONS.add(IntegerModuloP::add); ADD_FUNCTIONS.add(IntegerModuloP::subtract); ADD_FUNCTIONS.add(MutableIntegerModuloP::setSum); ADD_FUNCTIONS.add(MutableIntegerModuloP::setDifference); // also include functions that return the first/second argument ADD_FUNCTIONS.add((a, b) -> a); ADD_FUNCTIONS.add((a, b) -> b); // mult functions are (im)mutable multiply and square MULT_FUNCTIONS.add(IntegerModuloP::multiply); MULT_FUNCTIONS.add((a, b) -> a.square()); MULT_FUNCTIONS.add((a, b) -> b.square()); MULT_FUNCTIONS.add(MutableIntegerModuloP::setProduct); MULT_FUNCTIONS.add((a, b) -> a.setSquare()); // also test multiplication by a small value MULT_FUNCTIONS.add((a, b) -> a.setProduct(b.getField().getSmallValue( b.asBigInteger().mod(BigInteger.valueOf(262144)).intValue()))); // set functions are setValue with various argument types SET_FUNCTIONS.add((a, b, c) -> a.setValue(b)); SET_FUNCTIONS.add((a, b, c) -> a.setValue(c, 0, c.length, (byte) 0)); SET_FUNCTIONS.add((a, b, c) -> a.setValue(ByteBuffer.wrap(c, 0, c.length).order(ByteOrder.LITTLE_ENDIAN), c.length, highByte)); // array functions return the (possibly modified) value as byte array ARRAY_FUNCTIONS.add((a, b ) -> a.asByteArray(length)); ARRAY_FUNCTIONS.add((a, b) -> a.addModPowerTwo(b, length)); } public static void main(String[] args) { String className = args[0]; final int length = Integer.parseInt(args[1]); int seed = Integer.parseInt(args[2]); Class fieldBaseClass = IntegerFieldModuloP.class; try { Class clazz = Class.forName(className).asSubclass(fieldBaseClass); IntegerFieldModuloP field = clazz.getDeclaredConstructor().newInstance(); setUpFunctions(field, length); runFieldTest(field, length, seed); } catch (Exception ex) { throw new RuntimeException(ex); } System.out.println("All tests passed"); } static void assertEqual(IntegerModuloP e1, IntegerModuloP e2) { if (!e1.asBigInteger().equals(e2.asBigInteger())) { throw new RuntimeException("values not equal: " + e1.asBigInteger() + " != " + e2.asBigInteger()); } } // A class that holds pairs of actual/expected values, and allows // computation on these pairs. static class TestPair { private final T test; private final T baseline; public TestPair(T test, T baseline) { this.test = test; this.baseline = baseline; } public T getTest() { return test; } public T getBaseline() { return baseline; } private void assertEqual() { TestIntegerModuloP.assertEqual(test, baseline); } public TestPair mutable() { return new TestPair<>(test.mutable(), baseline.mutable()); } public TestPair apply(BiFunction func, TestPair right) { X testResult = func.apply(test, right.test); X baselineResult = func.apply(baseline, right.baseline); return new TestPair(testResult, baselineResult); } public void apply(TriConsumer func, TestPair right, V argV) { func.accept(test, right.test, argV); func.accept(baseline, right.baseline, argV); } public void applyAndCheckArray(BiFunction func, TestPair right) { byte[] testResult = func.apply(test, right.test); byte[] baselineResult = func.apply(baseline, right.baseline); if (!Arrays.equals(testResult, baselineResult)) { throw new RuntimeException("Array values do not match: " + byteArrayToHexString(testResult) + " != " + byteArrayToHexString(baselineResult)); } } } static String byteArrayToHexString(byte[] arr) { StringBuilder result = new StringBuilder(); for (int i = 0; i < arr.length; ++i) { byte curVal = arr[i]; result.append(Character.forDigit(curVal >> 4 & 0xF, 16)); result.append(Character.forDigit(curVal & 0xF, 16)); } return result.toString(); } static TestPair applyAndCheck(ElemFunction func, TestPair left, TestPair right) { TestPair result = left.apply(func, right); result.assertEqual(); left.assertEqual(); right.assertEqual(); return result; } static void setAndCheck(ElemSetFunction func, TestPair left, TestPair right, byte[] argV) { left.apply(func, right, argV); left.assertEqual(); right.assertEqual(); } static TestPair applyAndCheckMutable(ElemFunction func, TestPair left, TestPair right) { TestPair result = applyAndCheck(func, left, right); TestPair mutableResult = result.mutable(); mutableResult.assertEqual(); result.assertEqual(); left.assertEqual(); right.assertEqual(); return mutableResult; } static void cswapAndCheck(int swap, TestPair left, TestPair right) { left.getTest().conditionalSwapWith(right.getTest(), swap); left.getBaseline().conditionalSwapWith(right.getBaseline(), swap); left.assertEqual(); right.assertEqual(); } // Request arithmetic that should overflow, and ensure that overflow is // detected. static void runOverflowTest(TestPair elem) { TestPair mutableElem = elem.mutable(); try { for (int i = 0; i < 1000; i++) { applyAndCheck(MutableIntegerModuloP::setSum, mutableElem, elem); } applyAndCheck(MutableIntegerModuloP::setProduct, mutableElem, elem); } catch (ArithmeticException ex) { // this is expected } mutableElem = elem.mutable(); try { for (int i = 0; i < 1000; i++) { elem = applyAndCheck(IntegerModuloP::add, mutableElem, elem); } applyAndCheck(IntegerModuloP::multiply, mutableElem, elem); } catch (ArithmeticException ex) { // this is expected } } // Run a large number of random operations and ensure that // results are correct static void runOperationsTest(Random random, int length, TestPair elem, TestPair right) { TestPair left = elem.mutable(); for (int i = 0; i < 10000; i++) { ElemFunction addFunc1 = ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size())); TestPair result1 = applyAndCheckMutable(addFunc1, left, right); // left could have been modified, so turn it back into a summand applyAndCheckMutable((a, b) -> a.setSquare(), left, right); ElemFunction addFunc2 = ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size())); TestPair result2 = applyAndCheck(addFunc2, left, right); ElemFunction multFunc2 = MULT_FUNCTIONS.get(random.nextInt(MULT_FUNCTIONS.size())); TestPair multResult = applyAndCheckMutable(multFunc2, result1, result2); int swap = random.nextInt(2); cswapAndCheck(swap, left, multResult); ElemSetFunction setFunc = SET_FUNCTIONS.get(random.nextInt(SET_FUNCTIONS.size())); byte[] valueArr = new byte[length]; random.nextBytes(valueArr); setAndCheck(setFunc, result1, result2, valueArr); // left could have been modified, so to turn it back into a summand applyAndCheckMutable((a, b) -> a.setSquare(), left, right); ElemArrayFunction arrayFunc = ARRAY_FUNCTIONS.get(random.nextInt(ARRAY_FUNCTIONS.size())); left.applyAndCheckArray(arrayFunc, right); } } // Run all the tests for a given field static void runFieldTest(IntegerFieldModuloP testField, int length, int seed) { System.out.println("Testing: " + testField.getClass().getSimpleName()); Random random = new Random(seed); IntegerFieldModuloP baselineField = new BigIntegerModuloP(testField.getSize()); int numBits = testField.getSize().bitLength(); BigInteger r = new BigInteger(numBits, random).mod(testField.getSize()); TestPair rand = new TestPair(testField.getElement(r), baselineField.getElement(r)); runOverflowTest(rand); // check combinations of operations for different kinds of elements List> testElements = new ArrayList<>(); testElements.add(rand); testElements.add(new TestPair(testField.get0(), baselineField.get0())); testElements.add(new TestPair(testField.get1(), baselineField.get1())); byte[] testArr = {121, 37, -100, -5, 76, 33}; testElements.add(new TestPair(testField.getElement(testArr), baselineField.getElement(testArr))); testArr = new byte[length]; random.nextBytes(testArr); testElements.add(new TestPair(testField.getElement(testArr), baselineField.getElement(testArr))); random.nextBytes(testArr); byte highByte = (byte) (numBits > length * 8 ? 1 : 0); testElements.add( new TestPair( testField.getElement(testArr, 0, testArr.length, highByte), baselineField.getElement(testArr, 0, testArr.length, highByte) ) ); for (int i = 0; i < testElements.size(); i++) { for (int j = 0; j < testElements.size(); j++) { runOperationsTest(random, length, testElements.get(i), testElements.get(j)); } } } }