8342967: Lambda deduplication fails with non-metafactory BSMs and mismatched local variables names

Reviewed-by: mcimadamore
This commit is contained in:
Aggelos Biboudis 2024-11-04 12:27:12 +00:00
parent b41d713ff4
commit 895a7b64f0
6 changed files with 203 additions and 114 deletions

View File

@ -217,7 +217,7 @@ public class LambdaToMethod extends TreeTranslator {
public int hashCode() { public int hashCode() {
int hashCode = this.hashCode; int hashCode = this.hashCode;
if (hashCode == 0) { if (hashCode == 0) {
this.hashCode = hashCode = TreeHasher.hash(tree, symbol.params()); this.hashCode = hashCode = TreeHasher.hash(types, tree, symbol.params());
} }
return hashCode; return hashCode;
} }
@ -226,7 +226,7 @@ public class LambdaToMethod extends TreeTranslator {
public boolean equals(Object o) { public boolean equals(Object o) {
return (o instanceof DedupedLambda dedupedLambda) return (o instanceof DedupedLambda dedupedLambda)
&& types.isSameType(symbol.asType(), dedupedLambda.symbol.asType()) && types.isSameType(symbol.asType(), dedupedLambda.symbol.asType())
&& new TreeDiffer(symbol.params(), dedupedLambda.symbol.params()).scan(tree, dedupedLambda.tree); && new TreeDiffer(types, symbol.params(), dedupedLambda.symbol.params()).scan(tree, dedupedLambda.tree);
} }
} }

View File

