8161608: StampedLock should use storeStoreFence when acquiring write lock

Reviewed-by: martin, psandoz, plevart
This commit is contained in:
Doug Lea 2016-07-26 10:02:05 -07:00
parent 825edd9274
commit 2545e51a0c
2 changed files with 183 additions and 35 deletions

View File

@ -256,8 +256,12 @@ public class StampedLock implements java.io.Serializable {
* method validate()) requires stricter ordering rules than apply
* to normal volatile reads (of "state"). To force orderings of
* reads before a validation and the validation itself in those
* cases where this is not already forced, we use
* VarHandle.acquireFence.
* cases where this is not already forced, we use acquireFence.
* Unlike in that paper, we allow writers to use plain writes.
* One would not expect reorderings of such writes with the lock
* acquisition CAS because there is a "control dependency", but it
* is theoretically possible, so we additionally add a
* storeStoreFence after lock acquisition CAS.
*
* The memory layout keeps lock state and queue pointers together
* (normally on the same cache line). This usually works well for
@ -355,6 +359,20 @@ public class StampedLock implements java.io.Serializable {
state = ORIGIN;
}
private boolean casState(long expectedValue, long newValue) {
return STATE.compareAndSet(this, expectedValue, newValue);
}
private long tryWriteLock(long s) {
// assert (s & ABITS) == 0L;
long next;
if (casState(s, next = s | WBIT)) {
VarHandle.storeStoreFence();
return next;
}
return 0L;
}
/**
* Exclusively acquires the lock, blocking if necessary
* until available.
@ -363,10 +381,8 @@ public class StampedLock implements java.io.Serializable {
*/
@ReservedStackAccess
public long writeLock() {
long s, next; // bypass acquireWrite in fully unlocked case only
return ((((s = state) & ABITS) == 0L &&
STATE.compareAndSet(this, s, next = s + WBIT)) ?
next : acquireWrite(false, 0L));
long next;
return ((next = tryWriteLock()) != 0L) ? next : acquireWrite(false, 0L);
}
/**
@ -377,10 +393,8 @@ public class StampedLock implements java.io.Serializable {
*/
@ReservedStackAccess
public long tryWriteLock() {
long s, next;
return ((((s = state) & ABITS) == 0L &&
STATE.compareAndSet(this, s, next = s + WBIT)) ?
next : 0L);
long s;
return (((s = state) & ABITS) == 0L) ? tryWriteLock(s) : 0L;
}
/**
@ -440,10 +454,13 @@ public class StampedLock implements java.io.Serializable {
*/
@ReservedStackAccess
public long readLock() {
long s = state, next; // bypass acquireRead on common uncontended case
return ((whead == wtail && (s & ABITS) < RFULL &&
STATE.compareAndSet(this, s, next = s + RUNIT)) ?
next : acquireRead(false, 0L));
long s, next;
// bypass acquireRead on common uncontended case
return (whead == wtail
&& ((s = state) & ABITS) < RFULL
&& casState(s, next = s + RUNIT))
? next
: acquireRead(false, 0L);
}
/**
@ -457,7 +474,7 @@ public class StampedLock implements java.io.Serializable {
long s, m, next;
while ((m = (s = state) & ABITS) != WBIT) {
if (m < RFULL) {
if (STATE.compareAndSet(this, s, next = s + RUNIT))
if (casState(s, next = s + RUNIT))
return next;
}
else if ((next = tryIncReaderOverflow(s)) != 0L)
@ -487,7 +504,7 @@ public class StampedLock implements java.io.Serializable {
if (!Thread.interrupted()) {
if ((m = (s = state) & ABITS) != WBIT) {
if (m < RFULL) {
if (STATE.compareAndSet(this, s, next = s + RUNIT))
if (casState(s, next = s + RUNIT))
return next;
}
else if ((next = tryIncReaderOverflow(s)) != 0L)
@ -514,10 +531,15 @@ public class StampedLock implements java.io.Serializable {
* before acquiring the lock
*/
@ReservedStackAccess
public long readLockInterruptibly() throws InterruptedException {
long next;
if (!Thread.interrupted() &&
(next = acquireRead(true, 0L)) != INTERRUPTED)
public long readLockInterruptibly() throws InterruptedException {
long s, next;
if (!Thread.interrupted()
// bypass acquireRead on common uncontended case
&& ((whead == wtail
&& ((s = state) & ABITS) < RFULL
&& casState(s, next = s + RUNIT))
||
(next = acquireRead(true, 0L)) != INTERRUPTED))
return next;
throw new InterruptedException();
}
@ -598,7 +620,7 @@ public class StampedLock implements java.io.Serializable {
&& (stamp & RBITS) > 0L
&& ((m = s & RBITS) > 0L)) {
if (m < RFULL) {
if (STATE.compareAndSet(this, s, s - RUNIT)) {
if (casState(s, s - RUNIT)) {
if (m == RUNIT && (h = whead) != null && h.status != 0)
release(h);
return;
@ -620,7 +642,7 @@ public class StampedLock implements java.io.Serializable {
*/
@ReservedStackAccess
public void unlock(long stamp) {
if ((stamp & WBIT) != 0)
if ((stamp & WBIT) != 0L)
unlockWrite(stamp);
else
unlockRead(stamp);
@ -644,7 +666,7 @@ public class StampedLock implements java.io.Serializable {
if ((m = s & ABITS) == 0L) {
if (a != 0L)
break;
if (STATE.compareAndSet(this, s, next = s + WBIT))
if ((next = tryWriteLock(s)) != 0L)
return next;
}
else if (m == WBIT) {
@ -653,8 +675,10 @@ public class StampedLock implements java.io.Serializable {
return stamp;
}
else if (m == RUNIT && a != 0L) {
if (STATE.compareAndSet(this, s, next = s - RUNIT + WBIT))
if (casState(s, next = s - RUNIT + WBIT)) {
VarHandle.storeStoreFence();
return next;
}
}
else
break;
@ -688,7 +712,7 @@ public class StampedLock implements java.io.Serializable {
else if (a == 0L) {
// optimistic read stamp
if ((s & ABITS) < RFULL) {
if (STATE.compareAndSet(this, s, next = s + RUNIT))
if (casState(s, next = s + RUNIT))
return next;
}
else if ((next = tryIncReaderOverflow(s)) != 0L)
@ -730,7 +754,7 @@ public class StampedLock implements java.io.Serializable {
else if ((m = s & ABITS) == 0L) // invalid read stamp
break;
else if (m < RFULL) {
if (STATE.compareAndSet(this, s, next = s - RUNIT)) {
if (casState(s, next = s - RUNIT)) {
if (m == RUNIT && (h = whead) != null && h.status != 0)
release(h);
return next & SBITS;
@ -771,7 +795,7 @@ public class StampedLock implements java.io.Serializable {
long s, m; WNode h;
while ((m = (s = state) & ABITS) != 0L && m < WBIT) {
if (m < RFULL) {
if (STATE.compareAndSet(this, s, s - RUNIT)) {
if (casState(s, s - RUNIT)) {
if (m == RUNIT && (h = whead) != null && h.status != 0)
release(h);
return true;
@ -940,7 +964,7 @@ public class StampedLock implements java.io.Serializable {
long s, m; WNode h;
while ((m = (s = state) & RBITS) > 0L) {
if (m < RFULL) {
if (STATE.compareAndSet(this, s, s - RUNIT)) {
if (casState(s, s - RUNIT)) {
if (m == RUNIT && (h = whead) != null && h.status != 0)
release(h);
return;
@ -971,7 +995,7 @@ public class StampedLock implements java.io.Serializable {
private long tryIncReaderOverflow(long s) {
// assert (s & ABITS) >= RFULL;
if ((s & ABITS) == RFULL) {
if (STATE.compareAndSet(this, s, s | RBITS)) {
if (casState(s, s | RBITS)) {
++readerOverflow;
STATE.setVolatile(this, s);
return s;
@ -993,7 +1017,7 @@ public class StampedLock implements java.io.Serializable {
private long tryDecReaderOverflow(long s) {
// assert (s & ABITS) >= RFULL;
if ((s & ABITS) == RFULL) {
if (STATE.compareAndSet(this, s, s | RBITS)) {
if (casState(s, s | RBITS)) {
int r; long next;
if ((r = readerOverflow) > 0) {
readerOverflow = r - 1;
@ -1047,7 +1071,7 @@ public class StampedLock implements java.io.Serializable {
for (int spins = -1;;) { // spin while enqueuing
long m, s, ns;
if ((m = (s = state) & ABITS) == 0L) {
if (STATE.compareAndSet(this, s, ns = s + WBIT))
if ((ns = tryWriteLock(s)) != 0L)
return ns;
}
else if (spins < 0)
@ -1082,7 +1106,7 @@ public class StampedLock implements java.io.Serializable {
for (int k = spins; k > 0; --k) { // spin at head
long s, ns;
if (((s = state) & ABITS) == 0L) {
if (STATE.compareAndSet(this, s, ns = s + WBIT)) {
if ((ns = tryWriteLock(s)) != 0L) {
whead = node;
node.prev = null;
if (wasInterrupted)
@ -1158,7 +1182,7 @@ public class StampedLock implements java.io.Serializable {
if ((h = whead) == (p = wtail)) {
for (long m, s, ns;;) {
if ((m = (s = state) & ABITS) < RFULL ?
STATE.compareAndSet(this, s, ns = s + RUNIT) :
casState(s, ns = s + RUNIT) :
(m < WBIT && (ns = tryIncReaderOverflow(s)) != 0L)) {
if (wasInterrupted)
Thread.currentThread().interrupt();
@ -1208,7 +1232,7 @@ public class StampedLock implements java.io.Serializable {
long m, s, ns;
do {
if ((m = (s = state) & ABITS) < RFULL ?
STATE.compareAndSet(this, s, ns = s + RUNIT) :
casState(s, ns = s + RUNIT) :
(m < WBIT &&
(ns = tryIncReaderOverflow(s)) != 0L)) {
if (wasInterrupted)
@ -1260,7 +1284,7 @@ public class StampedLock implements java.io.Serializable {
for (int k = spins;;) { // spin at head
long m, s, ns;
if ((m = (s = state) & ABITS) < RFULL ?
STATE.compareAndSet(this, s, ns = s + RUNIT) :
casState(s, ns = s + RUNIT) :
(m < WBIT && (ns = tryIncReaderOverflow(s)) != 0L)) {
WNode c; Thread w;
whead = node;

View File

@ -32,11 +32,18 @@
* http://creativecommons.org/publicdomain/zero/1.0/
*/
import static java.util.concurrent.TimeUnit.DAYS;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.StampedLock;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
import junit.framework.Test;
import junit.framework.TestSuite;
@ -1078,4 +1085,121 @@ public class StampedLockTest extends JSR166TestCase {
assertThrows(IllegalMonitorStateException.class, actions);
}
static long writeLockInterruptiblyUninterrupted(StampedLock sl) {
try { return sl.writeLockInterruptibly(); }
catch (InterruptedException ex) { throw new AssertionError(ex); }
}
static long tryWriteLockUninterrupted(StampedLock sl, long time, TimeUnit unit) {
try { return sl.tryWriteLock(time, unit); }
catch (InterruptedException ex) { throw new AssertionError(ex); }
}
static long readLockInterruptiblyUninterrupted(StampedLock sl) {
try { return sl.readLockInterruptibly(); }
catch (InterruptedException ex) { throw new AssertionError(ex); }
}
static long tryReadLockUninterrupted(StampedLock sl, long time, TimeUnit unit) {
try { return sl.tryReadLock(time, unit); }
catch (InterruptedException ex) { throw new AssertionError(ex); }
}
/**
* Invalid write stamps result in IllegalMonitorStateException
*/
public void testInvalidWriteStampsThrowIllegalMonitorStateException() {
List<Function<StampedLock, Long>> writeLockers = new ArrayList<>();
writeLockers.add((sl) -> sl.writeLock());
writeLockers.add((sl) -> writeLockInterruptiblyUninterrupted(sl));
writeLockers.add((sl) -> tryWriteLockUninterrupted(sl, Long.MIN_VALUE, DAYS));
writeLockers.add((sl) -> tryWriteLockUninterrupted(sl, 0, DAYS));
List<BiConsumer<StampedLock, Long>> writeUnlockers = new ArrayList<>();
writeUnlockers.add((sl, stamp) -> sl.unlockWrite(stamp));
writeUnlockers.add((sl, stamp) -> assertTrue(sl.tryUnlockWrite()));
writeUnlockers.add((sl, stamp) -> sl.asWriteLock().unlock());
writeUnlockers.add((sl, stamp) -> sl.unlock(stamp));
List<Consumer<StampedLock>> mutaters = new ArrayList<>();
mutaters.add((sl) -> {});
mutaters.add((sl) -> sl.readLock());
for (Function<StampedLock, Long> writeLocker : writeLockers)
mutaters.add((sl) -> writeLocker.apply(sl));
for (Function<StampedLock, Long> writeLocker : writeLockers)
for (BiConsumer<StampedLock, Long> writeUnlocker : writeUnlockers)
for (Consumer<StampedLock> mutater : mutaters) {
final StampedLock sl = new StampedLock();
final long stamp = writeLocker.apply(sl);
assertTrue(stamp != 0L);
assertThrows(IllegalMonitorStateException.class,
() -> sl.unlockRead(stamp));
writeUnlocker.accept(sl, stamp);
mutater.accept(sl);
assertThrows(IllegalMonitorStateException.class,
() -> sl.unlock(stamp),
() -> sl.unlockRead(stamp),
() -> sl.unlockWrite(stamp));
}
}
/**
* Invalid read stamps result in IllegalMonitorStateException
*/
public void testInvalidReadStampsThrowIllegalMonitorStateException() {
List<Function<StampedLock, Long>> readLockers = new ArrayList<>();
readLockers.add((sl) -> sl.readLock());
readLockers.add((sl) -> readLockInterruptiblyUninterrupted(sl));
readLockers.add((sl) -> tryReadLockUninterrupted(sl, Long.MIN_VALUE, DAYS));
readLockers.add((sl) -> tryReadLockUninterrupted(sl, 0, DAYS));
List<BiConsumer<StampedLock, Long>> readUnlockers = new ArrayList<>();
readUnlockers.add((sl, stamp) -> sl.unlockRead(stamp));
readUnlockers.add((sl, stamp) -> assertTrue(sl.tryUnlockRead()));
readUnlockers.add((sl, stamp) -> sl.asReadLock().unlock());
readUnlockers.add((sl, stamp) -> sl.unlock(stamp));
List<Function<StampedLock, Long>> writeLockers = new ArrayList<>();
writeLockers.add((sl) -> sl.writeLock());
writeLockers.add((sl) -> writeLockInterruptiblyUninterrupted(sl));
writeLockers.add((sl) -> tryWriteLockUninterrupted(sl, Long.MIN_VALUE, DAYS));
writeLockers.add((sl) -> tryWriteLockUninterrupted(sl, 0, DAYS));
List<BiConsumer<StampedLock, Long>> writeUnlockers = new ArrayList<>();
writeUnlockers.add((sl, stamp) -> sl.unlockWrite(stamp));
writeUnlockers.add((sl, stamp) -> assertTrue(sl.tryUnlockWrite()));
writeUnlockers.add((sl, stamp) -> sl.asWriteLock().unlock());
writeUnlockers.add((sl, stamp) -> sl.unlock(stamp));
for (Function<StampedLock, Long> readLocker : readLockers)
for (BiConsumer<StampedLock, Long> readUnlocker : readUnlockers)
for (Function<StampedLock, Long> writeLocker : writeLockers)
for (BiConsumer<StampedLock, Long> writeUnlocker : writeUnlockers) {
final StampedLock sl = new StampedLock();
final long stamp = readLocker.apply(sl);
assertTrue(stamp != 0L);
assertThrows(IllegalMonitorStateException.class,
() -> sl.unlockWrite(stamp));
readUnlocker.accept(sl, stamp);
assertThrows(IllegalMonitorStateException.class,
() -> sl.unlock(stamp),
() -> sl.unlockRead(stamp),
() -> sl.unlockWrite(stamp));
final long writeStamp = writeLocker.apply(sl);
assertTrue(writeStamp != 0L);
assertTrue(writeStamp != stamp);
assertThrows(IllegalMonitorStateException.class,
() -> sl.unlock(stamp),
() -> sl.unlockRead(stamp),
() -> sl.unlockWrite(stamp));
writeUnlocker.accept(sl, writeStamp);
assertThrows(IllegalMonitorStateException.class,
() -> sl.unlock(stamp),
() -> sl.unlockRead(stamp),
() -> sl.unlockWrite(stamp));
}
}
}