8247696: Incorrect tail computation for large segments in AbstractMemorySegmentImpl::mismatch

Reviewed-by: psandoz, mcimadamore
This commit is contained in:
Chris Hegarty 2020-06-23 10:09:26 +01:00
parent 6469685285
commit 7f69acc778
4 changed files with 99 additions and 18 deletions

View File

@ -163,19 +163,24 @@ public class ArraysSupport {
/**
* Mismatch over long lengths.
*/
public static long vectorizedMismatchLarge(Object a, long aOffset,
Object b, long bOffset,
long length,
int log2ArrayIndexScale) {
public static long vectorizedMismatchLargeForBytes(Object a, long aOffset,
Object b, long bOffset,
long length) {
long off = 0;
long remaining = length;
int i ;
while (remaining > 7) {
int size = (int) Math.min(Integer.MAX_VALUE, remaining);
int i, size;
boolean lastSubRange = false;
while (remaining > 7 && !lastSubRange) {
if (remaining > Integer.MAX_VALUE) {
size = Integer.MAX_VALUE;
} else {
size = (int) remaining;
lastSubRange = true;
}
i = vectorizedMismatch(
a, aOffset + off,
b, bOffset + off,
size, log2ArrayIndexScale);
size, LOG2_ARRAY_BYTE_INDEX_SCALE);
if (i >= 0)
return off + i;
@ -183,7 +188,7 @@ public class ArraysSupport {
off += i;
remaining -= i;
}
return ~off;
return ~remaining;
}
// Booleans

View File

@ -149,14 +149,19 @@ public abstract class AbstractMemorySegmentImpl implements MemorySegment, Memory
long i = 0;
if (length > 7) {
i = ArraysSupport.vectorizedMismatchLarge(
if ((byte) BYTE_HANDLE.get(this.baseAddress(), 0) != (byte) BYTE_HANDLE.get(that.baseAddress(), 0)) {
return 0;
}
i = ArraysSupport.vectorizedMismatchLargeForBytes(
this.base(), this.min(),
that.base(), that.min(),
length, ArraysSupport.LOG2_ARRAY_BYTE_INDEX_SCALE);
length);
if (i >= 0) {
return i;
}
i = length - ~i;
long remaining = ~i;
assert remaining < 8 : "remaining greater than 7: " + remaining;
i = length - remaining;
}
MemoryAddress thisAddress = this.baseAddress();
MemoryAddress thatAddress = that.baseAddress();

View File

@ -117,12 +117,28 @@ public class TestMismatch {
assertEquals(s1.mismatch(s2), -1);
assertEquals(s2.mismatch(s1), -1);
for (long i = s2.byteSize() -1 ; i >= Integer.MAX_VALUE - 10L; i--) {
BYTE_HANDLE.set(s2.baseAddress().addOffset(i), (byte) 0xFF);
long expectedMismatchOffset = i;
assertEquals(s1.mismatch(s2), expectedMismatchOffset);
assertEquals(s2.mismatch(s1), expectedMismatchOffset);
}
testLargeAcrossMaxBoundary(s1, s2);
testLargeMismatchAcrossMaxBoundary(s1, s2);
}
}
private void testLargeAcrossMaxBoundary(MemorySegment s1, MemorySegment s2) {
for (long i = s2.byteSize() -1 ; i >= Integer.MAX_VALUE - 10L; i--) {
var s3 = s1.asSlice(0, i);
var s4 = s2.asSlice(0, i);
assertEquals(s3.mismatch(s3), -1);
assertEquals(s3.mismatch(s4), -1);
assertEquals(s4.mismatch(s3), -1);
}
}
private void testLargeMismatchAcrossMaxBoundary(MemorySegment s1, MemorySegment s2) {
for (long i = s2.byteSize() -1 ; i >= Integer.MAX_VALUE - 10L; i--) {
BYTE_HANDLE.set(s2.baseAddress().addOffset(i), (byte) 0xFF);
long expectedMismatchOffset = i;
assertEquals(s1.mismatch(s2), expectedMismatchOffset);
assertEquals(s2.mismatch(s1), expectedMismatchOffset);
}
}

View File

@ -35,6 +35,7 @@ import org.openjdk.jmh.annotations.Warmup;
import sun.misc.Unsafe;
import jdk.incubator.foreign.MemorySegment;
import java.nio.ByteBuffer;
import java.util.concurrent.TimeUnit;
import static jdk.incubator.foreign.MemoryLayouts.JAVA_INT;
@ -60,6 +61,36 @@ public class BulkOps {
static final MemorySegment bytesSegment = MemorySegment.ofArray(bytes);
static final int UNSAFE_INT_OFFSET = unsafe.arrayBaseOffset(int[].class);
// large(ish) segments/buffers with same content, 0, for mismatch, non-multiple-of-8 sized
static final int SIZE_WITH_TAIL = (1024 * 1024) + 7;
static final MemorySegment mismatchSegmentLarge1 = MemorySegment.allocateNative(SIZE_WITH_TAIL);
static final MemorySegment mismatchSegmentLarge2 = MemorySegment.allocateNative(SIZE_WITH_TAIL);
static final ByteBuffer mismatchBufferLarge1 = ByteBuffer.allocateDirect(SIZE_WITH_TAIL);
static final ByteBuffer mismatchBufferLarge2 = ByteBuffer.allocateDirect(SIZE_WITH_TAIL);
// mismatch at first byte
static final MemorySegment mismatchSegmentSmall1 = MemorySegment.allocateNative(7);
static final MemorySegment mismatchSegmentSmall2 = MemorySegment.allocateNative(7);
static final ByteBuffer mismatchBufferSmall1 = ByteBuffer.allocateDirect(7);
static final ByteBuffer mismatchBufferSmall2 = ByteBuffer.allocateDirect(7);
static {
mismatchSegmentSmall1.fill((byte) 0xFF);
mismatchBufferSmall1.put((byte) 0xFF).clear();
// verify expected mismatch indices
long si = mismatchSegmentLarge1.mismatch(mismatchSegmentLarge2);
if (si != -1)
throw new AssertionError("Unexpected mismatch index:" + si);
int bi = mismatchBufferLarge1.mismatch(mismatchBufferLarge2);
if (bi != -1)
throw new AssertionError("Unexpected mismatch index:" + bi);
si = mismatchSegmentSmall1.mismatch(mismatchSegmentSmall2);
if (si != 0)
throw new AssertionError("Unexpected mismatch index:" + si);
bi = mismatchBufferSmall1.mismatch(mismatchBufferSmall2);
if (bi != 0)
throw new AssertionError("Unexpected mismatch index:" + bi);
}
static {
for (int i = 0 ; i < bytes.length ; i++) {
bytes[i] = i;
@ -89,4 +120,28 @@ public class BulkOps {
public void segment_copy() {
segment.copyFrom(bytesSegment);
}
@Benchmark
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public long mismatch_large_segment() {
return mismatchSegmentLarge1.mismatch(mismatchSegmentLarge2);
}
@Benchmark
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public int mismatch_large_bytebuffer() {
return mismatchBufferLarge1.mismatch(mismatchBufferLarge2);
}
@Benchmark
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public long mismatch_small_segment() {
return mismatchSegmentSmall1.mismatch(mismatchSegmentSmall2);
}
@Benchmark
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public int mismatch_small_bytebuffer() {
return mismatchBufferSmall1.mismatch(mismatchBufferSmall2);
}
}