From 98cb81b38120e1d06534289de25f73b3b181e161 Mon Sep 17 00:00:00 2001 From: Peter Levart Date: Wed, 21 Apr 2021 10:32:03 +0000 Subject: [PATCH] 8265237: String.join and StringJoiner can be improved further Reviewed-by: rriggs, redestad --- .../share/classes/java/lang/String.java | 72 +++++++++++--- .../share/classes/java/lang/System.java | 4 + .../share/classes/java/util/StringJoiner.java | 43 ++++---- .../jdk/internal/access/JavaLangAccess.java | 5 + .../StringJoinerOomUtf16Test.java | 97 +++++++++++++++++++ .../java/util/StringJoinerBenchmark.java | 16 ++- 6 files changed, 197 insertions(+), 40 deletions(-) create mode 100644 test/jdk/java/util/StringJoiner/StringJoinerOomUtf16Test.java diff --git a/src/java.base/share/classes/java/lang/String.java b/src/java.base/share/classes/java/lang/String.java index 8f81bb4a0da..e5def118c9a 100644 --- a/src/java.base/share/classes/java/lang/String.java +++ b/src/java.base/share/classes/java/lang/String.java @@ -43,7 +43,6 @@ import java.util.Locale; import java.util.Objects; import java.util.Optional; import java.util.Spliterator; -import java.util.StringJoiner; import java.util.function.Function; import java.util.regex.Pattern; import java.util.regex.PatternSyntaxException; @@ -51,6 +50,8 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; import java.util.stream.StreamSupport; + +import jdk.internal.vm.annotation.ForceInline; import jdk.internal.vm.annotation.IntrinsicCandidate; import jdk.internal.vm.annotation.Stable; import sun.nio.cs.ArrayDecoder; @@ -3218,14 +3219,58 @@ public final class String * @since 1.8 */ public static String join(CharSequence delimiter, CharSequence... elements) { - Objects.requireNonNull(delimiter); - Objects.requireNonNull(elements); - // Number of elements not likely worth Arrays.stream overhead. - StringJoiner joiner = new StringJoiner(delimiter); - for (CharSequence cs: elements) { - joiner.add(cs); + var delim = delimiter.toString(); + var elems = new String[elements.length]; + for (int i = 0; i < elements.length; i++) { + elems[i] = String.valueOf(elements[i]); } - return joiner.toString(); + return join("", "", delim, elems, elems.length); + } + + /** + * Designated join routine. + * + * @param prefix the non-null prefix + * @param suffix the non-null suffix + * @param delimiter the non-null delimiter + * @param elements the non-null array of non-null elements + * @param size the number of elements in the array (<= elements.length) + * @return the joined string + */ + @ForceInline + static String join(String prefix, String suffix, String delimiter, String[] elements, int size) { + int icoder = prefix.coder() | suffix.coder() | delimiter.coder(); + long len = (long) prefix.length() + suffix.length() + (long) Math.max(0, size - 1) * delimiter.length(); + // assert len > 0L; // max: (long) Integer.MAX_VALUE << 32 + // following loop wil add max: (long) Integer.MAX_VALUE * Integer.MAX_VALUE to len + // so len can overflow at most once + for (int i = 0; i < size; i++) { + var el = elements[i]; + len += el.length(); + icoder |= el.coder(); + } + byte coder = (byte) icoder; + // long len overflow check, char -> byte length, int len overflow check + if (len < 0L || (len <<= coder) != (int) len) { + throw new OutOfMemoryError("Requested string length exceeds VM limit"); + } + byte[] value = StringConcatHelper.newArray(len); + + int off = 0; + prefix.getBytes(value, off, coder); off += prefix.length(); + if (size > 0) { + var el = elements[0]; + el.getBytes(value, off, coder); off += el.length(); + for (int i = 1; i < size; i++) { + delimiter.getBytes(value, off, coder); off += delimiter.length(); + el = elements[i]; + el.getBytes(value, off, coder); off += el.length(); + } + } + suffix.getBytes(value, off, coder); + // assert off + suffix.length() == value.length >> coder; + + return new String(value, coder); } /** @@ -3266,11 +3311,16 @@ public final class String Iterable elements) { Objects.requireNonNull(delimiter); Objects.requireNonNull(elements); - StringJoiner joiner = new StringJoiner(delimiter); + var delim = delimiter.toString(); + var elems = new String[8]; + int size = 0; for (CharSequence cs: elements) { - joiner.add(cs); + if (size >= elems.length) { + elems = Arrays.copyOf(elems, elems.length << 1); + } + elems[size++] = String.valueOf(cs); } - return joiner.toString(); + return join("", "", delim, elems, size); } /** diff --git a/src/java.base/share/classes/java/lang/System.java b/src/java.base/share/classes/java/lang/System.java index 706077f803a..524071ed671 100644 --- a/src/java.base/share/classes/java/lang/System.java +++ b/src/java.base/share/classes/java/lang/System.java @@ -2308,6 +2308,10 @@ public final class System { return StringConcatHelper.mix(lengthCoder, constant); } + public String join(String prefix, String suffix, String delimiter, String[] elements, int size) { + return String.join(prefix, suffix, delimiter, elements, size); + } + public Object classData(Class c) { return c.getClassData(); } diff --git a/src/java.base/share/classes/java/util/StringJoiner.java b/src/java.base/share/classes/java/util/StringJoiner.java index 86ba41df14a..f8127d9b707 100644 --- a/src/java.base/share/classes/java/util/StringJoiner.java +++ b/src/java.base/share/classes/java/util/StringJoiner.java @@ -24,6 +24,9 @@ */ package java.util; +import jdk.internal.access.JavaLangAccess; +import jdk.internal.access.SharedSecrets; + /** * {@code StringJoiner} is used to construct a sequence of characters separated * by a delimiter and optionally starting with a supplied prefix @@ -63,6 +66,8 @@ package java.util; * @since 1.8 */ public final class StringJoiner { + private static final String[] EMPTY_STRING_ARRAY = new String[0]; + private final String prefix; private final String delimiter; private final String suffix; @@ -158,27 +163,15 @@ public final class StringJoiner { */ @Override public String toString() { - final String[] elts = this.elts; - if (elts == null && emptyValue != null) { - return emptyValue; - } final int size = this.size; - final int addLen = prefix.length() + suffix.length(); + var elts = this.elts; if (size == 0) { - if (addLen == 0) { - return ""; + if (emptyValue != null) { + return emptyValue; } - return prefix + suffix; + elts = EMPTY_STRING_ARRAY; } - final String delimiter = this.delimiter; - StringBuilder sb = new StringBuilder(len + addLen).append(prefix); - if (size > 0) { - sb.append(elts[0]); - for (int i = 1; i < size; i++) { - sb.append(delimiter).append(elts[i]); - } - } - return sb.append(suffix).toString(); + return JLA.join(prefix, suffix, delimiter, elts, size); } /** @@ -233,7 +226,7 @@ public final class StringJoiner { */ public StringJoiner merge(StringJoiner other) { Objects.requireNonNull(other); - if (other.elts == null) { + if (other.size == 0) { return this; } other.compactElts(); @@ -241,15 +234,11 @@ public final class StringJoiner { } private void compactElts() { - if (size > 1) { - StringBuilder sb = new StringBuilder(len).append(elts[0]); - int i = 1; - do { - sb.append(delimiter).append(elts[i]); - elts[i] = null; - } while (++i < size); + int sz = size; + if (sz > 1) { + elts[0] = JLA.join("", "", delimiter, elts, sz); + Arrays.fill(elts, 1, sz, null); size = 1; - elts[0] = sb.toString(); } } @@ -267,4 +256,6 @@ public final class StringJoiner { return (size == 0 && emptyValue != null) ? emptyValue.length() : len + prefix.length() + suffix.length(); } + + private static final JavaLangAccess JLA = SharedSecrets.getJavaLangAccess(); } diff --git a/src/java.base/share/classes/jdk/internal/access/JavaLangAccess.java b/src/java.base/share/classes/jdk/internal/access/JavaLangAccess.java index 4cac7626e5b..60f8a5f6719 100644 --- a/src/java.base/share/classes/jdk/internal/access/JavaLangAccess.java +++ b/src/java.base/share/classes/jdk/internal/access/JavaLangAccess.java @@ -367,6 +367,11 @@ public interface JavaLangAccess { */ long stringConcatMix(long lengthCoder, String constant); + /** + * Join strings + */ + String join(String prefix, String suffix, String delimiter, String[] elements, int size); + /* * Get the class data associated with the given class. * @param c the class diff --git a/test/jdk/java/util/StringJoiner/StringJoinerOomUtf16Test.java b/test/jdk/java/util/StringJoiner/StringJoinerOomUtf16Test.java new file mode 100644 index 00000000000..b45e66b4d8b --- /dev/null +++ b/test/jdk/java/util/StringJoiner/StringJoinerOomUtf16Test.java @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +/** + * @test + * @bug 8265237 + * @summary tests StringJoiner OOME when joining sub-max-length Strings + * @modules java.base/jdk.internal.util + * @requires vm.bits == "64" & os.maxMemory > 4G + * @run testng/othervm -Xmx4g -XX:+CompactStrings StringJoinerOomUtf16Test + */ + +import org.testng.annotations.Test; + +import static jdk.internal.util.ArraysSupport.SOFT_MAX_ARRAY_LENGTH; +import static org.testng.Assert.fail; + +import java.util.StringJoiner; + + +@Test(groups = {"unit","string","util","libs"}) +public class StringJoinerOomUtf16Test { + + // the sum of lengths of the following two strings is way less than + // SOFT_MAX_ARRAY_LENGTH, but the byte[] array holding the UTF16 representation + // would need to be bigger than Integer.MAX_VALUE... + private static final String HALF_MAX_LATIN1_STRING = + "*".repeat(SOFT_MAX_ARRAY_LENGTH >> 1); + private static final String OVERFLOW_UTF16_STRING = + "\u017D".repeat(((Integer.MAX_VALUE - SOFT_MAX_ARRAY_LENGTH) >> 1) + 1); + + public void OOM1() { + try { + new StringJoiner("") + .add(HALF_MAX_LATIN1_STRING) + .add(OVERFLOW_UTF16_STRING) + .toString(); + fail("Should have thrown OutOfMemoryError"); + } catch (OutOfMemoryError ex) { + System.out.println("Expected: " + ex); + } + } + + public void OOM2() { + try { + new StringJoiner(HALF_MAX_LATIN1_STRING) + .add("") + .add(OVERFLOW_UTF16_STRING) + .toString(); + fail("Should have thrown OutOfMemoryError"); + } catch (OutOfMemoryError ex) { + System.out.println("Expected: " + ex); + } + } + + public void OOM3() { + try { + new StringJoiner(OVERFLOW_UTF16_STRING) + .add("") + .add(HALF_MAX_LATIN1_STRING) + .toString(); + fail("Should have thrown OutOfMemoryError"); + } catch (OutOfMemoryError ex) { + System.out.println("Expected: " + ex); + } + } + + public void OOM4() { + try { + new StringJoiner("", HALF_MAX_LATIN1_STRING, OVERFLOW_UTF16_STRING) + .toString(); + fail("Should have thrown OutOfMemoryError"); + } catch (OutOfMemoryError ex) { + System.out.println("Expected: " + ex); + } + } +} + diff --git a/test/micro/org/openjdk/bench/java/util/StringJoinerBenchmark.java b/test/micro/org/openjdk/bench/java/util/StringJoinerBenchmark.java index b576571be60..6df0045ef32 100644 --- a/test/micro/org/openjdk/bench/java/util/StringJoinerBenchmark.java +++ b/test/micro/org/openjdk/bench/java/util/StringJoinerBenchmark.java @@ -25,12 +25,14 @@ package org.openjdk.bench.java.util; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; import org.openjdk.jmh.annotations.Mode; import org.openjdk.jmh.annotations.OutputTimeUnit; import org.openjdk.jmh.annotations.Param; import org.openjdk.jmh.annotations.Scope; import org.openjdk.jmh.annotations.Setup; import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; import java.util.StringJoiner; import java.util.concurrent.ThreadLocalRandom; @@ -41,9 +43,17 @@ import java.util.concurrent.TimeUnit; */ @BenchmarkMode(Mode.AverageTime) @OutputTimeUnit(TimeUnit.NANOSECONDS) -@Fork(jvmArgsAppend = {"-Xms2g", "-Xmx2g"}) +@Warmup(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS) +@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS) +@Fork(value = 3, jvmArgsAppend = {"-Xms1g", "-Xmx1g"}) public class StringJoinerBenchmark { + @Benchmark + public String join(Data data) { + String[] stringArray = data.stringArray; + return String.join(",", stringArray); + } + @Benchmark public String stringJoiner(Data data) { String[] stringArray = data.stringArray; @@ -56,10 +66,10 @@ public class StringJoinerBenchmark { @Param({"latin", "cyrillic"}) private String mode; - @Param({"8", "32"}) + @Param({"1", "8", "32", "128"}) private int length; - @Param({"5", "10"}) + @Param({"5", "20"}) private int count; private String[] stringArray;