diff --git a/jdk/src/java.base/share/classes/java/util/concurrent/locks/StampedLock.java b/jdk/src/java.base/share/classes/java/util/concurrent/locks/StampedLock.java index b90aef6206e..b007db93d88 100644 --- a/jdk/src/java.base/share/classes/java/util/concurrent/locks/StampedLock.java +++ b/jdk/src/java.base/share/classes/java/util/concurrent/locks/StampedLock.java @@ -256,8 +256,12 @@ public class StampedLock implements java.io.Serializable { * method validate()) requires stricter ordering rules than apply * to normal volatile reads (of "state"). To force orderings of * reads before a validation and the validation itself in those - * cases where this is not already forced, we use - * VarHandle.acquireFence. + * cases where this is not already forced, we use acquireFence. + * Unlike in that paper, we allow writers to use plain writes. + * One would not expect reorderings of such writes with the lock + * acquisition CAS because there is a "control dependency", but it + * is theoretically possible, so we additionally add a + * storeStoreFence after lock acquisition CAS. * * The memory layout keeps lock state and queue pointers together * (normally on the same cache line). This usually works well for @@ -355,6 +359,20 @@ public class StampedLock implements java.io.Serializable { state = ORIGIN; } + private boolean casState(long expectedValue, long newValue) { + return STATE.compareAndSet(this, expectedValue, newValue); + } + + private long tryWriteLock(long s) { + // assert (s & ABITS) == 0L; + long next; + if (casState(s, next = s | WBIT)) { + VarHandle.storeStoreFence(); + return next; + } + return 0L; + } + /** * Exclusively acquires the lock, blocking if necessary * until available. @@ -363,10 +381,8 @@ public class StampedLock implements java.io.Serializable { */ @ReservedStackAccess public long writeLock() { - long s, next; // bypass acquireWrite in fully unlocked case only - return ((((s = state) & ABITS) == 0L && - STATE.compareAndSet(this, s, next = s + WBIT)) ? - next : acquireWrite(false, 0L)); + long next; + return ((next = tryWriteLock()) != 0L) ? next : acquireWrite(false, 0L); } /** @@ -377,10 +393,8 @@ public class StampedLock implements java.io.Serializable { */ @ReservedStackAccess public long tryWriteLock() { - long s, next; - return ((((s = state) & ABITS) == 0L && - STATE.compareAndSet(this, s, next = s + WBIT)) ? - next : 0L); + long s; + return (((s = state) & ABITS) == 0L) ? tryWriteLock(s) : 0L; } /** @@ -440,10 +454,13 @@ public class StampedLock implements java.io.Serializable { */ @ReservedStackAccess public long readLock() { - long s = state, next; // bypass acquireRead on common uncontended case - return ((whead == wtail && (s & ABITS) < RFULL && - STATE.compareAndSet(this, s, next = s + RUNIT)) ? - next : acquireRead(false, 0L)); + long s, next; + // bypass acquireRead on common uncontended case + return (whead == wtail + && ((s = state) & ABITS) < RFULL + && casState(s, next = s + RUNIT)) + ? next + : acquireRead(false, 0L); } /** @@ -457,7 +474,7 @@ public class StampedLock implements java.io.Serializable { long s, m, next; while ((m = (s = state) & ABITS) != WBIT) { if (m < RFULL) { - if (STATE.compareAndSet(this, s, next = s + RUNIT)) + if (casState(s, next = s + RUNIT)) return next; } else if ((next = tryIncReaderOverflow(s)) != 0L) @@ -487,7 +504,7 @@ public class StampedLock implements java.io.Serializable { if (!Thread.interrupted()) { if ((m = (s = state) & ABITS) != WBIT) { if (m < RFULL) { - if (STATE.compareAndSet(this, s, next = s + RUNIT)) + if (casState(s, next = s + RUNIT)) return next; } else if ((next = tryIncReaderOverflow(s)) != 0L) @@ -514,10 +531,15 @@ public class StampedLock implements java.io.Serializable { * before acquiring the lock */ @ReservedStackAccess - public long readLockInterruptibly() throws InterruptedException { - long next; - if (!Thread.interrupted() && - (next = acquireRead(true, 0L)) != INTERRUPTED) + public long readLockInterruptibly() throws InterruptedException { + long s, next; + if (!Thread.interrupted() + // bypass acquireRead on common uncontended case + && ((whead == wtail + && ((s = state) & ABITS) < RFULL + && casState(s, next = s + RUNIT)) + || + (next = acquireRead(true, 0L)) != INTERRUPTED)) return next; throw new InterruptedException(); } @@ -598,7 +620,7 @@ public class StampedLock implements java.io.Serializable { && (stamp & RBITS) > 0L && ((m = s & RBITS) > 0L)) { if (m < RFULL) { - if (STATE.compareAndSet(this, s, s - RUNIT)) { + if (casState(s, s - RUNIT)) { if (m == RUNIT && (h = whead) != null && h.status != 0) release(h); return; @@ -620,7 +642,7 @@ public class StampedLock implements java.io.Serializable { */ @ReservedStackAccess public void unlock(long stamp) { - if ((stamp & WBIT) != 0) + if ((stamp & WBIT) != 0L) unlockWrite(stamp); else unlockRead(stamp); @@ -644,7 +666,7 @@ public class StampedLock implements java.io.Serializable { if ((m = s & ABITS) == 0L) { if (a != 0L) break; - if (STATE.compareAndSet(this, s, next = s + WBIT)) + if ((next = tryWriteLock(s)) != 0L) return next; } else if (m == WBIT) { @@ -653,8 +675,10 @@ public class StampedLock implements java.io.Serializable { return stamp; } else if (m == RUNIT && a != 0L) { - if (STATE.compareAndSet(this, s, next = s - RUNIT + WBIT)) + if (casState(s, next = s - RUNIT + WBIT)) { + VarHandle.storeStoreFence(); return next; + } } else break; @@ -688,7 +712,7 @@ public class StampedLock implements java.io.Serializable { else if (a == 0L) { // optimistic read stamp if ((s & ABITS) < RFULL) { - if (STATE.compareAndSet(this, s, next = s + RUNIT)) + if (casState(s, next = s + RUNIT)) return next; } else if ((next = tryIncReaderOverflow(s)) != 0L) @@ -730,7 +754,7 @@ public class StampedLock implements java.io.Serializable { else if ((m = s & ABITS) == 0L) // invalid read stamp break; else if (m < RFULL) { - if (STATE.compareAndSet(this, s, next = s - RUNIT)) { + if (casState(s, next = s - RUNIT)) { if (m == RUNIT && (h = whead) != null && h.status != 0) release(h); return next & SBITS; @@ -771,7 +795,7 @@ public class StampedLock implements java.io.Serializable { long s, m; WNode h; while ((m = (s = state) & ABITS) != 0L && m < WBIT) { if (m < RFULL) { - if (STATE.compareAndSet(this, s, s - RUNIT)) { + if (casState(s, s - RUNIT)) { if (m == RUNIT && (h = whead) != null && h.status != 0) release(h); return true; @@ -940,7 +964,7 @@ public class StampedLock implements java.io.Serializable { long s, m; WNode h; while ((m = (s = state) & RBITS) > 0L) { if (m < RFULL) { - if (STATE.compareAndSet(this, s, s - RUNIT)) { + if (casState(s, s - RUNIT)) { if (m == RUNIT && (h = whead) != null && h.status != 0) release(h); return; @@ -971,7 +995,7 @@ public class StampedLock implements java.io.Serializable { private long tryIncReaderOverflow(long s) { // assert (s & ABITS) >= RFULL; if ((s & ABITS) == RFULL) { - if (STATE.compareAndSet(this, s, s | RBITS)) { + if (casState(s, s | RBITS)) { ++readerOverflow; STATE.setVolatile(this, s); return s; @@ -993,7 +1017,7 @@ public class StampedLock implements java.io.Serializable { private long tryDecReaderOverflow(long s) { // assert (s & ABITS) >= RFULL; if ((s & ABITS) == RFULL) { - if (STATE.compareAndSet(this, s, s | RBITS)) { + if (casState(s, s | RBITS)) { int r; long next; if ((r = readerOverflow) > 0) { readerOverflow = r - 1; @@ -1047,7 +1071,7 @@ public class StampedLock implements java.io.Serializable { for (int spins = -1;;) { // spin while enqueuing long m, s, ns; if ((m = (s = state) & ABITS) == 0L) { - if (STATE.compareAndSet(this, s, ns = s + WBIT)) + if ((ns = tryWriteLock(s)) != 0L) return ns; } else if (spins < 0) @@ -1082,7 +1106,7 @@ public class StampedLock implements java.io.Serializable { for (int k = spins; k > 0; --k) { // spin at head long s, ns; if (((s = state) & ABITS) == 0L) { - if (STATE.compareAndSet(this, s, ns = s + WBIT)) { + if ((ns = tryWriteLock(s)) != 0L) { whead = node; node.prev = null; if (wasInterrupted) @@ -1158,7 +1182,7 @@ public class StampedLock implements java.io.Serializable { if ((h = whead) == (p = wtail)) { for (long m, s, ns;;) { if ((m = (s = state) & ABITS) < RFULL ? - STATE.compareAndSet(this, s, ns = s + RUNIT) : + casState(s, ns = s + RUNIT) : (m < WBIT && (ns = tryIncReaderOverflow(s)) != 0L)) { if (wasInterrupted) Thread.currentThread().interrupt(); @@ -1208,7 +1232,7 @@ public class StampedLock implements java.io.Serializable { long m, s, ns; do { if ((m = (s = state) & ABITS) < RFULL ? - STATE.compareAndSet(this, s, ns = s + RUNIT) : + casState(s, ns = s + RUNIT) : (m < WBIT && (ns = tryIncReaderOverflow(s)) != 0L)) { if (wasInterrupted) @@ -1260,7 +1284,7 @@ public class StampedLock implements java.io.Serializable { for (int k = spins;;) { // spin at head long m, s, ns; if ((m = (s = state) & ABITS) < RFULL ? - STATE.compareAndSet(this, s, ns = s + RUNIT) : + casState(s, ns = s + RUNIT) : (m < WBIT && (ns = tryIncReaderOverflow(s)) != 0L)) { WNode c; Thread w; whead = node; diff --git a/jdk/test/java/util/concurrent/tck/StampedLockTest.java b/jdk/test/java/util/concurrent/tck/StampedLockTest.java index cf4f271768b..d71d6546aca 100644 --- a/jdk/test/java/util/concurrent/tck/StampedLockTest.java +++ b/jdk/test/java/util/concurrent/tck/StampedLockTest.java @@ -32,11 +32,18 @@ * http://creativecommons.org/publicdomain/zero/1.0/ */ +import static java.util.concurrent.TimeUnit.DAYS; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.StampedLock; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Function; import junit.framework.Test; import junit.framework.TestSuite; @@ -1078,4 +1085,121 @@ public class StampedLockTest extends JSR166TestCase { assertThrows(IllegalMonitorStateException.class, actions); } + static long writeLockInterruptiblyUninterrupted(StampedLock sl) { + try { return sl.writeLockInterruptibly(); } + catch (InterruptedException ex) { throw new AssertionError(ex); } + } + + static long tryWriteLockUninterrupted(StampedLock sl, long time, TimeUnit unit) { + try { return sl.tryWriteLock(time, unit); } + catch (InterruptedException ex) { throw new AssertionError(ex); } + } + + static long readLockInterruptiblyUninterrupted(StampedLock sl) { + try { return sl.readLockInterruptibly(); } + catch (InterruptedException ex) { throw new AssertionError(ex); } + } + + static long tryReadLockUninterrupted(StampedLock sl, long time, TimeUnit unit) { + try { return sl.tryReadLock(time, unit); } + catch (InterruptedException ex) { throw new AssertionError(ex); } + } + + /** + * Invalid write stamps result in IllegalMonitorStateException + */ + public void testInvalidWriteStampsThrowIllegalMonitorStateException() { + List> writeLockers = new ArrayList<>(); + writeLockers.add((sl) -> sl.writeLock()); + writeLockers.add((sl) -> writeLockInterruptiblyUninterrupted(sl)); + writeLockers.add((sl) -> tryWriteLockUninterrupted(sl, Long.MIN_VALUE, DAYS)); + writeLockers.add((sl) -> tryWriteLockUninterrupted(sl, 0, DAYS)); + + List> writeUnlockers = new ArrayList<>(); + writeUnlockers.add((sl, stamp) -> sl.unlockWrite(stamp)); + writeUnlockers.add((sl, stamp) -> assertTrue(sl.tryUnlockWrite())); + writeUnlockers.add((sl, stamp) -> sl.asWriteLock().unlock()); + writeUnlockers.add((sl, stamp) -> sl.unlock(stamp)); + + List> mutaters = new ArrayList<>(); + mutaters.add((sl) -> {}); + mutaters.add((sl) -> sl.readLock()); + for (Function writeLocker : writeLockers) + mutaters.add((sl) -> writeLocker.apply(sl)); + + for (Function writeLocker : writeLockers) + for (BiConsumer writeUnlocker : writeUnlockers) + for (Consumer mutater : mutaters) { + final StampedLock sl = new StampedLock(); + final long stamp = writeLocker.apply(sl); + assertTrue(stamp != 0L); + assertThrows(IllegalMonitorStateException.class, + () -> sl.unlockRead(stamp)); + writeUnlocker.accept(sl, stamp); + mutater.accept(sl); + assertThrows(IllegalMonitorStateException.class, + () -> sl.unlock(stamp), + () -> sl.unlockRead(stamp), + () -> sl.unlockWrite(stamp)); + } + } + + /** + * Invalid read stamps result in IllegalMonitorStateException + */ + public void testInvalidReadStampsThrowIllegalMonitorStateException() { + List> readLockers = new ArrayList<>(); + readLockers.add((sl) -> sl.readLock()); + readLockers.add((sl) -> readLockInterruptiblyUninterrupted(sl)); + readLockers.add((sl) -> tryReadLockUninterrupted(sl, Long.MIN_VALUE, DAYS)); + readLockers.add((sl) -> tryReadLockUninterrupted(sl, 0, DAYS)); + + List> readUnlockers = new ArrayList<>(); + readUnlockers.add((sl, stamp) -> sl.unlockRead(stamp)); + readUnlockers.add((sl, stamp) -> assertTrue(sl.tryUnlockRead())); + readUnlockers.add((sl, stamp) -> sl.asReadLock().unlock()); + readUnlockers.add((sl, stamp) -> sl.unlock(stamp)); + + List> writeLockers = new ArrayList<>(); + writeLockers.add((sl) -> sl.writeLock()); + writeLockers.add((sl) -> writeLockInterruptiblyUninterrupted(sl)); + writeLockers.add((sl) -> tryWriteLockUninterrupted(sl, Long.MIN_VALUE, DAYS)); + writeLockers.add((sl) -> tryWriteLockUninterrupted(sl, 0, DAYS)); + + List> writeUnlockers = new ArrayList<>(); + writeUnlockers.add((sl, stamp) -> sl.unlockWrite(stamp)); + writeUnlockers.add((sl, stamp) -> assertTrue(sl.tryUnlockWrite())); + writeUnlockers.add((sl, stamp) -> sl.asWriteLock().unlock()); + writeUnlockers.add((sl, stamp) -> sl.unlock(stamp)); + + + for (Function readLocker : readLockers) + for (BiConsumer readUnlocker : readUnlockers) + for (Function writeLocker : writeLockers) + for (BiConsumer writeUnlocker : writeUnlockers) { + final StampedLock sl = new StampedLock(); + final long stamp = readLocker.apply(sl); + assertTrue(stamp != 0L); + assertThrows(IllegalMonitorStateException.class, + () -> sl.unlockWrite(stamp)); + readUnlocker.accept(sl, stamp); + assertThrows(IllegalMonitorStateException.class, + () -> sl.unlock(stamp), + () -> sl.unlockRead(stamp), + () -> sl.unlockWrite(stamp)); + final long writeStamp = writeLocker.apply(sl); + assertTrue(writeStamp != 0L); + assertTrue(writeStamp != stamp); + assertThrows(IllegalMonitorStateException.class, + () -> sl.unlock(stamp), + () -> sl.unlockRead(stamp), + () -> sl.unlockWrite(stamp)); + writeUnlocker.accept(sl, writeStamp); + assertThrows(IllegalMonitorStateException.class, + () -> sl.unlock(stamp), + () -> sl.unlockRead(stamp), + () -> sl.unlockWrite(stamp)); + } + } + }