358 lines
16 KiB
Java
358 lines
16 KiB
Java
|
/*
|
||
|
* Copyright (c) 2018, Google LLC. All rights reserved.
|
||
|
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||
|
*
|
||
|
* This code is free software; you can redistribute it and/or modify it
|
||
|
* under the terms of the GNU General Public License version 2 only, as
|
||
|
* published by the Free Software Foundation.
|
||
|
*
|
||
|
* This code is distributed in the hope that it will be useful, but WITHOUT
|
||
|
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||
|
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
|
||
|
* version 2 for more details (a copy is included in the LICENSE file that
|
||
|
* accompanied this code).
|
||
|
*
|
||
|
* You should have received a copy of the GNU General Public License version
|
||
|
* 2 along with this work; if not, write to the Free Software Foundation,
|
||
|
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
|
||
|
*
|
||
|
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
|
||
|
* or visit www.oracle.com if you need additional information or have any
|
||
|
* questions.
|
||
|
*/
|
||
|
|
||
|
/**
|
||
|
* @test 8200301
|
||
|
* @summary deduplicate lambda methods with the same body, target type, and captured state
|
||
|
* @modules jdk.jdeps/com.sun.tools.classfile jdk.compiler/com.sun.tools.javac.api
|
||
|
* jdk.compiler/com.sun.tools.javac.code jdk.compiler/com.sun.tools.javac.comp
|
||
|
* jdk.compiler/com.sun.tools.javac.file jdk.compiler/com.sun.tools.javac.main
|
||
|
* jdk.compiler/com.sun.tools.javac.tree jdk.compiler/com.sun.tools.javac.util
|
||
|
* @run main DeduplicationTest
|
||
|
*/
|
||
|
import static java.nio.charset.StandardCharsets.UTF_8;
|
||
|
import static java.util.stream.Collectors.joining;
|
||
|
import static java.util.stream.Collectors.toMap;
|
||
|
import static java.util.stream.Collectors.toSet;
|
||
|
|
||
|
import com.sun.source.util.JavacTask;
|
||
|
import com.sun.source.util.TaskEvent;
|
||
|
import com.sun.source.util.TaskEvent.Kind;
|
||
|
import com.sun.source.util.TaskListener;
|
||
|
import com.sun.tools.classfile.Attribute;
|
||
|
import com.sun.tools.classfile.BootstrapMethods_attribute;
|
||
|
import com.sun.tools.classfile.BootstrapMethods_attribute.BootstrapMethodSpecifier;
|
||
|
import com.sun.tools.classfile.ClassFile;
|
||
|
import com.sun.tools.classfile.ConstantPool.CONSTANT_MethodHandle_info;
|
||
|
import com.sun.tools.javac.api.ClientCodeWrapper.Trusted;
|
||
|
import com.sun.tools.javac.api.JavacTool;
|
||
|
import com.sun.tools.javac.code.Symbol;
|
||
|
import com.sun.tools.javac.code.Symbol.MethodSymbol;
|
||
|
import com.sun.tools.javac.comp.TreeDiffer;
|
||
|
import com.sun.tools.javac.comp.TreeHasher;
|
||
|
import com.sun.tools.javac.file.JavacFileManager;
|
||
|
import com.sun.tools.javac.tree.JCTree.JCCompilationUnit;
|
||
|
import com.sun.tools.javac.tree.JCTree.JCExpression;
|
||
|
import com.sun.tools.javac.tree.JCTree.JCIdent;
|
||
|
import com.sun.tools.javac.tree.JCTree.JCLambda;
|
||
|
import com.sun.tools.javac.tree.JCTree.JCMethodInvocation;
|
||
|
import com.sun.tools.javac.tree.JCTree.JCTypeCast;
|
||
|
import com.sun.tools.javac.tree.JCTree.JCVariableDecl;
|
||
|
import com.sun.tools.javac.tree.JCTree.Tag;
|
||
|
import com.sun.tools.javac.tree.TreeScanner;
|
||
|
import com.sun.tools.javac.util.Context;
|
||
|
import com.sun.tools.javac.util.JCDiagnostic;
|
||
|
import java.nio.file.Path;
|
||
|
import java.nio.file.Paths;
|
||
|
import java.util.ArrayList;
|
||
|
import java.util.Arrays;
|
||
|
import java.util.LinkedHashMap;
|
||
|
import java.util.List;
|
||
|
import java.util.Locale;
|
||
|
import java.util.Map;
|
||
|
import java.util.Objects;
|
||
|
import java.util.Set;
|
||
|
import java.util.TreeSet;
|
||
|
import java.util.function.BiFunction;
|
||
|
import javax.tools.Diagnostic;
|
||
|
import javax.tools.DiagnosticListener;
|
||
|
import javax.tools.JavaFileObject;
|
||
|
|
||
|
public class DeduplicationTest {
|
||
|
|
||
|
public static void main(String[] args) throws Exception {
|
||
|
JavacFileManager fileManager = new JavacFileManager(new Context(), false, UTF_8);
|
||
|
JavacTool javacTool = JavacTool.create();
|
||
|
Listener diagnosticListener = new Listener();
|
||
|
Path testSrc = Paths.get(System.getProperty("test.src"));
|
||
|
Path file = testSrc.resolve("Deduplication.java");
|
||
|
JavacTask task =
|
||
|
javacTool.getTask(
|
||
|
null,
|
||
|
null,
|
||
|
diagnosticListener,
|
||
|
Arrays.asList(
|
||
|
"-d",
|
||
|
".",
|
||
|
"-XDdebug.dumpLambdaToMethodDeduplication",
|
||
|
"-XDdebug.dumpLambdaToMethodStats"),
|
||
|
null,
|
||
|
fileManager.getJavaFileObjects(file));
|
||
|
Map<JCLambda, JCLambda> dedupedLambdas = new LinkedHashMap<>();
|
||
|
task.addTaskListener(new TreeDiffHashTaskListener(dedupedLambdas));
|
||
|
Iterable<? extends JavaFileObject> generated = task.generate();
|
||
|
if (!diagnosticListener.unexpected.isEmpty()) {
|
||
|
throw new AssertionError(
|
||
|
diagnosticListener
|
||
|
.unexpected
|
||
|
.stream()
|
||
|
.map(
|
||
|
d ->
|
||
|
String.format(
|
||
|
"%s: %s",
|
||
|
d.getCode(), d.getMessage(Locale.getDefault())))
|
||
|
.collect(joining(", ", "unexpected diagnostics: ", "")));
|
||
|
}
|
||
|
|
||
|
// Assert that each group of lambdas was deduplicated.
|
||
|
Map<JCLambda, JCLambda> actual = diagnosticListener.deduplicationTargets();
|
||
|
dedupedLambdas.forEach(
|
||
|
(k, v) -> {
|
||
|
if (!actual.containsKey(k)) {
|
||
|
throw new AssertionError("expected " + k + " to be deduplicated");
|
||
|
}
|
||
|
if (!v.equals(actual.get(k))) {
|
||
|
throw new AssertionError(
|
||
|
String.format(
|
||
|
"expected %s to be deduplicated to:\n %s\nwas: %s",
|
||
|
k, v, actual.get(v)));
|
||
|
}
|
||
|
});
|
||
|
|
||
|
// Assert that the output contains only the canonical lambdas, and not the deduplicated
|
||
|
// lambdas.
|
||
|
Set<String> bootstrapMethodNames = new TreeSet<>();
|
||
|
for (JavaFileObject output : generated) {
|
||
|
ClassFile cf = ClassFile.read(output.openInputStream());
|
||
|
BootstrapMethods_attribute bsm =
|
||
|
(BootstrapMethods_attribute) cf.getAttribute(Attribute.BootstrapMethods);
|
||
|
for (BootstrapMethodSpecifier b : bsm.bootstrap_method_specifiers) {
|
||
|
bootstrapMethodNames.add(
|
||
|
((CONSTANT_MethodHandle_info)
|
||
|
cf.constant_pool.get(b.bootstrap_arguments[1]))
|
||
|
.getCPRefInfo()
|
||
|
.getNameAndTypeInfo()
|
||
|
.getName());
|
||
|
}
|
||
|
}
|
||
|
Set<String> deduplicatedNames =
|
||
|
diagnosticListener
|
||
|
.expectedLambdaMethods()
|
||
|
.stream()
|
||
|
.map(s -> s.getSimpleName().toString())
|
||
|
.sorted()
|
||
|
.collect(toSet());
|
||
|
if (!deduplicatedNames.equals(bootstrapMethodNames)) {
|
||
|
throw new AssertionError(
|
||
|
String.format(
|
||
|
"expected deduplicated methods: %s, but saw: %s",
|
||
|
deduplicatedNames, bootstrapMethodNames));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* Returns a symbol comparator that treats symbols that correspond to the same parameter of each
|
||
|
* of the given lambdas as equal.
|
||
|
*/
|
||
|
private static BiFunction<Symbol, Symbol, Boolean> paramsEqual(JCLambda lhs, JCLambda rhs) {
|
||
|
return (x, y) -> {
|
||
|
Integer idx = paramIndex(lhs, x);
|
||
|
if (idx != null && idx != -1) {
|
||
|
if (Objects.equals(idx, paramIndex(rhs, y))) {
|
||
|
return true;
|
||
|
}
|
||
|
}
|
||
|
return null;
|
||
|
};
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* Returns the index of the given symbol as a parameter of the given lambda, or else {@code -1}
|
||
|
* if is not a parameter.
|
||
|
*/
|
||
|
private static Integer paramIndex(JCLambda lambda, Symbol sym) {
|
||
|
if (sym != null) {
|
||
|
int idx = 0;
|
||
|
for (JCVariableDecl param : lambda.params) {
|
||
|
if (sym == param.sym) {
|
||
|
return idx;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return null;
|
||
|
}
|
||
|
|
||
|
/** A diagnostic listener that records debug messages related to lambda desugaring. */
|
||
|
@Trusted
|
||
|
static class Listener implements DiagnosticListener<JavaFileObject> {
|
||
|
|
||
|
/** A map from method symbols to lambda trees for desugared lambdas. */
|
||
|
final Map<MethodSymbol, JCLambda> lambdaMethodSymbolsToTrees = new LinkedHashMap<>();
|
||
|
|
||
|
/**
|
||
|
* A map from lambda trees that were deduplicated to the method symbol of the canonical
|
||
|
* lambda implementation method they were deduplicated to.
|
||
|
*/
|
||
|
final Map<JCLambda, MethodSymbol> deduped = new LinkedHashMap<>();
|
||
|
|
||
|
final List<Diagnostic<? extends JavaFileObject>> unexpected = new ArrayList<>();
|
||
|
|
||
|
@Override
|
||
|
public void report(Diagnostic<? extends JavaFileObject> diagnostic) {
|
||
|
JCDiagnostic d = (JCDiagnostic) diagnostic;
|
||
|
switch (d.getCode()) {
|
||
|
case "compiler.note.lambda.stat":
|
||
|
lambdaMethodSymbolsToTrees.put(
|
||
|
(MethodSymbol) d.getArgs()[1],
|
||
|
(JCLambda) d.getDiagnosticPosition().getTree());
|
||
|
break;
|
||
|
case "compiler.note.verbose.l2m.deduplicate":
|
||
|
deduped.put(
|
||
|
(JCLambda) d.getDiagnosticPosition().getTree(),
|
||
|
(MethodSymbol) d.getArgs()[0]);
|
||
|
break;
|
||
|
default:
|
||
|
unexpected.add(diagnostic);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/** Returns expected lambda implementation method symbols. */
|
||
|
Set<MethodSymbol> expectedLambdaMethods() {
|
||
|
return lambdaMethodSymbolsToTrees
|
||
|
.entrySet()
|
||
|
.stream()
|
||
|
.filter(e -> !deduped.containsKey(e.getValue()))
|
||
|
.map(Map.Entry::getKey)
|
||
|
.collect(toSet());
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* Returns a mapping from deduplicated lambda trees to the tree of the canonical lambda they
|
||
|
* were deduplicated to.
|
||
|
*/
|
||
|
Map<JCLambda, JCLambda> deduplicationTargets() {
|
||
|
return deduped.entrySet()
|
||
|
.stream()
|
||
|
.collect(
|
||
|
toMap(
|
||
|
Map.Entry::getKey,
|
||
|
e -> lambdaMethodSymbolsToTrees.get(e.getValue()),
|
||
|
(a, b) -> {
|
||
|
throw new AssertionError();
|
||
|
},
|
||
|
LinkedHashMap::new));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* A task listener that tests {@link TreeDiffer} and {@link TreeHasher} on all lambda trees in a
|
||
|
* compilation, post-analysis.
|
||
|
*/
|
||
|
private static class TreeDiffHashTaskListener implements TaskListener {
|
||
|
|
||
|
/**
|
||
|
* A map from deduplicated lambdas to the canonical lambda they are expected to be
|
||
|
* deduplicated to.
|
||
|
*/
|
||
|
private final Map<JCLambda, JCLambda> dedupedLambdas;
|
||
|
|
||
|
public TreeDiffHashTaskListener(Map<JCLambda, JCLambda> dedupedLambdas) {
|
||
|
this.dedupedLambdas = dedupedLambdas;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public void finished(TaskEvent e) {
|
||
|
if (e.getKind() != Kind.ANALYZE) {
|
||
|
return;
|
||
|
}
|
||
|
// Scan the compilation for calls to a varargs method named 'group', whose arguments
|
||
|
// are a group of lambdas that are equivalent to each other, but distinct from all
|
||
|
// lambdas in the compilation unit outside of that group.
|
||
|
List<List<JCLambda>> lambdaGroups = new ArrayList<>();
|
||
|
new TreeScanner() {
|
||
|
@Override
|
||
|
public void visitApply(JCMethodInvocation tree) {
|
||
|
if (tree.getMethodSelect().getTag() == Tag.IDENT
|
||
|
&& ((JCIdent) tree.getMethodSelect())
|
||
|
.getName()
|
||
|
.contentEquals("group")) {
|
||
|
List<JCLambda> xs = new ArrayList<>();
|
||
|
for (JCExpression arg : tree.getArguments()) {
|
||
|
if (arg instanceof JCTypeCast) {
|
||
|
arg = ((JCTypeCast) arg).getExpression();
|
||
|
}
|
||
|
xs.add((JCLambda) arg);
|
||
|
}
|
||
|
lambdaGroups.add(xs);
|
||
|
}
|
||
|
super.visitApply(tree);
|
||
|
}
|
||
|
}.scan((JCCompilationUnit) e.getCompilationUnit());
|
||
|
for (int i = 0; i < lambdaGroups.size(); i++) {
|
||
|
List<JCLambda> curr = lambdaGroups.get(i);
|
||
|
JCLambda first = null;
|
||
|
// Assert that all pairwise combinations of lambdas in the group are equal, and
|
||
|
// hash to the same value.
|
||
|
for (JCLambda lhs : curr) {
|
||
|
if (first == null) {
|
||
|
first = lhs;
|
||
|
} else {
|
||
|
dedupedLambdas.put(lhs, first);
|
||
|
}
|
||
|
for (JCLambda rhs : curr) {
|
||
|
if (!new TreeDiffer(paramsEqual(lhs, rhs)).scan(lhs.body, rhs.body)) {
|
||
|
throw new AssertionError(
|
||
|
String.format(
|
||
|
"expected lambdas to be equal\n%s\n%s", lhs, rhs));
|
||
|
}
|
||
|
if (TreeHasher.hash(lhs, sym -> paramIndex(lhs, sym))
|
||
|
!= TreeHasher.hash(rhs, sym -> paramIndex(rhs, sym))) {
|
||
|
throw new AssertionError(
|
||
|
String.format(
|
||
|
"expected lambdas to hash to the same value\n%s\n%s",
|
||
|
lhs, rhs));
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
// Assert that no lambdas in a group are equal to any lambdas outside that group,
|
||
|
// or hash to the same value as lambda outside the group.
|
||
|
// (Note that the hash collisions won't result in correctness problems but could
|
||
|
// regress performs, and do not currently occurr for any of the test inputs.)
|
||
|
for (int j = 0; j < lambdaGroups.size(); j++) {
|
||
|
if (i == j) {
|
||
|
continue;
|
||
|
}
|
||
|
for (JCLambda lhs : curr) {
|
||
|
for (JCLambda rhs : lambdaGroups.get(j)) {
|
||
|
if (new TreeDiffer(paramsEqual(lhs, rhs)).scan(lhs.body, rhs.body)) {
|
||
|
throw new AssertionError(
|
||
|
String.format(
|
||
|
"expected lambdas to not be equal\n%s\n%s",
|
||
|
lhs, rhs));
|
||
|
}
|
||
|
if (TreeHasher.hash(lhs, sym -> paramIndex(lhs, sym))
|
||
|
== TreeHasher.hash(rhs, sym -> paramIndex(rhs, sym))) {
|
||
|
throw new AssertionError(
|
||
|
String.format(
|
||
|
"expected lambdas to hash to different values\n%s\n%s",
|
||
|
lhs, rhs));
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
lambdaGroups.clear();
|
||
|
}
|
||
|
}
|
||
|
}
|