/* * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it * under the terms of the GNU General Public License version 2 only, as * published by the Free Software Foundation. * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * * You should have received a copy of the GNU General Public License version * 2 along with this work; if not, write to the Free Software Foundation, * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. * * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA * or visit www.oracle.com if you need additional information or have any * questions. * */ /* * @test * @run testng/othervm -Dforeign.restricted=permit TestSharedAccess */ import jdk.incubator.foreign.MemoryAddress; import jdk.incubator.foreign.MemoryLayout; import jdk.incubator.foreign.MemoryLayouts; import jdk.incubator.foreign.MemorySegment; import jdk.incubator.foreign.SequenceLayout; import org.testng.annotations.Test; import java.lang.invoke.VarHandle; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; import java.util.Spliterator; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; public class TestSharedAccess { static final VarHandle intHandle = MemoryLayouts.JAVA_INT.varHandle(int.class); @Test public void testConfined() throws Throwable { Thread owner = Thread.currentThread(); MemorySegment s = MemorySegment.allocateNative(4); AtomicReference confined = new AtomicReference<>(s); setInt(s.baseAddress(), 42); assertEquals(getInt(s.baseAddress()), 42); List threads = new ArrayList<>(); for (int i = 0 ; i < 1000 ; i++) { threads.add(new Thread(() -> { assertEquals(getInt(confined.get().baseAddress()), 42); confined.set(confined.get().withOwnerThread(owner)); })); } threads.forEach(t -> { confined.set(confined.get().withOwnerThread(t)); t.start(); try { t.join(); } catch (Throwable e) { throw new IllegalStateException(e); } }); confined.get().close(); } @Test public void testShared() throws Throwable { SequenceLayout layout = MemoryLayout.ofSequence(1024, MemoryLayouts.JAVA_INT); try (MemorySegment s = MemorySegment.allocateNative(layout)) { for (int i = 0 ; i < layout.elementCount().getAsLong() ; i++) { setInt(s.baseAddress().addOffset(i * 4), 42); } List threads = new ArrayList<>(); List> spliterators = new ArrayList<>(); spliterators.add(MemorySegment.spliterator(s, layout)); while (true) { boolean progress = false; List> newSpliterators = new ArrayList<>(); for (Spliterator spliterator : spliterators) { Spliterator sub = spliterator.trySplit(); if (sub != null) { progress = true; newSpliterators.add(sub); } } spliterators.addAll(newSpliterators); if (!progress) break; } AtomicInteger accessCount = new AtomicInteger(); for (Spliterator spliterator : spliterators) { threads.add(new Thread(() -> { spliterator.tryAdvance(local -> { assertEquals(getInt(local.baseAddress()), 42); accessCount.incrementAndGet(); }); })); } threads.forEach(Thread::start); threads.forEach(t -> { try { t.join(); } catch (Throwable e) { throw new IllegalStateException(e); } }); assertEquals(accessCount.get(), 1024); } } @Test public void testSharedUnsafe() throws Throwable { try (MemorySegment s = MemorySegment.allocateNative(4)) { setInt(s.baseAddress(), 42); assertEquals(getInt(s.baseAddress()), 42); List threads = new ArrayList<>(); MemorySegment sharedSegment = MemorySegment.ofNativeRestricted( s.baseAddress(), s.byteSize(), null, null, null); for (int i = 0 ; i < 1000 ; i++) { threads.add(new Thread(() -> { assertEquals(getInt(sharedSegment.baseAddress()), 42); })); } threads.forEach(Thread::start); threads.forEach(t -> { try { t.join(); } catch (Throwable e) { throw new IllegalStateException(e); } }); } } @Test(expectedExceptions=IllegalStateException.class) public void testBadCloseWithPendingAcquire() { withAcquired(MemorySegment::close); } @Test(expectedExceptions=IllegalStateException.class) public void testBadCloseWithPendingAcquireBuffer() { withAcquired(segment -> { segment = MemorySegment.ofByteBuffer(segment.asByteBuffer()); // original segment is lost segment.close(); // this should still fail }); } @Test(expectedExceptions=IllegalStateException.class) public void testBadHandoffWithPendingAcquire() { withAcquired(segment -> segment.withOwnerThread(new Thread())); } @Test(expectedExceptions=IllegalStateException.class) public void testBadHandoffWithPendingAcquireBuffer() { withAcquired(segment -> { segment = MemorySegment.ofByteBuffer(segment.asByteBuffer()); // original segment is lost segment.withOwnerThread(new Thread()); // this should still fail }); } @Test(expectedExceptions=IllegalArgumentException.class) public void testBadHandoffSameThread() { MemorySegment.ofArray(new int[4]).withOwnerThread(Thread.currentThread()); } @Test(expectedExceptions=NullPointerException.class) public void testBadHandoffNullThread() { MemorySegment.ofArray(new int[4]).withOwnerThread(null); } private void withAcquired(Consumer acquiredAction) { CountDownLatch holder = new CountDownLatch(1); MemorySegment segment = MemorySegment.allocateNative(16); Spliterator spliterator = MemorySegment.spliterator(segment, MemoryLayout.ofSequence(16, MemoryLayouts.JAVA_BYTE)); CountDownLatch acquired = new CountDownLatch(1); Runnable r = () -> spliterator.tryAdvance(s -> { try { acquired.countDown(); holder.await(); } catch (InterruptedException ex) { throw new AssertionError(ex); } }); new Thread(r).start(); try { acquired.await(); acquiredAction.accept(segment); } catch (InterruptedException ex) { throw new AssertionError(ex); } finally { holder.countDown(); } } @Test public void testOutsideConfinementThread() throws Throwable { CountDownLatch a = new CountDownLatch(1); CountDownLatch b = new CountDownLatch(1); CompletableFuture r; try (MemorySegment s1 = MemorySegment.allocateNative(MemoryLayout.ofSequence(2, MemoryLayouts.JAVA_INT))) { r = CompletableFuture.runAsync(() -> { try { ByteBuffer bb = s1.asByteBuffer(); MemorySegment s2 = MemorySegment.ofByteBuffer(bb); a.countDown(); try { b.await(); } catch (InterruptedException e) { } MemoryAddress base = s2.baseAddress(); setInt(base.addOffset(4), -42); fail(); } catch (IllegalStateException ex) { assertTrue(ex.getMessage().contains("owning thread")); } }); a.await(); MemoryAddress base = s1.baseAddress(); setInt(base.addOffset(4), 42); } b.countDown(); r.get(); } static int getInt(MemoryAddress address) { return (int)intHandle.getVolatile(address); } static void setInt(MemoryAddress address, int value) { intHandle.setVolatile(address, value); } }