@ -993,7 +993,7 @@ public class TransPatterns extends TreeTranslator {
!currentNullable && !currentNullable &&
!previousCompletesNormally && !previousCompletesNormally &&
!currentCompletesNormally && !currentCompletesNormally &&
new TreeDiffer(List.of(commonBinding), List.of(currentBinding)) new TreeDiffer(types, List.of(commonBinding), List.of(currentBinding))
.scan(commonNestedExpression, currentNestedExpression)) { .scan(commonNestedExpression, currentNestedExpression)) {
accummulator.add(c.head); accummulator.add(c.head);
} else { } else {

View File

@ -28,6 +28,9 @@ package com.sun.tools.javac.comp;
import com.sun.tools.javac.code.Flags; import com.sun.tools.javac.code.Flags;
import com.sun.tools.javac.code.Symbol; import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.code.TypeTag;
import com.sun.tools.javac.code.Types;
import com.sun.tools.javac.jvm.PoolConstant;
import com.sun.tools.javac.tree.JCTree; import com.sun.tools.javac.tree.JCTree;
import com.sun.tools.javac.tree.JCTree.JCAnnotatedType; import com.sun.tools.javac.tree.JCTree.JCAnnotatedType;
import com.sun.tools.javac.tree.JCTree.JCAnnotation; import com.sun.tools.javac.tree.JCTree.JCAnnotation;
@ -107,10 +110,10 @@ import java.util.Objects;
/** A visitor that compares two lambda bodies for structural equality. */ /** A visitor that compares two lambda bodies for structural equality. */
public class TreeDiffer extends TreeScanner { public class TreeDiffer extends TreeScanner {
public TreeDiffer(Types types,
public TreeDiffer(
Collection<? extends Symbol> symbols, Collection<? extends Symbol> otherSymbols) { Collection<? extends Symbol> symbols, Collection<? extends Symbol> otherSymbols) {
this.equiv = equiv(symbols, otherSymbols); this.equiv = equiv(symbols, otherSymbols);
this.types = types;
} }
private static Map<Symbol, Symbol> equiv( private static Map<Symbol, Symbol> equiv(
@ -127,6 +130,7 @@ public class TreeDiffer extends TreeScanner {
private JCTree parameter; private JCTree parameter;
private boolean result; private boolean result;
private Map<Symbol, Symbol> equiv = new HashMap<>(); private Map<Symbol, Symbol> equiv = new HashMap<>();
final Types types;
public boolean scan(JCTree tree, JCTree parameter) { public boolean scan(JCTree tree, JCTree parameter) {
if (tree == null || parameter == null) { if (tree == null || parameter == null) {
@ -197,13 +201,24 @@ public class TreeDiffer extends TreeScanner {
return; return;
} }
} }
result = tree.sym == that.sym; result = scanSymbol(symbol, otherSymbol);
}
private boolean scanSymbol(Symbol symbol, Symbol otherSymbol) {
if (symbol instanceof PoolConstant.Dynamic dms && otherSymbol instanceof PoolConstant.Dynamic other_dms) {
return dms.bsmKey(types).equals(other_dms.bsmKey(types));
}
else {
return symbol == otherSymbol;
}
} }
@Override @Override
public void visitSelect(JCFieldAccess tree) { public void visitSelect(JCFieldAccess tree) {
JCFieldAccess that = (JCFieldAccess) parameter; JCFieldAccess that = (JCFieldAccess) parameter;
result = scan(tree.selected, that.selected) && tree.sym == that.sym;
result = scan(tree.selected, that.selected) &&
scanSymbol(tree.sym, that.sym);
} }
@Override @Override
@ -328,14 +343,7 @@ public class TreeDiffer extends TreeScanner {
@Override @Override
public void visitClassDef(JCClassDecl tree) { public void visitClassDef(JCClassDecl tree) {
JCClassDecl that = (JCClassDecl) parameter; result = false;
result =
scan(tree.mods, that.mods)
&& tree.name == that.name
&& scan(tree.typarams, that.typarams)
&& scan(tree.extending, that.extending)
&& scan(tree.implementing, that.implementing)
&& scan(tree.defs, that.defs);
} }
@Override @Override
@ -667,15 +675,19 @@ public class TreeDiffer extends TreeScanner {
JCVariableDecl that = (JCVariableDecl) parameter; JCVariableDecl that = (JCVariableDecl) parameter;
result = result =
scan(tree.mods, that.mods) scan(tree.mods, that.mods)
&& tree.name == that.name
&& scan(tree.nameexpr, that.nameexpr) && scan(tree.nameexpr, that.nameexpr)
&& scan(tree.vartype, that.vartype) && scan(tree.vartype, that.vartype)
&& scan(tree.init, that.init); && scan(tree.init, that.init);
if (!result) {
return; if (tree.sym.owner.type.hasTag(TypeTag.CLASS)) {
// field names are important!
result &= tree.name == that.name;
} }
if (result) {
equiv.put(tree.sym, that.sym); equiv.put(tree.sym, that.sym);
} }
}
@Override @Override
public void visitWhileLoop(JCWhileLoop tree) { public void visitWhileLoop(JCWhileLoop tree) {

View File

@ -27,7 +27,10 @@
package com.sun.tools.javac.comp; package com.sun.tools.javac.comp;
import com.sun.tools.javac.code.Symbol; import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.code.Types;
import com.sun.tools.javac.jvm.PoolConstant;
import com.sun.tools.javac.tree.JCTree; import com.sun.tools.javac.tree.JCTree;
import com.sun.tools.javac.tree.JCTree.JCClassDecl;
import com.sun.tools.javac.tree.JCTree.JCFieldAccess; import com.sun.tools.javac.tree.JCTree.JCFieldAccess;
import com.sun.tools.javac.tree.JCTree.JCIdent; import com.sun.tools.javac.tree.JCTree.JCIdent;
import com.sun.tools.javac.tree.JCTree.JCLiteral; import com.sun.tools.javac.tree.JCTree.JCLiteral;
@ -43,19 +46,21 @@ import java.util.Objects;
public class TreeHasher extends TreeScanner { public class TreeHasher extends TreeScanner {
private final Map<Symbol, Integer> symbolHashes; private final Map<Symbol, Integer> symbolHashes;
private final Types types;
private int result = 17; private int result = 17;
public TreeHasher(Map<Symbol, Integer> symbolHashes) { public TreeHasher(Types types, Map<Symbol, Integer> symbolHashes) {
this.symbolHashes = Objects.requireNonNull(symbolHashes); this.symbolHashes = Objects.requireNonNull(symbolHashes);
this.types = types;
} }
public static int hash(JCTree tree, Collection<? extends Symbol> symbols) { public static int hash(Types types, JCTree tree, Collection<? extends Symbol> symbols) {
if (tree == null) { if (tree == null) {
return 0; return 0;
} }
Map<Symbol, Integer> symbolHashes = new HashMap<>(); Map<Symbol, Integer> symbolHashes = new HashMap<>();
symbols.forEach(s -> symbolHashes.put(s, symbolHashes.size())); symbols.forEach(s -> symbolHashes.put(s, symbolHashes.size()));
TreeHasher hasher = new TreeHasher(symbolHashes); TreeHasher hasher = new TreeHasher(types, symbolHashes);
tree.accept(hasher); tree.accept(hasher);
return hasher.result; return hasher.result;
} }
@ -87,6 +92,11 @@ public class TreeHasher extends TreeScanner {
super.visitLiteral(tree); super.visitLiteral(tree);
} }
@Override
public void visitClassDef(JCClassDecl tree) {
hash(tree.sym);
}
@Override @Override
public void visitIdent(JCIdent tree) { public void visitIdent(JCIdent tree) {
Symbol sym = tree.sym; Symbol sym = tree.sym;
@ -97,15 +107,23 @@ public class TreeHasher extends TreeScanner {
return; return;
} }
} }
hash(sym); hashSymbol(sym);
} }
@Override @Override
public void visitSelect(JCFieldAccess tree) { public void visitSelect(JCFieldAccess tree) {
hash(tree.sym); hashSymbol(tree.sym);
super.visitSelect(tree); super.visitSelect(tree);
} }
private void hashSymbol(Symbol sym) {
if (sym instanceof PoolConstant.Dynamic dynamic) {
hash(dynamic.bsmKey(types));
} else {
hash(sym);
}
}
@Override @Override
public void visitVarDef(JCVariableDecl tree) { public void visitVarDef(JCVariableDecl tree) {
symbolHashes.computeIfAbsent(tree.sym, k -> symbolHashes.size()); symbolHashes.computeIfAbsent(tree.sym, k -> symbolHashes.size());

View File

@ -29,52 +29,54 @@ import java.util.function.Function;
import java.util.function.Supplier; import java.util.function.Supplier;
public class Deduplication { public class Deduplication {
void groupEquals(Object... xs) {}
void groupNotEquals(Object... xs) {}
void group(Object... xs) {} void group(Object... xs) {}
void test() { void test() {
group( groupEquals(
(Runnable) () -> { ( (Runnable) () -> {} ).run(); }, (Runnable) () -> { ( (Runnable) () -> {} ).run(); },
(Runnable) () -> { ( (Runnable) () -> {} ).run(); } (Runnable) () -> { ( (Runnable) () -> {} ).run(); }
); );
group( groupEquals(
(Runnable) () -> { Deduplication.class.toString(); }, (Runnable) () -> { Deduplication.class.toString(); },
(Runnable) () -> { Deduplication.class.toString(); } (Runnable) () -> { Deduplication.class.toString(); }
); );
group( groupEquals(
(Runnable) () -> { Integer[].class.toString(); }, (Runnable) () -> { Integer[].class.toString(); },
(Runnable) () -> { Integer[].class.toString(); } (Runnable) () -> { Integer[].class.toString(); }
); );
group( groupEquals(
(Runnable) () -> { char.class.toString(); }, (Runnable) () -> { char.class.toString(); },
(Runnable) () -> { char.class.toString(); } (Runnable) () -> { char.class.toString(); }
); );
group( groupEquals(
(Runnable) () -> { Void.class.toString(); }, (Runnable) () -> { Void.class.toString(); },
(Runnable) () -> { Void.class.toString(); } (Runnable) () -> { Void.class.toString(); }
); );
group( groupEquals(
(Runnable) () -> { void.class.toString(); }, (Runnable) () -> { void.class.toString(); },
(Runnable) () -> { void.class.toString(); } (Runnable) () -> { void.class.toString(); }
); );
group((Function<String, Integer>) x -> x.hashCode()); groupEquals((Function<String, Integer>) x -> x.hashCode());
group((Function<Object, Integer>) x -> x.hashCode()); groupEquals((Function<Object, Integer>) x -> x.hashCode());
{ {
int x = 1; int x = 1;
group((Supplier<Integer>) () -> x + 1); groupEquals((Supplier<Integer>) () -> x + 1);
} }
{ {
int x = 1; int x = 1;
group((Supplier<Integer>) () -> x + 1); groupEquals((Supplier<Integer>) () -> x + 1);
} }
group( groupEquals(
(BiFunction<Integer, Integer, ?>) (x, y) -> x + ((y)), (BiFunction<Integer, Integer, ?>) (x, y) -> x + ((y)),
(BiFunction<Integer, Integer, ?>) (x, y) -> x + (y), (BiFunction<Integer, Integer, ?>) (x, y) -> x + (y),
(BiFunction<Integer, Integer, ?>) (x, y) -> x + y, (BiFunction<Integer, Integer, ?>) (x, y) -> x + y,
@ -85,29 +87,29 @@ public class Deduplication {
(BiFunction<Integer, Integer, ?>) (x, y) -> ((x)) + (y), (BiFunction<Integer, Integer, ?>) (x, y) -> ((x)) + (y),
(BiFunction<Integer, Integer, ?>) (x, y) -> ((x)) + y); (BiFunction<Integer, Integer, ?>) (x, y) -> ((x)) + y);
group( groupEquals(
(Function<Integer, Integer>) x -> x + (1 + 2 + 3), (Function<Integer, Integer>) x -> x + (1 + 2 + 3),
(Function<Integer, Integer>) x -> x + 6); (Function<Integer, Integer>) x -> x + 6);
group((Function<Integer, Integer>) x -> x + 1, (Function<Integer, Integer>) y -> y + 1); groupEquals((Function<Integer, Integer>) x -> x + 1, (Function<Integer, Integer>) y -> y + 1);
group((Consumer<Integer>) x -> this.f(), (Consumer<Integer>) x -> this.f()); groupEquals((Consumer<Integer>) x -> this.f(), (Consumer<Integer>) x -> this.f());
group((Consumer<Integer>) y -> this.g()); groupEquals((Consumer<Integer>) y -> this.g());
group((Consumer<Integer>) x -> f(), (Consumer<Integer>) x -> f()); groupEquals((Consumer<Integer>) x -> f(), (Consumer<Integer>) x -> f());
group((Consumer<Integer>) y -> g()); groupEquals((Consumer<Integer>) y -> g());
group((Function<Integer, Integer>) x -> this.i, (Function<Integer, Integer>) x -> this.i); groupEquals((Function<Integer, Integer>) x -> this.i, (Function<Integer, Integer>) x -> this.i);
group((Function<Integer, Integer>) y -> this.j); groupEquals((Function<Integer, Integer>) y -> this.j);
group((Function<Integer, Integer>) x -> i, (Function<Integer, Integer>) x -> i); groupEquals((Function<Integer, Integer>) x -> i, (Function<Integer, Integer>) x -> i);
group((Function<Integer, Integer>) y -> j); groupEquals((Function<Integer, Integer>) y -> j);
group( groupEquals(
(Function<Integer, Integer>) (Function<Integer, Integer>)
y -> { y -> {
while (true) { while (true) {
@ -123,7 +125,7 @@ public class Deduplication {
return 42; return 42;
}); });
group( groupEquals(
(Function<Integer, Integer>) (Function<Integer, Integer>)
x -> { x -> {
int y = x; int y = x;
@ -135,13 +137,13 @@ public class Deduplication {
return y; return y;
}); });
group( groupEquals(
(Function<Integer, Integer>) (Function<Integer, Integer>)
x -> { x -> {
int y = 0, z = x; int y = 0, z = x;
return y; return y;
}); });
group( groupEquals(
(Function<Integer, Integer>) (Function<Integer, Integer>)
x -> { x -> {
int y = 0, z = x; int y = 0, z = x;
@ -154,24 +156,41 @@ public class Deduplication {
void f() {} void f() {}
{ {
group((Function<Integer, Integer>) x -> this.i); groupEquals((Function<Integer, Integer>) x -> this.i);
group((Consumer<Integer>) x -> this.f()); groupEquals((Consumer<Integer>) x -> this.f());
group((Function<Integer, Integer>) x -> Deduplication.this.i); groupEquals((Function<Integer, Integer>) x -> Deduplication.this.i);
group((Consumer<Integer>) x -> Deduplication.this.f()); groupEquals((Consumer<Integer>) x -> Deduplication.this.f());
} }
} }
group((Function<Integer, Integer>) x -> switch (x) { default: yield x; }, groupEquals((Function<Integer, Integer>) x -> switch (x) { default: yield x; },
(Function<Integer, Integer>) x -> switch (x) { default: yield x; }); (Function<Integer, Integer>) x -> switch (x) { default: yield x; });
group((Function<Object, Integer>) x -> x instanceof Integer i ? i : -1, groupEquals((Function<Object, Integer>) x -> x instanceof Integer i ? i : -1,
(Function<Object, Integer>) x -> x instanceof Integer i ? i : -1); (Function<Object, Integer>) x -> x instanceof Integer i ? i : -1);
group((Function<Object, Integer>) x -> x instanceof R(var i1, var i2) ? i1 : -1, groupEquals((Function<Object, Integer>) x -> x instanceof R(var i1, var i2) ? i1 : -1,
(Function<Object, Integer>) x -> x instanceof R(var i1, var i2) ? i1 : -1 ); (Function<Object, Integer>) x -> x instanceof R(var i1, var i2) ? i1 : -1 );
group((Function<Object, Integer>) x -> x instanceof R(Integer i1, int i2) ? i2 : -1, groupEquals((Function<Object, Integer>) x -> x instanceof R(Integer i1, int i2) ? i2 : -1,
(Function<Object, Integer>) x -> x instanceof R(Integer i1, int i2) ? i2 : -1 ); (Function<Object, Integer>) x -> x instanceof R(Integer i1, int i2) ? i2 : -1 );
groupEquals((Function<Object, Integer>) x -> x instanceof int i2 ? i2 : -1,
(Function<Object, Integer>) x -> x instanceof int i2 ? i2 : -1);
groupEquals((Function<Object, Integer>) x -> switch (x) { case String s -> s.length(); default -> -1; },
(Function<Object, Integer>) x -> switch (x) { case String s -> s.length(); default -> -1; });
groupEquals((Function<Object, Integer>) x -> {
int y1 = -1;
return y1;
},
(Function<Object, Integer>) x -> {
int y2 = -1;
return y2;
});
groupNotEquals((Function<Object, Integer>) x -> {class C {} new C(); return 42; }, (Function<Object, Integer>) x -> {class C {} new C(); return 42; });
} }
void f() {} void f() {}

View File

@ -48,9 +48,11 @@ import java.lang.classfile.*;
import java.lang.classfile.attribute.BootstrapMethodsAttribute; import java.lang.classfile.attribute.BootstrapMethodsAttribute;
import java.lang.classfile.constantpool.MethodHandleEntry; import java.lang.classfile.constantpool.MethodHandleEntry;
import com.sun.tools.javac.api.ClientCodeWrapper.Trusted; import com.sun.tools.javac.api.ClientCodeWrapper.Trusted;
import com.sun.tools.javac.api.JavacTaskImpl;
import com.sun.tools.javac.api.JavacTool; import com.sun.tools.javac.api.JavacTool;
import com.sun.tools.javac.code.Symbol; import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.code.Symbol.MethodSymbol; import com.sun.tools.javac.code.Symbol.MethodSymbol;
import com.sun.tools.javac.code.Types;
import com.sun.tools.javac.comp.TreeDiffer; import com.sun.tools.javac.comp.TreeDiffer;
import com.sun.tools.javac.comp.TreeHasher; import com.sun.tools.javac.comp.TreeHasher;
import com.sun.tools.javac.file.JavacFileManager; import com.sun.tools.javac.file.JavacFileManager;
@ -64,6 +66,8 @@ import com.sun.tools.javac.tree.JCTree.Tag;
import com.sun.tools.javac.tree.TreeScanner; import com.sun.tools.javac.tree.TreeScanner;
import com.sun.tools.javac.util.Context; import com.sun.tools.javac.util.Context;
import com.sun.tools.javac.util.JCDiagnostic; import com.sun.tools.javac.util.JCDiagnostic;
import jdk.internal.classfile.impl.BootstrapMethodEntryImpl;
import java.io.InputStream; import java.io.InputStream;
import java.nio.file.Path; import java.nio.file.Path;
import java.nio.file.Paths; import java.nio.file.Paths;
@ -103,8 +107,11 @@ public class DeduplicationTest {
"-source", System.getProperty("java.specification.version")), "-source", System.getProperty("java.specification.version")),
null, null,
fileManager.getJavaFileObjects(file)); fileManager.getJavaFileObjects(file));
Context context = ((JavacTaskImpl)task).getContext();
Types types = Types.instance(context);
Map<JCLambda, JCLambda> dedupedLambdas = new LinkedHashMap<>(); Map<JCLambda, JCLambda> dedupedLambdas = new LinkedHashMap<>();
task.addTaskListener(new TreeDiffHashTaskListener(dedupedLambdas)); task.addTaskListener(new TreeDiffHashTaskListener(dedupedLambdas, types));
Iterable<? extends JavaFileObject> generated = task.generate(); Iterable<? extends JavaFileObject> generated = task.generate();
if (!diagnosticListener.unexpected.isEmpty()) { if (!diagnosticListener.unexpected.isEmpty()) {
throw new AssertionError( throw new AssertionError(
@ -142,17 +149,21 @@ public class DeduplicationTest {
try (InputStream input = output.openInputStream()) { try (InputStream input = output.openInputStream()) {
cm = ClassFile.of().parse(input.readAllBytes()); cm = ClassFile.of().parse(input.readAllBytes());
} }
if (cm.thisClass().asInternalName().equals("com/sun/tools/javac/comp/Deduplication$R")) { if (cm.thisClass().asInternalName().equals("com/sun/tools/javac/comp/Deduplication$R") ||
cm.thisClass().asInternalName().equals("com/sun/tools/javac/comp/Deduplication$1C") ||
cm.thisClass().asInternalName().equals("com/sun/tools/javac/comp/Deduplication$2C")) {
continue; continue;
} }
BootstrapMethodsAttribute bsm = cm.findAttribute(Attributes.bootstrapMethods()).orElseThrow(); BootstrapMethodsAttribute bsm = cm.findAttribute(Attributes.bootstrapMethods()).orElseThrow();
for (BootstrapMethodEntry b : bsm.bootstrapMethods()) { for (BootstrapMethodEntry b : bsm.bootstrapMethods()) {
if (((BootstrapMethodEntryImpl) b).bootstrapMethod().asSymbol().methodName().equals("metafactory")) {
bootstrapMethodNames.add( bootstrapMethodNames.add(
((MethodHandleEntry) b.arguments().get(1)) ((MethodHandleEntry) b.arguments().get(1))
.reference() .reference()
.name().stringValue()); .name().stringValue());
} }
} }
}
Set<String> deduplicatedNames = Set<String> deduplicatedNames =
diagnosticListener diagnosticListener
.expectedLambdaMethods() .expectedLambdaMethods()
@ -249,9 +260,11 @@ public class DeduplicationTest {
* deduplicated to. * deduplicated to.
*/ */
private final Map<JCLambda, JCLambda> dedupedLambdas; private final Map<JCLambda, JCLambda> dedupedLambdas;
private final Types types;
public TreeDiffHashTaskListener(Map<JCLambda, JCLambda> dedupedLambdas) { public TreeDiffHashTaskListener(Map<JCLambda, JCLambda> dedupedLambdas, Types types) {
this.dedupedLambdas = dedupedLambdas; this.dedupedLambdas = dedupedLambdas;
this.types = types;
} }
@Override @Override
@ -262,31 +275,26 @@ public class DeduplicationTest {
// Scan the compilation for calls to a varargs method named 'group', whose arguments // Scan the compilation for calls to a varargs method named 'group', whose arguments
// are a group of lambdas that are equivalent to each other, but distinct from all // are a group of lambdas that are equivalent to each other, but distinct from all
// lambdas in the compilation unit outside of that group. // lambdas in the compilation unit outside of that group.
List<List<JCLambda>> lambdaGroups = new ArrayList<>(); List<List<JCLambda>> lambdaEqualsGroups = new ArrayList<>();
List<List<JCLambda>> lambdaNotEqualsGroups = new ArrayList<>();
new TreeScanner() { new TreeScanner() {
@Override @Override
public void visitApply(JCMethodInvocation tree) { public void visitApply(JCMethodInvocation tree) {
if (tree.getMethodSelect().getTag() == Tag.IDENT if (isMethodWithName(tree, "groupEquals")) {
&& ((JCIdent) tree.getMethodSelect()) addToGroup(tree, lambdaEqualsGroups);
.getName() } else if (isMethodWithName(tree, "groupNotEquals")) {
.contentEquals("group")) { addToGroup(tree, lambdaNotEqualsGroups);
List<JCLambda> xs = new ArrayList<>();
for (JCExpression arg : tree.getArguments()) {
if (arg instanceof JCTypeCast) {
arg = ((JCTypeCast) arg).getExpression();
}
xs.add((JCLambda) arg);
}
lambdaGroups.add(xs);
} }
super.visitApply(tree); super.visitApply(tree);
} }
}.scan((JCCompilationUnit) e.getCompilationUnit()); }.scan((JCCompilationUnit) e.getCompilationUnit());
for (int i = 0; i < lambdaGroups.size(); i++) {
List<JCLambda> curr = lambdaGroups.get(i); for (int i = 0; i < lambdaEqualsGroups.size(); i++) {
JCLambda first = null; List<JCLambda> curr = lambdaEqualsGroups.get(i);
// Assert that all pairwise combinations of lambdas in the group are equal, and // Assert that all pairwise combinations of lambdas in the group are equal, and
// hash to the same value. // hash to the same value.
JCLambda first = null;
for (JCLambda lhs : curr) { for (JCLambda lhs : curr) {
if (first == null) { if (first == null) {
first = lhs; first = lhs;
@ -294,14 +302,15 @@ public class DeduplicationTest {
dedupedLambdas.put(lhs, first); dedupedLambdas.put(lhs, first);
} }
for (JCLambda rhs : curr) { for (JCLambda rhs : curr) {
if (!new TreeDiffer(paramSymbols(lhs), paramSymbols(rhs)) if (rhs != lhs) {
if (!new TreeDiffer(types, paramSymbols(lhs), paramSymbols(rhs))
.scan(lhs.body, rhs.body)) { .scan(lhs.body, rhs.body)) {
throw new AssertionError( throw new AssertionError(
String.format( String.format(
"expected lambdas to be equal\n%s\n%s", lhs, rhs)); "expected lambdas to be equal\n%s\n%s", lhs, rhs));
} }
if (TreeHasher.hash(lhs, paramSymbols(lhs)) if (TreeHasher.hash(types, lhs, paramSymbols(lhs))
!= TreeHasher.hash(rhs, paramSymbols(rhs))) { != TreeHasher.hash(types, rhs, paramSymbols(rhs))) {
throw new AssertionError( throw new AssertionError(
String.format( String.format(
"expected lambdas to hash to the same value\n%s\n%s", "expected lambdas to hash to the same value\n%s\n%s",
@ -309,25 +318,41 @@ public class DeduplicationTest {
} }
} }
} }
}
// Assert that no lambdas in a group are equal to any lambdas outside that group, // Assert that no lambdas in a group are equal to any lambdas outside that group,
// or hash to the same value as lambda outside the group. // or hash to the same value as lambda outside the group.
// (Note that the hash collisions won't result in correctness problems but could // (Note that the hash collisions won't result in correctness problems but could
// regress performs, and do not currently occurr for any of the test inputs.) // regress performs, and do not currently occurr for any of the test inputs.)
for (int j = 0; j < lambdaGroups.size(); j++) { assertNotEqualsWithinGroup(lambdaEqualsGroups, i, curr, types);
}
lambdaEqualsGroups.clear();
// Assert that no lambdas in a not-equals group are equal to any lambdas inside that group,
// or hash to the same value as lambda inside the group.
for (int i = 0; i < lambdaNotEqualsGroups.size(); i++) {
List<JCLambda> curr = lambdaNotEqualsGroups.get(i);
assertNotEqualsWithinGroup(lambdaNotEqualsGroups, i, curr, types);
}
lambdaNotEqualsGroups.clear();
}
private void assertNotEqualsWithinGroup(List<List<JCLambda>> lambdaNotEqualsGroups, int i, List<JCLambda> curr, Types types) {
for (int j = 0; j < lambdaNotEqualsGroups.size(); j++) {
if (i == j) { if (i == j) {
continue; continue;
} }
for (JCLambda lhs : curr) { for (JCLambda lhs : curr) {
for (JCLambda rhs : lambdaGroups.get(j)) { for (JCLambda rhs : lambdaNotEqualsGroups.get(j)) {
if (new TreeDiffer(paramSymbols(lhs), paramSymbols(rhs)) if (new TreeDiffer(types, paramSymbols(lhs), paramSymbols(rhs))
.scan(lhs.body, rhs.body)) { .scan(lhs.body, rhs.body)) {
throw new AssertionError( throw new AssertionError(
String.format( String.format(
"expected lambdas to not be equal\n%s\n%s", "expected lambdas to not be equal\n%s\n%s",
lhs, rhs)); lhs, rhs));
} }
if (TreeHasher.hash(lhs, paramSymbols(lhs)) if (TreeHasher.hash(types, lhs, paramSymbols(lhs))
== TreeHasher.hash(rhs, paramSymbols(rhs))) { == TreeHasher.hash(types, rhs, paramSymbols(rhs))) {
throw new AssertionError( throw new AssertionError(
String.format( String.format(
"expected lambdas to hash to different values\n%s\n%s", "expected lambdas to hash to different values\n%s\n%s",
@ -337,7 +362,22 @@ public class DeduplicationTest {
} }
} }
} }
lambdaGroups.clear();
private boolean isMethodWithName(JCMethodInvocation tree, String markerMethodName) {
return tree.getMethodSelect().getTag() == Tag.IDENT && ((JCIdent) tree.getMethodSelect())
.getName()
.contentEquals(markerMethodName);
}
private void addToGroup(JCMethodInvocation tree, List<List<JCLambda>> groupToAdd) {
List<JCLambda> xs = new ArrayList<>();
for (JCExpression arg : tree.getArguments()) {
if (arg instanceof JCTypeCast) {
arg = ((JCTypeCast) arg).getExpression();
}
xs.add((JCLambda) arg);
}
groupToAdd.add(xs);
} }
} }
} }