diff --git a/src/java.base/share/classes/java/util/stream/Collectors.java b/src/java.base/share/classes/java/util/stream/Collectors.java
index 26d98bf6d42..29832411c79 100644
--- a/src/java.base/share/classes/java/util/stream/Collectors.java
+++ b/src/java.base/share/classes/java/util/stream/Collectors.java
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2012, 2017, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2012, 2018, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@@ -1884,6 +1884,102 @@ public final class Collectors {
(l, r) -> { l.combine(r); return l; }, CH_ID);
}
+ /**
+ * Returns a {@code Collector} that is a composite of two downstream collectors.
+ * Every element passed to the resulting collector is processed by both downstream
+ * collectors, then their results are merged using the specified merge function
+ * into the final result.
+ *
+ *
The resulting collector functions do the following:
+ *
+ *
+ * - supplier: creates a result container that contains result containers
+ * obtained by calling each collector's supplier
+ *
- accumulator: calls each collector's accumulator with its result container
+ * and the input element
+ *
- combiner: calls each collector's combiner with two result containers
+ *
- finisher: calls each collector's finisher with its result container,
+ * then calls the supplied merger and returns its result.
+ *
+ *
+ * The resulting collector is {@link Collector.Characteristics#UNORDERED} if both downstream
+ * collectors are unordered and {@link Collector.Characteristics#CONCURRENT} if both downstream
+ * collectors are concurrent.
+ *
+ * @param the type of the input elements
+ * @param the result type of the first collector
+ * @param the result type of the second collector
+ * @param the final result type
+ * @param downstream1 the first downstream collector
+ * @param downstream2 the second downstream collector
+ * @param merger the function which merges two results into the single one
+ * @return a {@code Collector} which aggregates the results of two supplied collectors.
+ * @since 12
+ */
+ public static
+ Collector teeing(Collector super T, ?, R1> downstream1,
+ Collector super T, ?, R2> downstream2,
+ BiFunction super R1, ? super R2, R> merger) {
+ return teeing0(downstream1, downstream2, merger);
+ }
+
+ private static
+ Collector teeing0(Collector super T, A1, R1> downstream1,
+ Collector super T, A2, R2> downstream2,
+ BiFunction super R1, ? super R2, R> merger) {
+ Objects.requireNonNull(downstream1, "downstream1");
+ Objects.requireNonNull(downstream2, "downstream2");
+ Objects.requireNonNull(merger, "merger");
+
+ Supplier c1Supplier = Objects.requireNonNull(downstream1.supplier(), "downstream1 supplier");
+ Supplier c2Supplier = Objects.requireNonNull(downstream2.supplier(), "downstream2 supplier");
+ BiConsumer c1Accumulator =
+ Objects.requireNonNull(downstream1.accumulator(), "downstream1 accumulator");
+ BiConsumer c2Accumulator =
+ Objects.requireNonNull(downstream2.accumulator(), "downstream2 accumulator");
+ BinaryOperator c1Combiner = Objects.requireNonNull(downstream1.combiner(), "downstream1 combiner");
+ BinaryOperator c2Combiner = Objects.requireNonNull(downstream2.combiner(), "downstream2 combiner");
+ Function c1Finisher = Objects.requireNonNull(downstream1.finisher(), "downstream1 finisher");
+ Function c2Finisher = Objects.requireNonNull(downstream2.finisher(), "downstream2 finisher");
+
+ Set characteristics;
+ Set c1Characteristics = downstream1.characteristics();
+ Set c2Characteristics = downstream2.characteristics();
+ if (CH_ID.containsAll(c1Characteristics) || CH_ID.containsAll(c2Characteristics)) {
+ characteristics = CH_NOID;
+ } else {
+ EnumSet c = EnumSet.noneOf(Collector.Characteristics.class);
+ c.addAll(c1Characteristics);
+ c.retainAll(c2Characteristics);
+ c.remove(Collector.Characteristics.IDENTITY_FINISH);
+ characteristics = Collections.unmodifiableSet(c);
+ }
+
+ class PairBox {
+ A1 left = c1Supplier.get();
+ A2 right = c2Supplier.get();
+
+ void add(T t) {
+ c1Accumulator.accept(left, t);
+ c2Accumulator.accept(right, t);
+ }
+
+ PairBox combine(PairBox other) {
+ left = c1Combiner.apply(left, other.left);
+ right = c2Combiner.apply(right, other.right);
+ return this;
+ }
+
+ R get() {
+ R1 r1 = c1Finisher.apply(left);
+ R2 r2 = c2Finisher.apply(right);
+ return merger.apply(r1, r2);
+ }
+ }
+
+ return new CollectorImpl<>(PairBox::new, PairBox::add, PairBox::combine, PairBox::get, characteristics);
+ }
+
/**
* Implementation class used by partitioningBy.
*/
diff --git a/test/jdk/java/util/stream/test/org/openjdk/tests/java/util/stream/CollectorsTest.java b/test/jdk/java/util/stream/test/org/openjdk/tests/java/util/stream/CollectorsTest.java
index d07b6eba4a7..4ce3916bcbf 100644
--- a/test/jdk/java/util/stream/test/org/openjdk/tests/java/util/stream/CollectorsTest.java
+++ b/test/jdk/java/util/stream/test/org/openjdk/tests/java/util/stream/CollectorsTest.java
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2012, 2015, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2012, 2018, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@@ -29,6 +29,7 @@ import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
+import java.util.IntSummaryStatistics;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
@@ -39,6 +40,7 @@ import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiFunction;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Predicate;
@@ -96,7 +98,7 @@ public class CollectorsTest extends OpTestCase {
@Override
void assertValue(R value, Supplier> source, boolean ordered) throws ReflectiveOperationException {
downstream.assertValue(value,
- () -> source.get().map(mapper::apply),
+ () -> source.get().map(mapper),
ordered);
}
}
@@ -114,7 +116,7 @@ public class CollectorsTest extends OpTestCase {
@Override
void assertValue(R value, Supplier> source, boolean ordered) throws ReflectiveOperationException {
downstream.assertValue(value,
- () -> source.get().flatMap(mapper::apply),
+ () -> source.get().flatMap(mapper),
ordered);
}
}
@@ -287,6 +289,27 @@ public class CollectorsTest extends OpTestCase {
}
}
+ static class TeeingAssertion extends CollectorAssertion {
+ private final Collector c1;
+ private final Collector c2;
+ private final BiFunction super R1, ? super R2, ? extends RR> finisher;
+
+ TeeingAssertion(Collector c1, Collector c2,
+ BiFunction super R1, ? super R2, ? extends RR> finisher) {
+ this.c1 = c1;
+ this.c2 = c2;
+ this.finisher = finisher;
+ }
+
+ @Override
+ void assertValue(RR value, Supplier> source, boolean ordered) {
+ R1 r1 = source.get().collect(c1);
+ R2 r2 = source.get().collect(c2);
+ RR expected = finisher.apply(r1, r2);
+ assertEquals(value, expected);
+ }
+ }
+
private ResultAsserter mapTabulationAsserter(boolean ordered) {
return (act, exp, ord, par) -> {
if (par && (!ordered || !ord)) {
@@ -746,4 +769,42 @@ public class CollectorsTest extends OpTestCase {
catch (UnsupportedOperationException ignored) { }
}
+ @Test(dataProvider = "StreamTestData", dataProviderClass = StreamTestDataProvider.class)
+ public void testTeeing(String name, TestData.OfRef data) throws ReflectiveOperationException {
+ Collector summing = Collectors.summingLong(Integer::valueOf);
+ Collector counting = Collectors.counting();
+ Collector min = collectingAndThen(Collectors.minBy(Comparator.naturalOrder()),
+ opt -> opt.orElse(Integer.MAX_VALUE));
+ Collector max = collectingAndThen(Collectors.maxBy(Comparator.naturalOrder()),
+ opt -> opt.orElse(Integer.MIN_VALUE));
+ Collector joining = mapping(String::valueOf, Collectors.joining(", ", "[", "]"));
+
+ Collector> sumAndCount = Collectors.teeing(summing, counting, Map::entry);
+ Collector> minAndMax = Collectors.teeing(min, max, Map::entry);
+ Collector averaging = Collectors.teeing(summing, counting,
+ (sum, count) -> ((double)sum) / count);
+ Collector summaryStatistics = Collectors.teeing(sumAndCount, minAndMax,
+ (sumCountEntry, minMaxEntry) -> new IntSummaryStatistics(
+ sumCountEntry.getValue(), minMaxEntry.getKey(),
+ minMaxEntry.getValue(), sumCountEntry.getKey()).toString());
+ Collector countAndContent = Collectors.teeing(counting, joining,
+ (count, content) -> count+": "+content);
+
+ assertCollect(data, sumAndCount, stream -> {
+ List list = stream.collect(toList());
+ return Map.entry(list.stream().mapToLong(Integer::intValue).sum(), (long) list.size());
+ });
+ assertCollect(data, averaging, stream -> stream.mapToInt(Integer::intValue).average().orElse(Double.NaN));
+ assertCollect(data, summaryStatistics,
+ stream -> stream.mapToInt(Integer::intValue).summaryStatistics().toString());
+ assertCollect(data, countAndContent, stream -> {
+ List list = stream.collect(toList());
+ return list.size()+": "+list;
+ });
+
+ Function classifier = i -> i % 3;
+ exerciseMapCollection(data, groupingBy(classifier, sumAndCount),
+ new GroupingByAssertion<>(classifier, Map.class,
+ new TeeingAssertion<>(summing, counting, Map::entry)));
+ }
}