8282252: Improve BigInteger/Decimal validation

Reviewed-by: jboes, rhalade, skoivu, bpb, smarks
This commit is contained in:
Joe Darcy 2022-03-14 23:39:22 +00:00 committed by Henry Jen
parent ecfb6bce5a
commit 25e88b21af
5 changed files with 378 additions and 57 deletions

View File

@ -31,6 +31,10 @@ package java.math;
import static java.math.BigInteger.LONG_MASK; import static java.math.BigInteger.LONG_MASK;
import java.io.IOException; import java.io.IOException;
import java.io.InvalidObjectException;
import java.io.ObjectInputStream;
import java.io.ObjectStreamException;
import java.io.StreamCorruptedException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Objects; import java.util.Objects;
@ -1047,6 +1051,15 @@ public class BigDecimal extends Number implements Comparable<BigDecimal> {
this.precision = prec; this.precision = prec;
} }
/**
* Accept no subclasses.
*/
private static BigInteger toStrictBigInteger(BigInteger val) {
return (val.getClass() == BigInteger.class) ?
val :
new BigInteger(val.toByteArray().clone());
}
/** /**
* Translates a {@code BigInteger} into a {@code BigDecimal}. * Translates a {@code BigInteger} into a {@code BigDecimal}.
* The scale of the {@code BigDecimal} is zero. * The scale of the {@code BigDecimal} is zero.
@ -1056,8 +1069,8 @@ public class BigDecimal extends Number implements Comparable<BigDecimal> {
*/ */
public BigDecimal(BigInteger val) { public BigDecimal(BigInteger val) {
scale = 0; scale = 0;
intVal = val; intVal = toStrictBigInteger(val);
intCompact = compactValFor(val); intCompact = compactValFor(intVal);
} }
/** /**
@ -1071,7 +1084,7 @@ public class BigDecimal extends Number implements Comparable<BigDecimal> {
* @since 1.5 * @since 1.5
*/ */
public BigDecimal(BigInteger val, MathContext mc) { public BigDecimal(BigInteger val, MathContext mc) {
this(val,0,mc); this(toStrictBigInteger(val), 0, mc);
} }
/** /**
@ -1085,8 +1098,8 @@ public class BigDecimal extends Number implements Comparable<BigDecimal> {
*/ */
public BigDecimal(BigInteger unscaledVal, int scale) { public BigDecimal(BigInteger unscaledVal, int scale) {
// Negative scales are now allowed // Negative scales are now allowed
this.intVal = unscaledVal; this.intVal = toStrictBigInteger(unscaledVal);
this.intCompact = compactValFor(unscaledVal); this.intCompact = compactValFor(this.intVal);
this.scale = scale; this.scale = scale;
} }
@ -1104,6 +1117,7 @@ public class BigDecimal extends Number implements Comparable<BigDecimal> {
* @since 1.5 * @since 1.5
*/ */
public BigDecimal(BigInteger unscaledVal, int scale, MathContext mc) { public BigDecimal(BigInteger unscaledVal, int scale, MathContext mc) {
unscaledVal = toStrictBigInteger(unscaledVal);
long compactVal = compactValFor(unscaledVal); long compactVal = compactValFor(unscaledVal);
int mcp = mc.precision; int mcp = mc.precision;
int prec = 0; int prec = 0;
@ -4253,9 +4267,13 @@ public class BigDecimal extends Number implements Comparable<BigDecimal> {
= unsafe.objectFieldOffset(BigDecimal.class, "intCompact"); = unsafe.objectFieldOffset(BigDecimal.class, "intCompact");
private static final long intValOffset private static final long intValOffset
= unsafe.objectFieldOffset(BigDecimal.class, "intVal"); = unsafe.objectFieldOffset(BigDecimal.class, "intVal");
private static final long scaleOffset
= unsafe.objectFieldOffset(BigDecimal.class, "scale");
static void setIntCompact(BigDecimal bd, long val) { static void setIntValAndScale(BigDecimal bd, BigInteger intVal, int scale) {
unsafe.putLong(bd, intCompactOffset, val); unsafe.putReference(bd, intValOffset, intVal);
unsafe.putInt(bd, scaleOffset, scale);
unsafe.putLong(bd, intCompactOffset, compactValFor(intVal));
} }
static void setIntValVolatile(BigDecimal bd, BigInteger val) { static void setIntValVolatile(BigDecimal bd, BigInteger val) {
@ -4274,15 +4292,30 @@ public class BigDecimal extends Number implements Comparable<BigDecimal> {
@java.io.Serial @java.io.Serial
private void readObject(java.io.ObjectInputStream s) private void readObject(java.io.ObjectInputStream s)
throws IOException, ClassNotFoundException { throws IOException, ClassNotFoundException {
// Read in all fields // prepare to read the fields
s.defaultReadObject(); ObjectInputStream.GetField fields = s.readFields();
// validate possibly bad fields BigInteger serialIntVal = (BigInteger) fields.get("intVal", null);
if (intVal == null) {
String message = "BigDecimal: null intVal in stream"; // Validate field data
throw new java.io.StreamCorruptedException(message); if (serialIntVal == null) {
// [all values of scale are now allowed] throw new StreamCorruptedException("Null or missing intVal in BigDecimal stream");
} }
UnsafeHolder.setIntCompact(this, compactValFor(intVal)); // Validate provenance of serialIntVal object
serialIntVal = toStrictBigInteger(serialIntVal);
// Any integer value is valid for scale
int serialScale = fields.get("scale", 0);
UnsafeHolder.setIntValAndScale(this, serialIntVal, serialScale);
}
/**
* Serialization without data not supported for this class.
*/
@java.io.Serial
private void readObjectNoData()
throws ObjectStreamException {
throw new InvalidObjectException("Deserialized BigDecimal objects need data");
} }
/** /**

View File

@ -30,9 +30,11 @@
package java.math; package java.math;
import java.io.IOException; import java.io.IOException;
import java.io.InvalidObjectException;
import java.io.ObjectInputStream; import java.io.ObjectInputStream;
import java.io.ObjectOutputStream; import java.io.ObjectOutputStream;
import java.io.ObjectStreamField; import java.io.ObjectStreamField;
import java.io.ObjectStreamException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Objects; import java.util.Objects;
import java.util.Random; import java.util.Random;
@ -4836,17 +4838,21 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
// prepare to read the alternate persistent fields // prepare to read the alternate persistent fields
ObjectInputStream.GetField fields = s.readFields(); ObjectInputStream.GetField fields = s.readFields();
// Read the alternate persistent fields that we care about // Read and validate the alternate persistent fields that we
int sign = fields.get("signum", -2); // care about, signum and magnitude
byte[] magnitude = (byte[])fields.get("magnitude", null);
// Validate signum // Read and validate signum
int sign = fields.get("signum", -2);
if (sign < -1 || sign > 1) { if (sign < -1 || sign > 1) {
String message = "BigInteger: Invalid signum value"; String message = "BigInteger: Invalid signum value";
if (fields.defaulted("signum")) if (fields.defaulted("signum"))
message = "BigInteger: Signum not present in stream"; message = "BigInteger: Signum not present in stream";
throw new java.io.StreamCorruptedException(message); throw new java.io.StreamCorruptedException(message);
} }
// Read and validate magnitude
byte[] magnitude = (byte[])fields.get("magnitude", null);
magnitude = magnitude.clone(); // defensive copy
int[] mag = stripLeadingZeroBytes(magnitude, 0, magnitude.length); int[] mag = stripLeadingZeroBytes(magnitude, 0, magnitude.length);
if ((mag.length == 0) != (sign == 0)) { if ((mag.length == 0) != (sign == 0)) {
String message = "BigInteger: signum-magnitude mismatch"; String message = "BigInteger: signum-magnitude mismatch";
@ -4855,18 +4861,24 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
throw new java.io.StreamCorruptedException(message); throw new java.io.StreamCorruptedException(message);
} }
// Commit final fields via Unsafe // Equivalent to checkRange() on mag local without assigning
UnsafeHolder.putSign(this, sign); // this.mag field
if (mag.length > MAX_MAG_LENGTH ||
// Calculate mag field from magnitude and discard magnitude (mag.length == MAX_MAG_LENGTH && mag[0] < 0)) {
UnsafeHolder.putMag(this, mag); throw new java.io.StreamCorruptedException("BigInteger: Out of the supported range");
if (mag.length >= MAX_MAG_LENGTH) {
try {
checkRange();
} catch (ArithmeticException e) {
throw new java.io.StreamCorruptedException("BigInteger: Out of the supported range");
}
} }
// Commit final fields via Unsafe
UnsafeHolder.putSignAndMag(this, sign, mag);
}
/**
* Serialization without data not supported for this class.
*/
@java.io.Serial
private void readObjectNoData()
throws ObjectStreamException {
throw new InvalidObjectException("Deserialized BigInteger objects need data");
} }
// Support for resetting final fields while deserializing // Support for resetting final fields while deserializing
@ -4878,11 +4890,8 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
private static final long magOffset private static final long magOffset
= unsafe.objectFieldOffset(BigInteger.class, "mag"); = unsafe.objectFieldOffset(BigInteger.class, "mag");
static void putSign(BigInteger bi, int sign) { static void putSignAndMag(BigInteger bi, int sign, int[] magnitude) {
unsafe.putInt(bi, signumOffset, sign); unsafe.putInt(bi, signumOffset, sign);
}
static void putMag(BigInteger bi, int[] magnitude) {
unsafe.putReference(bi, magOffset, magnitude); unsafe.putReference(bi, magOffset, magnitude);
} }
} }

View File

@ -0,0 +1,64 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
/*
* @test
* @bug 8282252
* @summary Test constructors of BigDecimal to replace BigInteger subclasses
*/
import java.math.*;
public class ConstructorUnscaledValue {
public static void main(String... args) {
TestBigInteger tbi = new TestBigInteger(BigInteger.ONE);
// Create BigDecimal's using each of the three constructors
// with guards on the class of unscaledValue
BigDecimal[] values = {
new BigDecimal(tbi),
new BigDecimal(tbi, 2),
new BigDecimal(tbi, 3, MathContext.DECIMAL32),
};
for (var bd : values) {
BigInteger unscaledValue = bd.unscaledValue();
if (unscaledValue.getClass() != BigInteger.class) {
throw new RuntimeException("Bad class for unscaledValue");
}
if (!unscaledValue.equals(BigInteger.ONE)) {
throw new RuntimeException("Bad value for unscaledValue");
}
}
}
private static class TestBigInteger extends BigInteger {
public TestBigInteger(BigInteger bi) {
super(bi.toByteArray());
}
@Override
public String toString() {
return java.util.Arrays.toString(toByteArray());
}
}
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright (c) 2005, Oracle and/or its affiliates. All rights reserved. * Copyright (c) 2005, 2022, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
* *
* This code is free software; you can redistribute it and/or modify it * This code is free software; you can redistribute it and/or modify it
@ -23,54 +23,128 @@
/* /*
* @test * @test
* @bug 6177836 * @bug 6177836 8282252
* @summary Verify BigDecimal objects with collapsed values are serialized properly. * @summary Verify BigDecimal objects with collapsed values are serialized properly.
* @author Joseph D. Darcy
*/ */
import java.math.*; import java.math.*;
import java.io.*; import java.io.*;
import java.util.List;
public class SerializationTests { public class SerializationTests {
static void checkSerialForm(BigDecimal bd) throws Exception { public static void main(String... args) throws Exception {
checkBigDecimalSerialRoundTrip();
checkBigDecimalSubSerialRoundTrip();
}
private static void checkSerialForm(BigDecimal bd) throws Exception {
checkSerialForm0(bd);
checkSerialForm0(bd.negate());
}
private static void checkSerialForm0(BigDecimal bd) throws Exception {
ByteArrayOutputStream bos = new ByteArrayOutputStream(); ByteArrayOutputStream bos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(bos); try(ObjectOutputStream oos = new ObjectOutputStream(bos)) {
oos.writeObject(bd); oos.writeObject(bd);
oos.flush(); oos.flush();
oos.close(); }
ObjectInputStream ois = new ObjectInputStream ois = new
ObjectInputStream(new ByteArrayInputStream(bos.toByteArray())); ObjectInputStream(new ByteArrayInputStream(bos.toByteArray()));
BigDecimal tmp = (BigDecimal)ois.readObject(); BigDecimal tmp = (BigDecimal)ois.readObject();
if (!bd.equals(tmp) || if (!bd.equals(tmp) ||
bd.hashCode() != tmp.hashCode()) { bd.hashCode() != tmp.hashCode() ||
bd.getClass() != tmp.getClass() ||
// Directly test equality of components
bd.scale() != tmp.scale() ||
!bd.unscaledValue().equals(tmp.unscaledValue())) {
System.err.print(" original : " + bd); System.err.print(" original : " + bd);
System.err.println(" (hash: 0x" + Integer.toHexString(bd.hashCode()) + ")"); System.err.println(" (hash: 0x" + Integer.toHexString(bd.hashCode()) + ")");
System.err.print("serialized : " + tmp); System.err.print("serialized : " + tmp);
System.err.println(" (hash: 0x" + Integer.toHexString(tmp.hashCode()) + ")"); System.err.println(" (hash: 0x" + Integer.toHexString(tmp.hashCode()) + ")");
throw new RuntimeException("Bad serial roundtrip"); throw new RuntimeException("Bad serial roundtrip");
} }
// If the class of the deserialized number is BigDecimal,
// verify the implementation constraint on the unscaled value
// having BigInteger class
if (tmp.getClass() == BigDecimal.class) {
if (tmp.unscaledValue().getClass() != BigInteger.class) {
throw new RuntimeException("Not using genuine BigInteger as an unscaled value");
}
}
} }
public static void main(String[] args) throws Exception { private static class BigIntegerSub extends BigInteger {
BigDecimal values[] = { public BigIntegerSub(BigInteger bi) {
BigDecimal.ZERO, super(bi.toByteArray());
BigDecimal.ONE, }
BigDecimal.TEN,
new BigDecimal(0), @Override
new BigDecimal(1), public String toString() {
new BigDecimal(10), return java.util.Arrays.toString(toByteArray());
new BigDecimal(Integer.MAX_VALUE), }
new BigDecimal(Long.MAX_VALUE-1), }
new BigDecimal(BigInteger.valueOf(1), 1), private static void checkBigDecimalSerialRoundTrip() throws Exception {
new BigDecimal(BigInteger.valueOf(100), 50), var values =
}; List.of(BigDecimal.ZERO,
BigDecimal.ONE,
BigDecimal.TEN,
new BigDecimal(0),
new BigDecimal(1),
new BigDecimal(10),
new BigDecimal(Integer.MAX_VALUE),
new BigDecimal(Long.MAX_VALUE-1),
new BigDecimal(BigInteger.valueOf(1), 1),
new BigDecimal(BigInteger.valueOf(100), 50),
new BigDecimal(new BigInteger("9223372036854775808"), // Long.MAX_VALUE + 1
Integer.MAX_VALUE),
new BigDecimal(new BigInteger("9223372036854775808"), // Long.MAX_VALUE + 1
Integer.MIN_VALUE),
new BigDecimal(new BigIntegerSub(BigInteger.ONE), 2));
for(BigDecimal value : values) { for(BigDecimal value : values) {
checkSerialForm(value); checkSerialForm(value);
checkSerialForm(value.negate()); }
}
private static class BigDecimalSub extends BigDecimal {
public BigDecimalSub(BigDecimal bd) {
super(bd.unscaledValue(), bd.scale());
} }
@Override
public String toString() {
return unscaledValue() + "x10^" + (-scale());
}
}
// Subclass defining a serialVersionUID
private static class BigDecimalSubSVUID extends BigDecimal {
@java.io.Serial
private static long serialVesionUID = 0x0123_4567_89ab_cdefL;
public BigDecimalSubSVUID(BigDecimal bd) {
super(bd.unscaledValue(), bd.scale());
}
}
private static void checkBigDecimalSubSerialRoundTrip() throws Exception {
var values =
List.of(BigDecimal.ZERO,
BigDecimal.ONE,
BigDecimal.TEN,
new BigDecimal(BigInteger.TEN, 1234),
new BigDecimal(new BigInteger("9223372036854775808"), // Long.MAX_VALUE + 1
Integer.MAX_VALUE),
new BigDecimal(new BigInteger("9223372036854775808"), // Long.MAX_VALUE + 1
Integer.MIN_VALUE));
for(var value : values) {
checkSerialForm(new BigDecimalSub(value));
checkSerialForm(new BigDecimalSubSVUID(value));
}
} }
} }

View File

@ -0,0 +1,141 @@
/*
* Copyright (c) 2005, 2022, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
/*
* @test
* @bug 8282252
* @summary Verify BigInteger objects are serialized properly.
*/
import java.math.*;
import java.io.*;
import java.util.Arrays;
import java.util.List;
public class SerializationTests {
public static void main(String... args) throws Exception {
checkBigIntegerSerialRoundTrip();
checkBigIntegerSubSerialRoundTrip();
}
private static void checkSerialForm(BigInteger bi) throws Exception {
checkSerialForm0(bi);
checkSerialForm0(bi.negate());
}
private static void checkSerialForm0(BigInteger bi) throws Exception {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
try(ObjectOutputStream oos = new ObjectOutputStream(bos)) {
oos.writeObject(bi);
oos.flush();
}
ObjectInputStream ois = new
ObjectInputStream(new ByteArrayInputStream(bos.toByteArray()));
BigInteger tmp = (BigInteger)ois.readObject();
if (!bi.equals(tmp) ||
bi.hashCode() != tmp.hashCode() ||
bi.getClass() != tmp.getClass() ||
// For extra measure, directly test equality of components
bi.signum() != tmp.signum() ||
!Arrays.equals(bi.toByteArray(), (tmp.toByteArray())) ) {
System.err.print(" original : " + bi);
System.err.println(" (hash: 0x" + Integer.toHexString(bi.hashCode()) + ")");
System.err.print("serialized : " + tmp);
System.err.println(" (hash: 0x" + Integer.toHexString(tmp.hashCode()) + ")");
throw new RuntimeException("Bad serial roundtrip");
}
}
private static void checkBigIntegerSerialRoundTrip() throws Exception {
var values =
List.of(BigInteger.ZERO,
BigInteger.ONE,
BigInteger.TWO,
BigInteger.TEN,
BigInteger.valueOf(100),
BigInteger.valueOf(Integer.MAX_VALUE),
BigInteger.valueOf(Long.MAX_VALUE-1),
new BigInteger("9223372036854775808")); // Long.MAX_VALUE + 1
for(BigInteger value : values) {
checkSerialForm(value);
}
}
// Subclass with specialized toString output
private static class BigIntegerSub extends BigInteger {
public BigIntegerSub(BigInteger bi) {
super(bi.toByteArray());
}
@Override
public String toString() {
return Arrays.toString(toByteArray());
}
}
// Subclass defining a serialVersionUID
private static class BigIntegerSubSVUID extends BigInteger {
@java.io.Serial
private static long serialVesionUID = 0x0123_4567_89ab_cdefL;
public BigIntegerSubSVUID(BigInteger bi) {
super(bi.toByteArray());
}
@Override
public String toString() {
return Arrays.toString(toByteArray());
}
}
// Subclass defining writeReplace
private static class BigIntegerSubWR extends BigInteger {
public BigIntegerSubWR(BigInteger bi) {
super(bi.toByteArray());
}
// Just return this; could use a serial proxy instead
@java.io.Serial
private Object writeReplace() throws ObjectStreamException {
return this;
}
}
private static void checkBigIntegerSubSerialRoundTrip() throws Exception {
var values = List.of(BigInteger.ZERO,
BigInteger.ONE,
BigInteger.TEN,
new BigInteger("9223372036854775808")); // Long.MAX_VALUE + 1
for(var value : values) {
checkSerialForm(new BigIntegerSub(value));
checkSerialForm(new BigIntegerSubSVUID(value));
checkSerialForm(new BigIntegerSubWR(value));
}
}
}