From 895a7b64f01dec7248549b127875edcf006457cf Mon Sep 17 00:00:00 2001 From: Aggelos Biboudis Date: Mon, 4 Nov 2024 12:27:12 +0000 Subject: [PATCH] 8342967: Lambda deduplication fails with non-metafactory BSMs and mismatched local variables names Reviewed-by: mcimadamore --- .../sun/tools/javac/comp/LambdaToMethod.java | 4 +- .../sun/tools/javac/comp/TransPatterns.java | 2 +- .../com/sun/tools/javac/comp/TreeDiffer.java | 46 ++++-- .../com/sun/tools/javac/comp/TreeHasher.java | 28 +++- .../lambda/deduplication/Deduplication.java | 85 ++++++---- .../deduplication/DeduplicationTest.java | 152 +++++++++++------- 6 files changed, 203 insertions(+), 114 deletions(-) diff --git a/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/LambdaToMethod.java b/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/LambdaToMethod.java index adfc3ceaa0d..8772e70dda3 100644 --- a/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/LambdaToMethod.java +++ b/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/LambdaToMethod.java @@ -217,7 +217,7 @@ public class LambdaToMethod extends TreeTranslator { public int hashCode() { int hashCode = this.hashCode; if (hashCode == 0) { - this.hashCode = hashCode = TreeHasher.hash(tree, symbol.params()); + this.hashCode = hashCode = TreeHasher.hash(types, tree, symbol.params()); } return hashCode; } @@ -226,7 +226,7 @@ public class LambdaToMethod extends TreeTranslator { public boolean equals(Object o) { return (o instanceof DedupedLambda dedupedLambda) && types.isSameType(symbol.asType(), dedupedLambda.symbol.asType()) - && new TreeDiffer(symbol.params(), dedupedLambda.symbol.params()).scan(tree, dedupedLambda.tree); + && new TreeDiffer(types, symbol.params(), dedupedLambda.symbol.params()).scan(tree, dedupedLambda.tree); } } diff --git a/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TransPatterns.java b/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TransPatterns.java index f1f4e73b5a6..0e3a2c0b1db 100644 --- a/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TransPatterns.java +++ b/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TransPatterns.java @@ -993,7 +993,7 @@ public class TransPatterns extends TreeTranslator { !currentNullable && !previousCompletesNormally && !currentCompletesNormally && - new TreeDiffer(List.of(commonBinding), List.of(currentBinding)) + new TreeDiffer(types, List.of(commonBinding), List.of(currentBinding)) .scan(commonNestedExpression, currentNestedExpression)) { accummulator.add(c.head); } else { diff --git a/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TreeDiffer.java b/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TreeDiffer.java index df14b1859e3..bbc12f1fe80 100644 --- a/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TreeDiffer.java +++ b/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TreeDiffer.java @@ -28,6 +28,9 @@ package com.sun.tools.javac.comp; import com.sun.tools.javac.code.Flags; import com.sun.tools.javac.code.Symbol; +import com.sun.tools.javac.code.TypeTag; +import com.sun.tools.javac.code.Types; +import com.sun.tools.javac.jvm.PoolConstant; import com.sun.tools.javac.tree.JCTree; import com.sun.tools.javac.tree.JCTree.JCAnnotatedType; import com.sun.tools.javac.tree.JCTree.JCAnnotation; @@ -107,10 +110,10 @@ import java.util.Objects; /** A visitor that compares two lambda bodies for structural equality. */ public class TreeDiffer extends TreeScanner { - - public TreeDiffer( - Collection symbols, Collection otherSymbols) { + public TreeDiffer(Types types, + Collection symbols, Collection otherSymbols) { this.equiv = equiv(symbols, otherSymbols); + this.types = types; } private static Map equiv( @@ -127,6 +130,7 @@ public class TreeDiffer extends TreeScanner { private JCTree parameter; private boolean result; private Map equiv = new HashMap<>(); + final Types types; public boolean scan(JCTree tree, JCTree parameter) { if (tree == null || parameter == null) { @@ -197,13 +201,24 @@ public class TreeDiffer extends TreeScanner { return; } } - result = tree.sym == that.sym; + result = scanSymbol(symbol, otherSymbol); + } + + private boolean scanSymbol(Symbol symbol, Symbol otherSymbol) { + if (symbol instanceof PoolConstant.Dynamic dms && otherSymbol instanceof PoolConstant.Dynamic other_dms) { + return dms.bsmKey(types).equals(other_dms.bsmKey(types)); + } + else { + return symbol == otherSymbol; + } } @Override public void visitSelect(JCFieldAccess tree) { JCFieldAccess that = (JCFieldAccess) parameter; - result = scan(tree.selected, that.selected) && tree.sym == that.sym; + + result = scan(tree.selected, that.selected) && + scanSymbol(tree.sym, that.sym); } @Override @@ -328,14 +343,7 @@ public class TreeDiffer extends TreeScanner { @Override public void visitClassDef(JCClassDecl tree) { - JCClassDecl that = (JCClassDecl) parameter; - result = - scan(tree.mods, that.mods) - && tree.name == that.name - && scan(tree.typarams, that.typarams) - && scan(tree.extending, that.extending) - && scan(tree.implementing, that.implementing) - && scan(tree.defs, that.defs); + result = false; } @Override @@ -667,14 +675,18 @@ public class TreeDiffer extends TreeScanner { JCVariableDecl that = (JCVariableDecl) parameter; result = scan(tree.mods, that.mods) - && tree.name == that.name && scan(tree.nameexpr, that.nameexpr) && scan(tree.vartype, that.vartype) && scan(tree.init, that.init); - if (!result) { - return; + + if (tree.sym.owner.type.hasTag(TypeTag.CLASS)) { + // field names are important! + result &= tree.name == that.name; + } + + if (result) { + equiv.put(tree.sym, that.sym); } - equiv.put(tree.sym, that.sym); } @Override diff --git a/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TreeHasher.java b/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TreeHasher.java index a0eec4ca91d..1f7691c5183 100644 --- a/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TreeHasher.java +++ b/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TreeHasher.java @@ -27,7 +27,10 @@ package com.sun.tools.javac.comp; import com.sun.tools.javac.code.Symbol; +import com.sun.tools.javac.code.Types; +import com.sun.tools.javac.jvm.PoolConstant; import com.sun.tools.javac.tree.JCTree; +import com.sun.tools.javac.tree.JCTree.JCClassDecl; import com.sun.tools.javac.tree.JCTree.JCFieldAccess; import com.sun.tools.javac.tree.JCTree.JCIdent; import com.sun.tools.javac.tree.JCTree.JCLiteral; @@ -43,19 +46,21 @@ import java.util.Objects; public class TreeHasher extends TreeScanner { private final Map symbolHashes; + private final Types types; private int result = 17; - public TreeHasher(Map symbolHashes) { + public TreeHasher(Types types, Map symbolHashes) { this.symbolHashes = Objects.requireNonNull(symbolHashes); + this.types = types; } - public static int hash(JCTree tree, Collection symbols) { + public static int hash(Types types, JCTree tree, Collection symbols) { if (tree == null) { return 0; } Map symbolHashes = new HashMap<>(); symbols.forEach(s -> symbolHashes.put(s, symbolHashes.size())); - TreeHasher hasher = new TreeHasher(symbolHashes); + TreeHasher hasher = new TreeHasher(types, symbolHashes); tree.accept(hasher); return hasher.result; } @@ -87,6 +92,11 @@ public class TreeHasher extends TreeScanner { super.visitLiteral(tree); } + @Override + public void visitClassDef(JCClassDecl tree) { + hash(tree.sym); + } + @Override public void visitIdent(JCIdent tree) { Symbol sym = tree.sym; @@ -97,15 +107,23 @@ public class TreeHasher extends TreeScanner { return; } } - hash(sym); + hashSymbol(sym); } @Override public void visitSelect(JCFieldAccess tree) { - hash(tree.sym); + hashSymbol(tree.sym); super.visitSelect(tree); } + private void hashSymbol(Symbol sym) { + if (sym instanceof PoolConstant.Dynamic dynamic) { + hash(dynamic.bsmKey(types)); + } else { + hash(sym); + } + } + @Override public void visitVarDef(JCVariableDecl tree) { symbolHashes.computeIfAbsent(tree.sym, k -> symbolHashes.size()); diff --git a/test/langtools/tools/javac/lambda/deduplication/Deduplication.java b/test/langtools/tools/javac/lambda/deduplication/Deduplication.java index 2b9a5e82b91..201cba66054 100644 --- a/test/langtools/tools/javac/lambda/deduplication/Deduplication.java +++ b/test/langtools/tools/javac/lambda/deduplication/Deduplication.java @@ -29,52 +29,54 @@ import java.util.function.Function; import java.util.function.Supplier; public class Deduplication { + void groupEquals(Object... xs) {} + void groupNotEquals(Object... xs) {} void group(Object... xs) {} void test() { - group( + groupEquals( (Runnable) () -> { ( (Runnable) () -> {} ).run(); }, (Runnable) () -> { ( (Runnable) () -> {} ).run(); } ); - group( + groupEquals( (Runnable) () -> { Deduplication.class.toString(); }, (Runnable) () -> { Deduplication.class.toString(); } ); - group( + groupEquals( (Runnable) () -> { Integer[].class.toString(); }, (Runnable) () -> { Integer[].class.toString(); } ); - group( + groupEquals( (Runnable) () -> { char.class.toString(); }, (Runnable) () -> { char.class.toString(); } ); - group( + groupEquals( (Runnable) () -> { Void.class.toString(); }, (Runnable) () -> { Void.class.toString(); } ); - group( + groupEquals( (Runnable) () -> { void.class.toString(); }, (Runnable) () -> { void.class.toString(); } ); - group((Function) x -> x.hashCode()); - group((Function) x -> x.hashCode()); + groupEquals((Function) x -> x.hashCode()); + groupEquals((Function) x -> x.hashCode()); { int x = 1; - group((Supplier) () -> x + 1); + groupEquals((Supplier) () -> x + 1); } { int x = 1; - group((Supplier) () -> x + 1); + groupEquals((Supplier) () -> x + 1); } - group( + groupEquals( (BiFunction) (x, y) -> x + ((y)), (BiFunction) (x, y) -> x + (y), (BiFunction) (x, y) -> x + y, @@ -85,29 +87,29 @@ public class Deduplication { (BiFunction) (x, y) -> ((x)) + (y), (BiFunction) (x, y) -> ((x)) + y); - group( + groupEquals( (Function) x -> x + (1 + 2 + 3), (Function) x -> x + 6); - group((Function) x -> x + 1, (Function) y -> y + 1); + groupEquals((Function) x -> x + 1, (Function) y -> y + 1); - group((Consumer) x -> this.f(), (Consumer) x -> this.f()); + groupEquals((Consumer) x -> this.f(), (Consumer) x -> this.f()); - group((Consumer) y -> this.g()); + groupEquals((Consumer) y -> this.g()); - group((Consumer) x -> f(), (Consumer) x -> f()); + groupEquals((Consumer) x -> f(), (Consumer) x -> f()); - group((Consumer) y -> g()); + groupEquals((Consumer) y -> g()); - group((Function) x -> this.i, (Function) x -> this.i); + groupEquals((Function) x -> this.i, (Function) x -> this.i); - group((Function) y -> this.j); + groupEquals((Function) y -> this.j); - group((Function) x -> i, (Function) x -> i); + groupEquals((Function) x -> i, (Function) x -> i); - group((Function) y -> j); + groupEquals((Function) y -> j); - group( + groupEquals( (Function) y -> { while (true) { @@ -123,7 +125,7 @@ public class Deduplication { return 42; }); - group( + groupEquals( (Function) x -> { int y = x; @@ -135,13 +137,13 @@ public class Deduplication { return y; }); - group( + groupEquals( (Function) x -> { int y = 0, z = x; return y; }); - group( + groupEquals( (Function) x -> { int y = 0, z = x; @@ -154,24 +156,41 @@ public class Deduplication { void f() {} { - group((Function) x -> this.i); - group((Consumer) x -> this.f()); - group((Function) x -> Deduplication.this.i); - group((Consumer) x -> Deduplication.this.f()); + groupEquals((Function) x -> this.i); + groupEquals((Consumer) x -> this.f()); + groupEquals((Function) x -> Deduplication.this.i); + groupEquals((Consumer) x -> Deduplication.this.f()); } } - group((Function) x -> switch (x) { default: yield x; }, + groupEquals((Function) x -> switch (x) { default: yield x; }, (Function) x -> switch (x) { default: yield x; }); - group((Function) x -> x instanceof Integer i ? i : -1, + groupEquals((Function) x -> x instanceof Integer i ? i : -1, (Function) x -> x instanceof Integer i ? i : -1); - group((Function) x -> x instanceof R(var i1, var i2) ? i1 : -1, + groupEquals((Function) x -> x instanceof R(var i1, var i2) ? i1 : -1, (Function) x -> x instanceof R(var i1, var i2) ? i1 : -1 ); - group((Function) x -> x instanceof R(Integer i1, int i2) ? i2 : -1, + groupEquals((Function) x -> x instanceof R(Integer i1, int i2) ? i2 : -1, (Function) x -> x instanceof R(Integer i1, int i2) ? i2 : -1 ); + + groupEquals((Function) x -> x instanceof int i2 ? i2 : -1, + (Function) x -> x instanceof int i2 ? i2 : -1); + + groupEquals((Function) x -> switch (x) { case String s -> s.length(); default -> -1; }, + (Function) x -> switch (x) { case String s -> s.length(); default -> -1; }); + + groupEquals((Function) x -> { + int y1 = -1; + return y1; + }, + (Function) x -> { + int y2 = -1; + return y2; + }); + + groupNotEquals((Function) x -> {class C {} new C(); return 42; }, (Function) x -> {class C {} new C(); return 42; }); } void f() {} diff --git a/test/langtools/tools/javac/lambda/deduplication/DeduplicationTest.java b/test/langtools/tools/javac/lambda/deduplication/DeduplicationTest.java index 8d948c55e7d..b9261632f61 100644 --- a/test/langtools/tools/javac/lambda/deduplication/DeduplicationTest.java +++ b/test/langtools/tools/javac/lambda/deduplication/DeduplicationTest.java @@ -48,9 +48,11 @@ import java.lang.classfile.*; import java.lang.classfile.attribute.BootstrapMethodsAttribute; import java.lang.classfile.constantpool.MethodHandleEntry; import com.sun.tools.javac.api.ClientCodeWrapper.Trusted; +import com.sun.tools.javac.api.JavacTaskImpl; import com.sun.tools.javac.api.JavacTool; import com.sun.tools.javac.code.Symbol; import com.sun.tools.javac.code.Symbol.MethodSymbol; +import com.sun.tools.javac.code.Types; import com.sun.tools.javac.comp.TreeDiffer; import com.sun.tools.javac.comp.TreeHasher; import com.sun.tools.javac.file.JavacFileManager; @@ -64,6 +66,8 @@ import com.sun.tools.javac.tree.JCTree.Tag; import com.sun.tools.javac.tree.TreeScanner; import com.sun.tools.javac.util.Context; import com.sun.tools.javac.util.JCDiagnostic; +import jdk.internal.classfile.impl.BootstrapMethodEntryImpl; + import java.io.InputStream; import java.nio.file.Path; import java.nio.file.Paths; @@ -103,8 +107,11 @@ public class DeduplicationTest { "-source", System.getProperty("java.specification.version")), null, fileManager.getJavaFileObjects(file)); + + Context context = ((JavacTaskImpl)task).getContext(); + Types types = Types.instance(context); Map dedupedLambdas = new LinkedHashMap<>(); - task.addTaskListener(new TreeDiffHashTaskListener(dedupedLambdas)); + task.addTaskListener(new TreeDiffHashTaskListener(dedupedLambdas, types)); Iterable generated = task.generate(); if (!diagnosticListener.unexpected.isEmpty()) { throw new AssertionError( @@ -142,15 +149,19 @@ public class DeduplicationTest { try (InputStream input = output.openInputStream()) { cm = ClassFile.of().parse(input.readAllBytes()); } - if (cm.thisClass().asInternalName().equals("com/sun/tools/javac/comp/Deduplication$R")) { + if (cm.thisClass().asInternalName().equals("com/sun/tools/javac/comp/Deduplication$R") || + cm.thisClass().asInternalName().equals("com/sun/tools/javac/comp/Deduplication$1C") || + cm.thisClass().asInternalName().equals("com/sun/tools/javac/comp/Deduplication$2C")) { continue; } BootstrapMethodsAttribute bsm = cm.findAttribute(Attributes.bootstrapMethods()).orElseThrow(); for (BootstrapMethodEntry b : bsm.bootstrapMethods()) { - bootstrapMethodNames.add( - ((MethodHandleEntry)b.arguments().get(1)) - .reference() - .name().stringValue()); + if (((BootstrapMethodEntryImpl) b).bootstrapMethod().asSymbol().methodName().equals("metafactory")) { + bootstrapMethodNames.add( + ((MethodHandleEntry) b.arguments().get(1)) + .reference() + .name().stringValue()); + } } } Set deduplicatedNames = @@ -249,9 +260,11 @@ public class DeduplicationTest { * deduplicated to. */ private final Map dedupedLambdas; + private final Types types; - public TreeDiffHashTaskListener(Map dedupedLambdas) { + public TreeDiffHashTaskListener(Map dedupedLambdas, Types types) { this.dedupedLambdas = dedupedLambdas; + this.types = types; } @Override @@ -262,31 +275,26 @@ public class DeduplicationTest { // Scan the compilation for calls to a varargs method named 'group', whose arguments // are a group of lambdas that are equivalent to each other, but distinct from all // lambdas in the compilation unit outside of that group. - List> lambdaGroups = new ArrayList<>(); + List> lambdaEqualsGroups = new ArrayList<>(); + List> lambdaNotEqualsGroups = new ArrayList<>(); + new TreeScanner() { @Override public void visitApply(JCMethodInvocation tree) { - if (tree.getMethodSelect().getTag() == Tag.IDENT - && ((JCIdent) tree.getMethodSelect()) - .getName() - .contentEquals("group")) { - List xs = new ArrayList<>(); - for (JCExpression arg : tree.getArguments()) { - if (arg instanceof JCTypeCast) { - arg = ((JCTypeCast) arg).getExpression(); - } - xs.add((JCLambda) arg); - } - lambdaGroups.add(xs); + if (isMethodWithName(tree, "groupEquals")) { + addToGroup(tree, lambdaEqualsGroups); + } else if (isMethodWithName(tree, "groupNotEquals")) { + addToGroup(tree, lambdaNotEqualsGroups); } super.visitApply(tree); } }.scan((JCCompilationUnit) e.getCompilationUnit()); - for (int i = 0; i < lambdaGroups.size(); i++) { - List curr = lambdaGroups.get(i); - JCLambda first = null; + + for (int i = 0; i < lambdaEqualsGroups.size(); i++) { + List curr = lambdaEqualsGroups.get(i); // Assert that all pairwise combinations of lambdas in the group are equal, and // hash to the same value. + JCLambda first = null; for (JCLambda lhs : curr) { if (first == null) { first = lhs; @@ -294,18 +302,20 @@ public class DeduplicationTest { dedupedLambdas.put(lhs, first); } for (JCLambda rhs : curr) { - if (!new TreeDiffer(paramSymbols(lhs), paramSymbols(rhs)) - .scan(lhs.body, rhs.body)) { - throw new AssertionError( - String.format( - "expected lambdas to be equal\n%s\n%s", lhs, rhs)); - } - if (TreeHasher.hash(lhs, paramSymbols(lhs)) - != TreeHasher.hash(rhs, paramSymbols(rhs))) { - throw new AssertionError( - String.format( - "expected lambdas to hash to the same value\n%s\n%s", - lhs, rhs)); + if (rhs != lhs) { + if (!new TreeDiffer(types, paramSymbols(lhs), paramSymbols(rhs)) + .scan(lhs.body, rhs.body)) { + throw new AssertionError( + String.format( + "expected lambdas to be equal\n%s\n%s", lhs, rhs)); + } + if (TreeHasher.hash(types, lhs, paramSymbols(lhs)) + != TreeHasher.hash(types, rhs, paramSymbols(rhs))) { + throw new AssertionError( + String.format( + "expected lambdas to hash to the same value\n%s\n%s", + lhs, rhs)); + } } } } @@ -313,31 +323,61 @@ public class DeduplicationTest { // or hash to the same value as lambda outside the group. // (Note that the hash collisions won't result in correctness problems but could // regress performs, and do not currently occurr for any of the test inputs.) - for (int j = 0; j < lambdaGroups.size(); j++) { - if (i == j) { - continue; - } - for (JCLambda lhs : curr) { - for (JCLambda rhs : lambdaGroups.get(j)) { - if (new TreeDiffer(paramSymbols(lhs), paramSymbols(rhs)) - .scan(lhs.body, rhs.body)) { - throw new AssertionError( - String.format( - "expected lambdas to not be equal\n%s\n%s", - lhs, rhs)); - } - if (TreeHasher.hash(lhs, paramSymbols(lhs)) - == TreeHasher.hash(rhs, paramSymbols(rhs))) { - throw new AssertionError( - String.format( - "expected lambdas to hash to different values\n%s\n%s", - lhs, rhs)); - } + assertNotEqualsWithinGroup(lambdaEqualsGroups, i, curr, types); + } + lambdaEqualsGroups.clear(); + + // Assert that no lambdas in a not-equals group are equal to any lambdas inside that group, + // or hash to the same value as lambda inside the group. + for (int i = 0; i < lambdaNotEqualsGroups.size(); i++) { + List curr = lambdaNotEqualsGroups.get(i); + + assertNotEqualsWithinGroup(lambdaNotEqualsGroups, i, curr, types); + } + lambdaNotEqualsGroups.clear(); + } + + private void assertNotEqualsWithinGroup(List> lambdaNotEqualsGroups, int i, List curr, Types types) { + for (int j = 0; j < lambdaNotEqualsGroups.size(); j++) { + if (i == j) { + continue; + } + for (JCLambda lhs : curr) { + for (JCLambda rhs : lambdaNotEqualsGroups.get(j)) { + if (new TreeDiffer(types, paramSymbols(lhs), paramSymbols(rhs)) + .scan(lhs.body, rhs.body)) { + throw new AssertionError( + String.format( + "expected lambdas to not be equal\n%s\n%s", + lhs, rhs)); + } + if (TreeHasher.hash(types, lhs, paramSymbols(lhs)) + == TreeHasher.hash(types, rhs, paramSymbols(rhs))) { + throw new AssertionError( + String.format( + "expected lambdas to hash to different values\n%s\n%s", + lhs, rhs)); } } } } - lambdaGroups.clear(); + } + + private boolean isMethodWithName(JCMethodInvocation tree, String markerMethodName) { + return tree.getMethodSelect().getTag() == Tag.IDENT && ((JCIdent) tree.getMethodSelect()) + .getName() + .contentEquals(markerMethodName); + } + + private void addToGroup(JCMethodInvocation tree, List> groupToAdd) { + List xs = new ArrayList<>(); + for (JCExpression arg : tree.getArguments()) { + if (arg instanceof JCTypeCast) { + arg = ((JCTypeCast) arg).getExpression(); + } + xs.add((JCLambda) arg); + } + groupToAdd.add(xs); } } }