/* * Copyright (c) 2018, Google LLC. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it * under the terms of the GNU General Public License version 2 only, as * published by the Free Software Foundation. * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * * You should have received a copy of the GNU General Public License version * 2 along with this work; if not, write to the Free Software Foundation, * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. * * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA * or visit www.oracle.com if you need additional information or have any * questions. */ /** * @test 8200301 * @summary deduplicate lambda methods with the same body, target type, and captured state * @modules jdk.jdeps/com.sun.tools.classfile jdk.compiler/com.sun.tools.javac.api * jdk.compiler/com.sun.tools.javac.code jdk.compiler/com.sun.tools.javac.comp * jdk.compiler/com.sun.tools.javac.file jdk.compiler/com.sun.tools.javac.main * jdk.compiler/com.sun.tools.javac.tree jdk.compiler/com.sun.tools.javac.util * @run main DeduplicationTest */ import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.toMap; import static java.util.stream.Collectors.toSet; import com.sun.source.util.JavacTask; import com.sun.source.util.TaskEvent; import com.sun.source.util.TaskEvent.Kind; import com.sun.source.util.TaskListener; import com.sun.tools.classfile.Attribute; import com.sun.tools.classfile.BootstrapMethods_attribute; import com.sun.tools.classfile.BootstrapMethods_attribute.BootstrapMethodSpecifier; import com.sun.tools.classfile.ClassFile; import com.sun.tools.classfile.ConstantPool.CONSTANT_MethodHandle_info; import com.sun.tools.javac.api.ClientCodeWrapper.Trusted; import com.sun.tools.javac.api.JavacTool; import com.sun.tools.javac.code.Symbol; import com.sun.tools.javac.code.Symbol.MethodSymbol; import com.sun.tools.javac.comp.TreeDiffer; import com.sun.tools.javac.comp.TreeHasher; import com.sun.tools.javac.file.JavacFileManager; import com.sun.tools.javac.tree.JCTree.JCCompilationUnit; import com.sun.tools.javac.tree.JCTree.JCExpression; import com.sun.tools.javac.tree.JCTree.JCIdent; import com.sun.tools.javac.tree.JCTree.JCLambda; import com.sun.tools.javac.tree.JCTree.JCMethodInvocation; import com.sun.tools.javac.tree.JCTree.JCTypeCast; import com.sun.tools.javac.tree.JCTree.JCVariableDecl; import com.sun.tools.javac.tree.JCTree.Tag; import com.sun.tools.javac.tree.TreeScanner; import com.sun.tools.javac.util.Context; import com.sun.tools.javac.util.JCDiagnostic; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; import java.util.LinkedHashMap; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.TreeSet; import java.util.function.BiFunction; import javax.tools.Diagnostic; import javax.tools.DiagnosticListener; import javax.tools.JavaFileObject; public class DeduplicationTest { public static void main(String[] args) throws Exception { JavacFileManager fileManager = new JavacFileManager(new Context(), false, UTF_8); JavacTool javacTool = JavacTool.create(); Listener diagnosticListener = new Listener(); Path testSrc = Paths.get(System.getProperty("test.src")); Path file = testSrc.resolve("Deduplication.java"); JavacTask task = javacTool.getTask( null, null, diagnosticListener, Arrays.asList( "-d", ".", "-XDdebug.dumpLambdaToMethodDeduplication", "-XDdebug.dumpLambdaToMethodStats"), null, fileManager.getJavaFileObjects(file)); Map dedupedLambdas = new LinkedHashMap<>(); task.addTaskListener(new TreeDiffHashTaskListener(dedupedLambdas)); Iterable generated = task.generate(); if (!diagnosticListener.unexpected.isEmpty()) { throw new AssertionError( diagnosticListener .unexpected .stream() .map( d -> String.format( "%s: %s", d.getCode(), d.getMessage(Locale.getDefault()))) .collect(joining(", ", "unexpected diagnostics: ", ""))); } // Assert that each group of lambdas was deduplicated. Map actual = diagnosticListener.deduplicationTargets(); dedupedLambdas.forEach( (k, v) -> { if (!actual.containsKey(k)) { throw new AssertionError("expected " + k + " to be deduplicated"); } if (!v.equals(actual.get(k))) { throw new AssertionError( String.format( "expected %s to be deduplicated to:\n %s\nwas: %s", k, v, actual.get(v))); } }); // Assert that the output contains only the canonical lambdas, and not the deduplicated // lambdas. Set bootstrapMethodNames = new TreeSet<>(); for (JavaFileObject output : generated) { ClassFile cf = ClassFile.read(output.openInputStream()); BootstrapMethods_attribute bsm = (BootstrapMethods_attribute) cf.getAttribute(Attribute.BootstrapMethods); for (BootstrapMethodSpecifier b : bsm.bootstrap_method_specifiers) { bootstrapMethodNames.add( ((CONSTANT_MethodHandle_info) cf.constant_pool.get(b.bootstrap_arguments[1])) .getCPRefInfo() .getNameAndTypeInfo() .getName()); } } Set deduplicatedNames = diagnosticListener .expectedLambdaMethods() .stream() .map(s -> s.getSimpleName().toString()) .sorted() .collect(toSet()); if (!deduplicatedNames.equals(bootstrapMethodNames)) { throw new AssertionError( String.format( "expected deduplicated methods: %s, but saw: %s", deduplicatedNames, bootstrapMethodNames)); } } /** * Returns a symbol comparator that treats symbols that correspond to the same parameter of each * of the given lambdas as equal. */ private static BiFunction paramsEqual(JCLambda lhs, JCLambda rhs) { return (x, y) -> { Integer idx = paramIndex(lhs, x); if (idx != null && idx != -1) { if (Objects.equals(idx, paramIndex(rhs, y))) { return true; } } return null; }; } /** * Returns the index of the given symbol as a parameter of the given lambda, or else {@code -1} * if is not a parameter. */ private static Integer paramIndex(JCLambda lambda, Symbol sym) { if (sym != null) { int idx = 0; for (JCVariableDecl param : lambda.params) { if (sym == param.sym) { return idx; } } } return null; } /** A diagnostic listener that records debug messages related to lambda desugaring. */ @Trusted static class Listener implements DiagnosticListener { /** A map from method symbols to lambda trees for desugared lambdas. */ final Map lambdaMethodSymbolsToTrees = new LinkedHashMap<>(); /** * A map from lambda trees that were deduplicated to the method symbol of the canonical * lambda implementation method they were deduplicated to. */ final Map deduped = new LinkedHashMap<>(); final List> unexpected = new ArrayList<>(); @Override public void report(Diagnostic diagnostic) { JCDiagnostic d = (JCDiagnostic) diagnostic; switch (d.getCode()) { case "compiler.note.lambda.stat": lambdaMethodSymbolsToTrees.put( (MethodSymbol) d.getArgs()[1], (JCLambda) d.getDiagnosticPosition().getTree()); break; case "compiler.note.verbose.l2m.deduplicate": deduped.put( (JCLambda) d.getDiagnosticPosition().getTree(), (MethodSymbol) d.getArgs()[0]); break; default: unexpected.add(diagnostic); } } /** Returns expected lambda implementation method symbols. */ Set expectedLambdaMethods() { return lambdaMethodSymbolsToTrees .entrySet() .stream() .filter(e -> !deduped.containsKey(e.getValue())) .map(Map.Entry::getKey) .collect(toSet()); } /** * Returns a mapping from deduplicated lambda trees to the tree of the canonical lambda they * were deduplicated to. */ Map deduplicationTargets() { return deduped.entrySet() .stream() .collect( toMap( Map.Entry::getKey, e -> lambdaMethodSymbolsToTrees.get(e.getValue()), (a, b) -> { throw new AssertionError(); }, LinkedHashMap::new)); } } /** * A task listener that tests {@link TreeDiffer} and {@link TreeHasher} on all lambda trees in a * compilation, post-analysis. */ private static class TreeDiffHashTaskListener implements TaskListener { /** * A map from deduplicated lambdas to the canonical lambda they are expected to be * deduplicated to. */ private final Map dedupedLambdas; public TreeDiffHashTaskListener(Map dedupedLambdas) { this.dedupedLambdas = dedupedLambdas; } @Override public void finished(TaskEvent e) { if (e.getKind() != Kind.ANALYZE) { return; } // Scan the compilation for calls to a varargs method named 'group', whose arguments // are a group of lambdas that are equivalent to each other, but distinct from all // lambdas in the compilation unit outside of that group. List> lambdaGroups = new ArrayList<>(); new TreeScanner() { @Override public void visitApply(JCMethodInvocation tree) { if (tree.getMethodSelect().getTag() == Tag.IDENT && ((JCIdent) tree.getMethodSelect()) .getName() .contentEquals("group")) { List xs = new ArrayList<>(); for (JCExpression arg : tree.getArguments()) { if (arg instanceof JCTypeCast) { arg = ((JCTypeCast) arg).getExpression(); } xs.add((JCLambda) arg); } lambdaGroups.add(xs); } super.visitApply(tree); } }.scan((JCCompilationUnit) e.getCompilationUnit()); for (int i = 0; i < lambdaGroups.size(); i++) { List curr = lambdaGroups.get(i); JCLambda first = null; // Assert that all pairwise combinations of lambdas in the group are equal, and // hash to the same value. for (JCLambda lhs : curr) { if (first == null) { first = lhs; } else { dedupedLambdas.put(lhs, first); } for (JCLambda rhs : curr) { if (!new TreeDiffer(paramsEqual(lhs, rhs)).scan(lhs.body, rhs.body)) { throw new AssertionError( String.format( "expected lambdas to be equal\n%s\n%s", lhs, rhs)); } if (TreeHasher.hash(lhs, sym -> paramIndex(lhs, sym)) != TreeHasher.hash(rhs, sym -> paramIndex(rhs, sym))) { throw new AssertionError( String.format( "expected lambdas to hash to the same value\n%s\n%s", lhs, rhs)); } } } // Assert that no lambdas in a group are equal to any lambdas outside that group, // or hash to the same value as lambda outside the group. // (Note that the hash collisions won't result in correctness problems but could // regress performs, and do not currently occurr for any of the test inputs.) for (int j = 0; j < lambdaGroups.size(); j++) { if (i == j) { continue; } for (JCLambda lhs : curr) { for (JCLambda rhs : lambdaGroups.get(j)) { if (new TreeDiffer(paramsEqual(lhs, rhs)).scan(lhs.body, rhs.body)) { throw new AssertionError( String.format( "expected lambdas to not be equal\n%s\n%s", lhs, rhs)); } if (TreeHasher.hash(lhs, sym -> paramIndex(lhs, sym)) == TreeHasher.hash(rhs, sym -> paramIndex(rhs, sym))) { throw new AssertionError( String.format( "expected lambdas to hash to different values\n%s\n%s", lhs, rhs)); } } } } } lambdaGroups.clear(); } } }