/* * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it * under the terms of the GNU General Public License version 2 only, as * published by the Free Software Foundation. * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * * You should have received a copy of the GNU General Public License version * 2 along with this work; if not, write to the Free Software Foundation, * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. * * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA * or visit www.oracle.com if you need additional information or have any * questions. */ /* * @test * @modules java.base/jdk.internal.ref * @run testng/othervm * --enable-native-access=ALL-UNNAMED * TestNulls */ import java.lang.foreign.*; import jdk.internal.ref.CleanerFactory; import org.testng.annotations.DataProvider; import org.testng.annotations.NoInjection; import org.testng.annotations.Test; import java.lang.constant.Constable; import java.lang.foreign.Arena; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; import java.lang.invoke.VarHandle; import java.lang.ref.Cleaner; import java.lang.reflect.Array; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.channels.FileChannel; import java.nio.charset.Charset; import java.nio.file.Path; import java.util.*; import java.util.function.Consumer; import java.util.function.Supplier; import java.util.function.UnaryOperator; import java.util.stream.Collectors; import java.util.stream.Stream; import static java.lang.foreign.ValueLayout.JAVA_INT; import static java.lang.foreign.ValueLayout.JAVA_LONG; import static org.testng.Assert.*; import static org.testng.Assert.fail; /** * This test makes sure that public API classes (listed in {@link TestNulls#CLASSES}) throws NPEs whenever * nulls are provided. The test looks at all the public methods in all the listed classes, and injects * values automatically. If an API takes a reference, the test will try to inject nulls. For APIs taking * either reference arrays, or collections, the framework will also generate additional replacements * (e.g. other than just replacing the array, or collection with null), such as an array or collection * with null elements. The test can be customized by adding/removing classes to the {@link #CLASSES} array, * by adding/removing default mappings for standard carrier types (see {@link #DEFAULT_VALUES} or by * adding/removing custom replacements (see {@link #REPLACEMENT_VALUES}). */ public class TestNulls { static final Class[] CLASSES = new Class[] { Arena.class, MemorySegment.class, MemoryLayout.class, MemoryLayout.PathElement.class, SequenceLayout.class, ValueLayout.class, ValueLayout.OfBoolean.class, ValueLayout.OfByte.class, ValueLayout.OfChar.class, ValueLayout.OfShort.class, ValueLayout.OfInt.class, ValueLayout.OfFloat.class, ValueLayout.OfLong.class, ValueLayout.OfDouble.class, AddressLayout.class, PaddingLayout.class, GroupLayout.class, StructLayout.class, UnionLayout.class, Linker.class, Linker.Option.class, FunctionDescriptor.class, SegmentAllocator.class, MemorySegment.Scope.class, SymbolLookup.class }; static final Set EXCLUDE_LIST = Set.of( "java.lang.foreign.MemorySegment/reinterpret(java.lang.foreign.Arena,java.util.function.Consumer)/1/0", "java.lang.foreign.MemorySegment/reinterpret(long,java.lang.foreign.Arena,java.util.function.Consumer)/2/0" ); static final Set OBJECT_METHODS = Stream.of(Object.class.getMethods()) .map(Method::getName) .collect(Collectors.toSet()); static final Map, Object> DEFAULT_VALUES = new HashMap<>(); static void addDefaultMapping(Class carrier, Z value) { DEFAULT_VALUES.put(carrier, value); } static { addDefaultMapping(char.class, (char)0); addDefaultMapping(byte.class, (byte)0); addDefaultMapping(short.class, (short)0); addDefaultMapping(int.class, 0); addDefaultMapping(float.class, 0f); addDefaultMapping(long.class, 0L); addDefaultMapping(double.class, 0d); addDefaultMapping(boolean.class, true); addDefaultMapping(ByteOrder.class, ByteOrder.nativeOrder()); addDefaultMapping(Thread.class, Thread.currentThread()); addDefaultMapping(Cleaner.class, CleanerFactory.cleaner()); addDefaultMapping(Buffer.class, ByteBuffer.wrap(new byte[10])); addDefaultMapping(ByteBuffer.class, ByteBuffer.wrap(new byte[10])); addDefaultMapping(Path.class, Path.of("nonExistent")); addDefaultMapping(FileChannel.MapMode.class, FileChannel.MapMode.PRIVATE); addDefaultMapping(UnaryOperator.class, UnaryOperator.identity()); addDefaultMapping(String.class, "Hello!"); addDefaultMapping(Constable.class, "Hello!"); addDefaultMapping(Class.class, String.class); addDefaultMapping(Runnable.class, () -> {}); addDefaultMapping(Object.class, new Object()); addDefaultMapping(VarHandle.class, JAVA_INT.varHandle()); addDefaultMapping(MethodHandle.class, MethodHandles.identity(int.class)); addDefaultMapping(List.class, List.of()); addDefaultMapping(Charset.class, Charset.defaultCharset()); addDefaultMapping(Consumer.class, x -> {}); addDefaultMapping(MethodType.class, MethodType.methodType(void.class)); addDefaultMapping(MemoryLayout.class, ValueLayout.JAVA_INT); addDefaultMapping(ValueLayout.class, ValueLayout.JAVA_INT); addDefaultMapping(AddressLayout.class, ValueLayout.ADDRESS); addDefaultMapping(ValueLayout.OfByte.class, ValueLayout.JAVA_BYTE); addDefaultMapping(ValueLayout.OfBoolean.class, ValueLayout.JAVA_BOOLEAN); addDefaultMapping(ValueLayout.OfChar.class, ValueLayout.JAVA_CHAR); addDefaultMapping(ValueLayout.OfShort.class, ValueLayout.JAVA_SHORT); addDefaultMapping(ValueLayout.OfInt.class, ValueLayout.JAVA_INT); addDefaultMapping(ValueLayout.OfFloat.class, ValueLayout.JAVA_FLOAT); addDefaultMapping(ValueLayout.OfLong.class, JAVA_LONG); addDefaultMapping(ValueLayout.OfDouble.class, ValueLayout.JAVA_DOUBLE); addDefaultMapping(PaddingLayout.class, MemoryLayout.paddingLayout(4)); addDefaultMapping(GroupLayout.class, MemoryLayout.structLayout(ValueLayout.JAVA_INT)); addDefaultMapping(StructLayout.class, MemoryLayout.structLayout(ValueLayout.JAVA_INT)); addDefaultMapping(UnionLayout.class, MemoryLayout.unionLayout(ValueLayout.JAVA_INT)); addDefaultMapping(SequenceLayout.class, MemoryLayout.sequenceLayout(1, ValueLayout.JAVA_INT)); addDefaultMapping(SymbolLookup.class, SymbolLookup.loaderLookup()); addDefaultMapping(MemorySegment.class, MemorySegment.ofArray(new byte[10])); addDefaultMapping(FunctionDescriptor.class, FunctionDescriptor.ofVoid()); addDefaultMapping(Linker.class, Linker.nativeLinker()); addDefaultMapping(Arena.class, Arena.ofConfined()); addDefaultMapping(MemorySegment.Scope.class, Arena.ofAuto().scope()); addDefaultMapping(SegmentAllocator.class, SegmentAllocator.prefixAllocator(MemorySegment.ofArray(new byte[10]))); addDefaultMapping(Supplier.class, () -> null); addDefaultMapping(ClassLoader.class, TestNulls.class.getClassLoader()); addDefaultMapping(Thread.UncaughtExceptionHandler.class, (thread, ex) -> {}); } static final Map, Object[]> REPLACEMENT_VALUES = new HashMap<>(); @SafeVarargs static void addReplacements(Class carrier, Z... value) { REPLACEMENT_VALUES.put(carrier, value); } static { addReplacements(Collection.class, null, Stream.of(new Object[] { null }).collect(Collectors.toList())); addReplacements(List.class, null, Stream.of(new Object[] { null }).collect(Collectors.toList())); addReplacements(Set.class, null, Stream.of(new Object[] { null }).collect(Collectors.toSet())); } @Test(dataProvider = "cases") public void testNulls(String testName, @NoInjection Method meth, Object receiver, Object[] args) { try { meth.invoke(receiver, args); fail("Method invocation completed normally"); } catch (InvocationTargetException ex) { Class cause = ex.getCause().getClass(); assertEquals(cause, NullPointerException.class, "got " + cause.getName() + " - expected NullPointerException"); } catch (Throwable ex) { fail("Unexpected exception: " + ex); } } @DataProvider(name = "cases") static Iterator cases() { List cases = new ArrayList<>(); for (Class clazz : CLASSES) { for (Method m : clazz.getMethods()) { if (OBJECT_METHODS.contains(m.getName())) continue; boolean isStatic = (m.getModifiers() & Modifier.STATIC) != 0; List refIndices = new ArrayList<>(); for (int i = 0; i < m.getParameterCount(); i++) { Class param = m.getParameterTypes()[i]; if (!param.isPrimitive()) { refIndices.add(i); } } for (int i : refIndices) { Object[] replacements = replacements(m.getParameterTypes()[i]); for (int r = 0 ; r < replacements.length ; r++) { String testName = clazz.getName() + "/" + shortSig(m) + "/" + i + "/" + r; if (EXCLUDE_LIST.contains(testName)) continue; Object[] args = new Object[m.getParameterCount()]; for (int j = 0; j < args.length; j++) { args[j] = defaultValue(m.getParameterTypes()[j]); } args[i] = replacements[r]; Object receiver = isStatic ? null : defaultValue(clazz); cases.add(new Object[]{testName, m, receiver, args}); } } } } return cases.iterator(); }; static String shortSig(Method m) { StringJoiner sj = new StringJoiner(",", m.getName() + "(", ")"); for (Class parameterType : m.getParameterTypes()) { sj.add(parameterType.getTypeName()); } return sj.toString(); } static Object defaultValue(Class carrier) { if (carrier.isArray()) { return Array.newInstance(carrier.componentType(), 0); } Object value = DEFAULT_VALUES.get(carrier); if (value == null) { throw new UnsupportedOperationException(carrier.getName()); } return value; } static Object[] replacements(Class carrier) { if (carrier.isArray() && !carrier.getComponentType().isPrimitive()) { Object arr = Array.newInstance(carrier.componentType(), 1); Array.set(arr, 0, null); return new Object[] { null, arr }; } return REPLACEMENT_VALUES.getOrDefault(carrier, new Object[] { null }); } }