jdk-24/test/jdk/java/foreign/TestReshape.java

193 lines
7.8 KiB
Java
Raw Normal View History

/*
* Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
/*
* @test
* @run testng TestReshape
*/
import jdk.incubator.foreign.MemoryLayout;
import jdk.incubator.foreign.SequenceLayout;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.LongStream;
import jdk.incubator.foreign.ValueLayout;
import org.testng.annotations.*;
import static org.testng.Assert.*;
public class TestReshape {
@Test(dataProvider = "shapes")
public void testReshape(MemoryLayout layout, long[] expectedShape) {
long flattenedSize = LongStream.of(expectedShape).reduce(1L, Math::multiplyExact);
SequenceLayout seq_flattened = MemoryLayout.sequenceLayout(flattenedSize, layout);
assertDimensions(seq_flattened, flattenedSize);
for (long[] shape : new Shape(expectedShape)) {
SequenceLayout seq_shaped = seq_flattened.reshape(shape);
assertDimensions(seq_shaped, expectedShape);
assertEquals(seq_shaped.flatten(), seq_flattened);
}
}
@Test(expectedExceptions = IllegalArgumentException.class)
public void testInvalidReshape() {
SequenceLayout seq = MemoryLayout.sequenceLayout(4, ValueLayout.JAVA_INT);
seq.reshape(3, 2);
}
@Test(expectedExceptions = IllegalArgumentException.class)
public void testBadReshapeInference() {
SequenceLayout seq = MemoryLayout.sequenceLayout(4, ValueLayout.JAVA_INT);
seq.reshape(-1, -1);
}
@Test(expectedExceptions = IllegalArgumentException.class)
public void testBadReshapeParameterZero() {
SequenceLayout seq = MemoryLayout.sequenceLayout(4, ValueLayout.JAVA_INT);
seq.reshape(0, 4);
}
@Test(expectedExceptions = IllegalArgumentException.class)
public void testBadReshapeParameterNegative() {
SequenceLayout seq = MemoryLayout.sequenceLayout(4, ValueLayout.JAVA_INT);
seq.reshape(-2, 2);
}
@Test(expectedExceptions = UnsupportedOperationException.class)
public void testReshapeOnUnboundSequence() {
SequenceLayout seq = MemoryLayout.sequenceLayout(ValueLayout.JAVA_INT);
seq.reshape(3, 2);
}
@Test(expectedExceptions = UnsupportedOperationException.class)
public void testFlattenOnUnboundSequence() {
SequenceLayout seq = MemoryLayout.sequenceLayout(ValueLayout.JAVA_INT);
seq.flatten();
}
@Test(expectedExceptions = UnsupportedOperationException.class)
public void testFlattenOnUnboundNestedSequence() {
SequenceLayout seq = MemoryLayout.sequenceLayout(4, MemoryLayout.sequenceLayout(ValueLayout.JAVA_INT));
seq.flatten();
}
static void assertDimensions(SequenceLayout layout, long... dims) {
SequenceLayout prev = null;
for (int i = 0 ; i < dims.length ; i++) {
if (prev != null) {
layout = (SequenceLayout)prev.elementLayout();
}
assertEquals(layout.elementCount().getAsLong(), dims[i]);
prev = layout;
}
}
static class Shape implements Iterable<long[]> {
long[] shape;
Shape(long... shape) {
this.shape = shape;
}
public Iterator<long[]> iterator() {
List<long[]> shapes = new ArrayList<>();
shapes.add(shape);
for (int i = 0 ; i < shape.length ; i++) {
long[] inferredShape = shape.clone();
inferredShape[i] = -1;
shapes.add(inferredShape);
}
return shapes.iterator();
}
}
static MemoryLayout POINT = MemoryLayout.structLayout(
ValueLayout.JAVA_INT,
ValueLayout.JAVA_INT
);
@DataProvider(name = "shapes")
Object[][] shapes() {
return new Object[][] {
{ ValueLayout.JAVA_BYTE, new long[] { 256 } },
{ ValueLayout.JAVA_BYTE, new long[] { 16, 16 } },
{ ValueLayout.JAVA_BYTE, new long[] { 4, 4, 4, 4 } },
{ ValueLayout.JAVA_BYTE, new long[] { 2, 8, 16 } },
{ ValueLayout.JAVA_BYTE, new long[] { 16, 8, 2 } },
{ ValueLayout.JAVA_BYTE, new long[] { 8, 16, 2 } },
{ ValueLayout.JAVA_SHORT, new long[] { 256 } },
{ ValueLayout.JAVA_SHORT, new long[] { 16, 16 } },
{ ValueLayout.JAVA_SHORT, new long[] { 4, 4, 4, 4 } },
{ ValueLayout.JAVA_SHORT, new long[] { 2, 8, 16 } },
{ ValueLayout.JAVA_SHORT, new long[] { 16, 8, 2 } },
{ ValueLayout.JAVA_SHORT, new long[] { 8, 16, 2 } },
{ ValueLayout.JAVA_CHAR, new long[] { 256 } },
{ ValueLayout.JAVA_CHAR, new long[] { 16, 16 } },
{ ValueLayout.JAVA_CHAR, new long[] { 4, 4, 4, 4 } },
{ ValueLayout.JAVA_CHAR, new long[] { 2, 8, 16 } },
{ ValueLayout.JAVA_CHAR, new long[] { 16, 8, 2 } },
{ ValueLayout.JAVA_CHAR, new long[] { 8, 16, 2 } },
{ ValueLayout.JAVA_INT, new long[] { 256 } },
{ ValueLayout.JAVA_INT, new long[] { 16, 16 } },
{ ValueLayout.JAVA_INT, new long[] { 4, 4, 4, 4 } },
{ ValueLayout.JAVA_INT, new long[] { 2, 8, 16 } },
{ ValueLayout.JAVA_INT, new long[] { 16, 8, 2 } },
{ ValueLayout.JAVA_INT, new long[] { 8, 16, 2 } },
{ ValueLayout.JAVA_LONG, new long[] { 256 } },
{ ValueLayout.JAVA_LONG, new long[] { 16, 16 } },
{ ValueLayout.JAVA_LONG, new long[] { 4, 4, 4, 4 } },
{ ValueLayout.JAVA_LONG, new long[] { 2, 8, 16 } },
{ ValueLayout.JAVA_LONG, new long[] { 16, 8, 2 } },
{ ValueLayout.JAVA_LONG, new long[] { 8, 16, 2 } },
{ ValueLayout.JAVA_FLOAT, new long[] { 256 } },
{ ValueLayout.JAVA_FLOAT, new long[] { 16, 16 } },
{ ValueLayout.JAVA_FLOAT, new long[] { 4, 4, 4, 4 } },
{ ValueLayout.JAVA_FLOAT, new long[] { 2, 8, 16 } },
{ ValueLayout.JAVA_FLOAT, new long[] { 16, 8, 2 } },
{ ValueLayout.JAVA_FLOAT, new long[] { 8, 16, 2 } },
{ ValueLayout.JAVA_DOUBLE, new long[] { 256 } },
{ ValueLayout.JAVA_DOUBLE, new long[] { 16, 16 } },
{ ValueLayout.JAVA_DOUBLE, new long[] { 4, 4, 4, 4 } },
{ ValueLayout.JAVA_DOUBLE, new long[] { 2, 8, 16 } },
{ ValueLayout.JAVA_DOUBLE, new long[] { 16, 8, 2 } },
{ ValueLayout.JAVA_DOUBLE, new long[] { 8, 16, 2 } },
{ POINT, new long[] { 256 } },
{ POINT, new long[] { 16, 16 } },
{ POINT, new long[] { 4, 4, 4, 4 } },
{ POINT, new long[] { 2, 8, 16 } },
{ POINT, new long[] { 16, 8, 2 } },
{ POINT, new long[] { 8, 16, 2 } },
};
}
}