676 lines
23 KiB
Java
Raw Normal View History

/*
* Copyright (c) 2010, 2020, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
/*
* @test
* @key randomness
*
* @summary converted from VM Testbase jit/escape/LockElision/MatMul.
* VM Testbase keywords: [jit, quick]
* VM Testbase readme:
* DESCRIPTION
* The test multiplies 2 matrices, first, by directly calculating matrix product
* elements, and second, by calculating them parallelly in diffenent threads.
* The results are compared then.
* The test, in addition to required locks, introduces locks on local variables or
* variables not escaping from the executing thread, and nests them manifoldly.
* In case of a buggy compiler, during lock elimination some code, required by
* the calulation may be eliminated as well, or the code may be overoptimized in
* some other way, causing difference in the execution results.
* The test has one parameter, -dim, which specifies the dimensions of matrices.
*
* @library /vmTestbase
* /test/lib
* @run main/othervm jit.escape.LockElision.MatMul.MatMul -dim 30 -threadCount 10
*/
package jit.escape.LockElision.MatMul;
import java.util.*;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import nsk.share.Consts;
import nsk.share.Log;
import nsk.share.Pair;
import nsk.share.test.StressOptions;
import vm.share.options.Option;
import vm.share.options.OptionSupport;
import vm.share.options.Options;
import jdk.test.lib.Utils;
public class MatMul {
@Option(name = "dim", description = "dimension of matrices")
int dim;
@Option(name = "verbose", default_value = "false",
description = "verbose mode")
boolean verbose;
@Option(name = "threadCount", description = "thread count")
int threadCount;
@Options
StressOptions stressOptions = new StressOptions();
private Log log;
public static void main(String[] args) {
MatMul test = new MatMul();
OptionSupport.setup(test, args);
System.exit(Consts.JCK_STATUS_BASE + test.run());
}
public int run() {
log = new Log(System.out, verbose);
log.display("Parallel matrix multiplication test");
Matrix a = Matrix.randomMatrix(dim);
Matrix b = Matrix.randomMatrix(dim);
long t1, t2;
t1 = System.currentTimeMillis();
Matrix serialResult = serialMul(a, b);
t2 = System.currentTimeMillis();
log.display("serial time: " + (t2 - t1) + "ms");
try {
t1 = System.currentTimeMillis();
Matrix parallelResult = parallelMul(a, b,
threadCount * stressOptions.getThreadsFactor());
t2 = System.currentTimeMillis();
log.display("parallel time: " + (t2 - t1) + "ms");
if (!serialResult.equals(parallelResult)) {
log.complain("a = \n" + a);
log.complain("b = \n" + b);
log.complain("serial: a * b = \n" + serialResult);
log.complain("serial: a * b = \n" + parallelResult);
return Consts.TEST_FAILED;
}
return Consts.TEST_PASSED;
} catch (CounterIncorrectStateException e) {
log.complain("incorrect state of counter " + e.counter.name);
log.complain("expected = " + e.counter.expected);
log.complain("actual " + e.counter.state());
return Consts.TEST_FAILED;
}
}
public static int convolution(Seq<Integer> one, Seq<Integer> two) {
int res = 0;
int upperBound = Math.min(one.size(), two.size());
for (int i = 0; i < upperBound; i++) {
res += one.get(i) * two.get(i);
}
return res;
}
/**
* calculate chunked convolutuion of two sequences
* <p/>
* This special version of this method:
* <pre>{@code
* public static int chunkedConvolution(Seq<Integer> one, Seq<Integer> two, int from, int to) {
* int res = 0;
* int upperBound = Math.min(Math.min(one.size(), two.size()), to + 1);
* for (int i = from; i < upperBound; i++) {
* res += one.get(i) * two.get(i);
* }
* return res;
* }}</pre>
* <p/>
* that tries to fool the Lock Elision optimization:
* Most lock objects in these lines are really thread local, so related synchronized blocks (dummy blocks) can be removed.
* But several synchronized blocks (all that protected by Counter instances) are really necessary, and removing them we obtain
* an incorrect result.
*
* @param one
* @param two
* @param from - lower bound of sum
* @param to - upper bound of sum
* @param local - reference ThreadLocal that will be used for calculations
* @param bCounter - Counter instance, need to perfom checks
*/
public static int chunkedConvolutionWithDummy(Seq<Integer> one,
Seq<Integer> two, int from, int to, ThreadLocals local,
Counter bCounter) {
ThreadLocals conv_local1 = new ThreadLocals(local, "conv_local1");
ThreadLocals conv_local2 = new ThreadLocals(conv_local1, "conv_local2");
ThreadLocals conv_local3 = new ThreadLocals(null, "conv_local3");
int res = 0;
synchronized (local) {
local.updateHash();
int upperBound = 0;
synchronized (conv_local1) {
upperBound = local.min(one.size(), two.size());
synchronized (two) {
//int upperBound = Math.min(Math.min(one.size(), two.size()), to + 1) :
upperBound = conv_local1.min(upperBound, to + 1);
synchronized (bCounter) {
bCounter.inc();
}
}
for (int i = from; i < upperBound; i++) {
synchronized (conv_local2) {
conv_local1.updateHash();
int prod = 0;
synchronized (one) {
int t = conv_local2.mult(one.get(i), two.get(i));
synchronized (conv_local3) {
prod = t;
}
//res += one.get(i) * two.get(i)
res = conv_local3.sum(res, prod);
}
}
}
}
return res;
}
}
public boolean productCheck(Matrix a, Matrix b) {
if (a == null || b == null) {
log.complain("null matrix!");
return false;
}
if (a.dim != b.dim) {
log.complain("matrices dimension are differs");
return false;
}
return true;
}
public Matrix serialMul(Matrix a, Matrix b) {
if (!productCheck(a, b)) {
throw new IllegalArgumentException();
}
Matrix result = Matrix.zeroMatrix(a.dim);
for (int i = 0; i < a.dim; i++) {
for (int j = 0; j < a.dim; j++) {
result.set(i, j, convolution(a.row(i), b.column(j)));
}
}
return result;
}
/**
* Parallel multiplication of matrices.
* <p/>
* This special version of this method:
* <pre>{@code
* public Matrix parallelMul1(final Matrix a, final Matrix b, int threadCount) {
* if (!productCheck(a, b)) {
* throw new IllegalArgumentException();
* }
* final int dim = a.dim;
* final Matrix result = Matrix.zeroMatrix(dim);
* <p/>
* ExecutorService threadPool = Executors.newFixedThreadPool(threadCount);
* final CountDownLatch latch = new CountDownLatch(threadCount);
* List<Pair<Integer, Integer>> parts = splitInterval(Pair.of(0, dim - 1), threadCount);
* for (final Pair<Integer, Integer> part : parts) {
* threadPool.submit(new Runnable() {
* @Override
* public void run() {
* for (int i = 0; i < dim; i++) {
* for (int j = 0; j < dim; j++) {
* synchronized (result) {
* int from = part.first;
* int to = part.second;
* result.add(i, j, chunkedConvolution(a.row(i), b.column(j), from, to));
* }
* }
* }
* latch.countDown();
* }
* });
* }
* <p/>
* try {
* latch.await();
* } catch (InterruptedException e) {
* e.printStackTrace();
* }
* threadPool.shutdown();
* return result;
* }}</pre>
* Lines marked with NOP comments need to fool the Lock Elision optimization:
* All lock objects in these lines are really thread local, so related synchronized blocks (dummy blocks) can be removed.
* But several synchronized blocks (that are nested in dummy blocks) are really necessary, and removing them we obtain
* an incorrect result.
*
* @param a first operand
* @param b second operand
* @param threadCount number of threads that will be used for calculations
* @return product of matrices a and b
*/
public Matrix parallelMul(final Matrix a, final Matrix b, int threadCount)
throws CounterIncorrectStateException {
if (!productCheck(a, b)) {
throw new IllegalArgumentException();
}
final int dim = a.dim;
final Matrix result = Matrix.zeroMatrix(dim);
ExecutorService threadPool = Executors.newFixedThreadPool(threadCount);
final CountDownLatch latch = new CountDownLatch(threadCount);
List<Pair<Integer, Integer>> parts = splitInterval(Pair.of(0, dim - 1),
threadCount);
final Counter lCounter1 = new Counter(threadCount, "lCounter1");
final Counter lCounter2 = new Counter(threadCount, "lCounter2");
final Counter lCounter3 = new Counter(threadCount, "lCounter3");
final Counter bCounter1 = new Counter(threadCount * dim * dim,
"bCounter1");
final Counter bCounter2 = new Counter(threadCount * dim * dim,
"bCounter2");
final Counter bCounter3 = new Counter(threadCount * dim * dim,
"bCounter3");
final Counter[] counters = {lCounter1, lCounter2, lCounter3,
bCounter1, bCounter2, bCounter3};
final Map<Pair<Integer, Integer>, ThreadLocals> locals1
= CollectionsUtils.newHashMap();
final Map<Pair<Integer, Integer>, ThreadLocals> locals2
= CollectionsUtils.newHashMap();
final Map<Pair<Integer, Integer>, ThreadLocals> locals3
= CollectionsUtils.newHashMap();
for (final Pair<Integer, Integer> part : parts) {
ThreadLocals local1 = new ThreadLocals(null,
"locals1[" + part + "]");
ThreadLocals local2 = new ThreadLocals(local1,
"locals2[" + part + "]");
ThreadLocals local3 = new ThreadLocals(local2,
"locals3[" + part + "]");
locals1.put(part, local1);
locals2.put(part, local2);
locals3.put(part, local3);
}
for (final Pair<Integer, Integer> part : parts) {
threadPool.submit(new Runnable() {
@Override
public void run() {
ThreadLocals local1 = locals1.get(part);
ThreadLocals local2 = locals2.get(part);
ThreadLocals local3 = locals3.get(part);
ThreadLocals local4 = locals3.get(part);
synchronized (local1) {
local1.updateHash();
synchronized (lCounter1) {
lCounter1.inc();
}
synchronized (lCounter3) {
synchronized (local2) {
local2.updateHash();
lCounter3.inc();
}
}
synchronized (new Object()) {
synchronized (lCounter2) {
lCounter2.inc();
}
for (int i = 0; i < dim; i++) {
for (int j = 0; j < dim; j++) {
synchronized (bCounter1) {
synchronized (new Object()) {
bCounter1.inc();
}
}
synchronized (local3) {
local3.updateHash();
synchronized (bCounter2) {
bCounter2.inc();
}
synchronized (result) {
local1.updateHash();
synchronized (local2) {
local2.updateHash();
int from = part.first;
int to = part.second;
result.add(i, j,
chunkedConvolutionWithDummy(
a.row(i),
b.column(j),
from, to,
local4,
bCounter3));
}
}
}
}
}
}
}
latch.countDown();
}
});
}
try {
latch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
threadPool.shutdown();
for (final Pair<Integer, Integer> part : parts) {
log.display(
"hash for " + part + " = " + locals1.get(part).getHash());
}
for (Counter counter : counters) {
if (!counter.check()) {
throw new CounterIncorrectStateException(counter);
}
}
return result;
}
/**
* Split interval into parts
*
* @param interval - pair than encode bounds of interval
* @param partCount - count of parts
* @return list of pairs than encode bounds of parts
*/
public static List<Pair<Integer, Integer>> splitInterval(
Pair<Integer, Integer> interval, int partCount) {
if (partCount == 0) {
throw new IllegalArgumentException();
}
if (partCount == 1) {
return CollectionsUtils.asList(interval);
}
int intervalSize = interval.second - interval.first + 1;
int partSize = intervalSize / partCount;
List<Pair<Integer, Integer>> init = splitInterval(
Pair.of(interval.first, interval.second - partSize),
partCount - 1);
Pair<Integer, Integer> lastPart = Pair
.of(interval.second - partSize + 1, interval.second);
return CollectionsUtils.append(init, lastPart);
}
public static class Counter {
private int state;
public final int expected;
public final String name;
public void inc() {
state++;
}
public int state() {
return state;
}
public boolean check() {
return state == expected;
}
public Counter(int expected, String name) {
this.expected = expected;
this.name = name;
}
}
private static class CounterIncorrectStateException extends Exception {
public final Counter counter;
public CounterIncorrectStateException(Counter counter) {
this.counter = counter;
}
}
private static abstract class Seq<E> implements Iterable<E> {
@Override
public Iterator<E> iterator() {
return new Iterator<E>() {
private int p = 0;
@Override
public boolean hasNext() {
return p < size();
}
@Override
public E next() {
return get(p++);
}
@Override
public void remove() {
}
};
}
public abstract E get(int i);
public abstract int size();
}
private static class CollectionsUtils {
public static <K, V> Map<K, V> newHashMap() {
return new HashMap<K, V>();
}
public static <E> List<E> newArrayList() {
return new ArrayList<E>();
}
public static <E> List<E> newArrayList(Collection<E> collection) {
return new ArrayList<E>(collection);
}
public static <E> List<E> asList(E e) {
List<E> result = newArrayList();
result.add(e);
return result;
}
public static <E> List<E> append(List<E> init, E last) {
List<E> result = newArrayList(init);
result.add(last);
return result;
}
}
private static class Matrix {
public final int dim;
private int[] coeffs;
private Matrix(int dim) {
this.dim = dim;
this.coeffs = new int[dim * dim];
}
public void set(int i, int j, int value) {
coeffs[i * dim + j] = value;
}
public void add(int i, int j, int value) {
coeffs[i * dim + j] += value;
}
public int get(int i, int j) {
return coeffs[i * dim + j];
}
public Seq<Integer> row(final int i) {
return new Seq<Integer>() {
@Override
public Integer get(int j) {
return Matrix.this.get(i, j);
}
@Override
public int size() {
return Matrix.this.dim;
}
};
}
public Seq<Integer> column(final int j) {
return new Seq<Integer>() {
@Override
public Integer get(int i) {
return Matrix.this.get(i, j);
}
@Override
public int size() {
return Matrix.this.dim;
}
};
}
@Override
public String toString() {
StringBuilder builder = new StringBuilder();
for (int i = 0; i < dim; i++) {
for (int j = 0; j < dim; j++) {
builder.append((j == 0) ? "" : "\t\t");
builder.append(get(i, j));
}
builder.append("\n");
}
return builder.toString();
}
@Override
public boolean equals(Object other) {
if (!(other instanceof Matrix)) {
return false;
}
Matrix b = (Matrix) other;
if (b.dim != this.dim) {
return false;
}
for (int i = 0; i < dim; i++) {
for (int j = 0; j < dim; j++) {
if (this.get(i, j) != b.get(i, j)) {
return false;
}
}
}
return true;
}
private static Random random = Utils.getRandomInstance();
public static Matrix randomMatrix(int dim) {
Matrix result = new Matrix(dim);
for (int i = 0; i < dim; i++) {
for (int j = 0; j < dim; j++) {
result.set(i, j, random.nextInt(50));
}
}
return result;
}
public static Matrix zeroMatrix(int dim) {
Matrix result = new Matrix(dim);
for (int i = 0; i < dim; i++) {
for (int j = 0; j < dim; j++) {
result.set(i, j, 0);
}
}
return result;
}
}
/**
* All instances of this class will be used in thread local context
*/
private static class ThreadLocals {
private static final int HASH_BOUND = 424242;
private ThreadLocals parent;
private int hash = 42;
public final String name;
public ThreadLocals(ThreadLocals parent, String name) {
this.parent = parent;
this.name = name;
}
public int min(int a, int b) {
updateHash(a + b + 1);
return Math.min(a, b);
}
public int mult(int a, int b) {
updateHash(a + b + 2);
return a * b;
}
public int sum(int a, int b) {
updateHash(a + b + 3);
return a + b;
}
public int updateHash() {
return updateHash(42);
}
public int updateHash(int data) {
hash = (hash + data) % HASH_BOUND;
if (parent != null) {
hash = parent.updateHash(hash) % HASH_BOUND;
}
return hash;
}
public int getHash() {
return hash;
}
}
}