diff --git a/make/modules/java.base/Lib.gmk b/make/modules/java.base/Lib.gmk index 30261bab6e1..924cb8aae26 100644 --- a/make/modules/java.base/Lib.gmk +++ b/make/modules/java.base/Lib.gmk @@ -245,7 +245,7 @@ ifeq ($(call isTargetOs, linux)+$(call isTargetCpu, x86_64)+$(INCLUDE_COMPILER2) TOOLCHAIN := TOOLCHAIN_LINK_CXX, \ OPTIMIZATION := HIGH, \ CFLAGS := $(CFLAGS_JDKLIB), \ - CXXFLAGS := $(CXXFLAGS_JDKLIB), \ + CXXFLAGS := $(CXXFLAGS_JDKLIB) -std=c++17, \ LDFLAGS := $(LDFLAGS_JDKLIB) \ $(call SET_SHARED_LIBRARY_ORIGIN), \ LIBS := $(LIBCXX), \ diff --git a/src/hotspot/cpu/aarch64/matcher_aarch64.hpp b/src/hotspot/cpu/aarch64/matcher_aarch64.hpp index aa22459d217..08bff22d7d0 100644 --- a/src/hotspot/cpu/aarch64/matcher_aarch64.hpp +++ b/src/hotspot/cpu/aarch64/matcher_aarch64.hpp @@ -193,4 +193,9 @@ } } + // Is SIMD sort supported for this CPU? + static bool supports_simd_sort(BasicType bt) { + return false; + } + #endif // CPU_AARCH64_MATCHER_AARCH64_HPP diff --git a/src/hotspot/cpu/arm/matcher_arm.hpp b/src/hotspot/cpu/arm/matcher_arm.hpp index 1e8f7683e76..eb26cbcbd7a 100644 --- a/src/hotspot/cpu/arm/matcher_arm.hpp +++ b/src/hotspot/cpu/arm/matcher_arm.hpp @@ -186,4 +186,9 @@ } } + // Is SIMD sort supported for this CPU? + static bool supports_simd_sort(BasicType bt) { + return false; + } + #endif // CPU_ARM_MATCHER_ARM_HPP diff --git a/src/hotspot/cpu/ppc/matcher_ppc.hpp b/src/hotspot/cpu/ppc/matcher_ppc.hpp index 44d1a3cd305..b195ba4eeb2 100644 --- a/src/hotspot/cpu/ppc/matcher_ppc.hpp +++ b/src/hotspot/cpu/ppc/matcher_ppc.hpp @@ -195,4 +195,9 @@ } } + // Is SIMD sort supported for this CPU? + static bool supports_simd_sort(BasicType bt) { + return false; + } + #endif // CPU_PPC_MATCHER_PPC_HPP diff --git a/src/hotspot/cpu/riscv/matcher_riscv.hpp b/src/hotspot/cpu/riscv/matcher_riscv.hpp index 1da8f003122..08914d4d834 100644 --- a/src/hotspot/cpu/riscv/matcher_riscv.hpp +++ b/src/hotspot/cpu/riscv/matcher_riscv.hpp @@ -192,4 +192,9 @@ } } + // Is SIMD sort supported for this CPU? + static bool supports_simd_sort(BasicType bt) { + return false; + } + #endif // CPU_RISCV_MATCHER_RISCV_HPP diff --git a/src/hotspot/cpu/s390/matcher_s390.hpp b/src/hotspot/cpu/s390/matcher_s390.hpp index e68dca34c3b..450ea35a6cb 100644 --- a/src/hotspot/cpu/s390/matcher_s390.hpp +++ b/src/hotspot/cpu/s390/matcher_s390.hpp @@ -184,4 +184,9 @@ } } + // Is SIMD sort supported for this CPU? + static bool supports_simd_sort(BasicType bt) { + return false; + } + #endif // CPU_S390_MATCHER_S390_HPP diff --git a/src/hotspot/cpu/x86/matcher_x86.hpp b/src/hotspot/cpu/x86/matcher_x86.hpp index bc249c0f33a..de844c4be9f 100644 --- a/src/hotspot/cpu/x86/matcher_x86.hpp +++ b/src/hotspot/cpu/x86/matcher_x86.hpp @@ -248,4 +248,17 @@ } } + // Is SIMD sort supported for this CPU? + static bool supports_simd_sort(BasicType bt) { + if (VM_Version::supports_avx512dq()) { + return true; + } + else if (VM_Version::supports_avx2() && !is_double_word_type(bt)) { + return true; + } + else { + return false; + } + } + #endif // CPU_X86_MATCHER_X86_HPP diff --git a/src/hotspot/cpu/x86/stubGenerator_x86_64.cpp b/src/hotspot/cpu/x86/stubGenerator_x86_64.cpp index c73e0759b57..85f3bbf6109 100644 --- a/src/hotspot/cpu/x86/stubGenerator_x86_64.cpp +++ b/src/hotspot/cpu/x86/stubGenerator_x86_64.cpp @@ -4193,22 +4193,23 @@ void StubGenerator::generate_compiler_stubs() { = CAST_FROM_FN_PTR(address, SharedRuntime::montgomery_square); } - // Load x86_64_sort library on supported hardware to enable avx512 sort and partition intrinsics - if (VM_Version::is_intel() && VM_Version::supports_avx512dq()) { + // Load x86_64_sort library on supported hardware to enable SIMD sort and partition intrinsics + + if (VM_Version::is_intel() && (VM_Version::supports_avx512dq() || VM_Version::supports_avx2())) { void *libsimdsort = nullptr; char ebuf_[1024]; char dll_name_simd_sort[JVM_MAXPATHLEN]; if (os::dll_locate_lib(dll_name_simd_sort, sizeof(dll_name_simd_sort), Arguments::get_dll_dir(), "simdsort")) { libsimdsort = os::dll_load(dll_name_simd_sort, ebuf_, sizeof ebuf_); } - // Get addresses for avx512 sort and partition routines + // Get addresses for SIMD sort and partition routines if (libsimdsort != nullptr) { log_info(library)("Loaded library %s, handle " INTPTR_FORMAT, JNI_LIB_PREFIX "simdsort" JNI_LIB_SUFFIX, p2i(libsimdsort)); - snprintf(ebuf_, sizeof(ebuf_), "avx512_sort"); + snprintf(ebuf_, sizeof(ebuf_), VM_Version::supports_avx512dq() ? "avx512_sort" : "avx2_sort"); StubRoutines::_array_sort = (address)os::dll_lookup(libsimdsort, ebuf_); - snprintf(ebuf_, sizeof(ebuf_), "avx512_partition"); + snprintf(ebuf_, sizeof(ebuf_), VM_Version::supports_avx512dq() ? "avx512_partition" : "avx2_partition"); StubRoutines::_array_partition = (address)os::dll_lookup(libsimdsort, ebuf_); } } diff --git a/src/hotspot/cpu/x86/vm_version_x86.cpp b/src/hotspot/cpu/x86/vm_version_x86.cpp index 1517e456e82..b536e535d2c 100644 --- a/src/hotspot/cpu/x86/vm_version_x86.cpp +++ b/src/hotspot/cpu/x86/vm_version_x86.cpp @@ -858,7 +858,7 @@ void VM_Version::get_processor_features() { // Check if processor has Intel Ecore if (FLAG_IS_DEFAULT(EnableX86ECoreOpts) && is_intel() && cpu_family() == 6 && - (_model == 0x97 || _model == 0xAC || _model == 0xAF)) { + (_model == 0x97 || _model == 0xAA || _model == 0xAC || _model == 0xAF)) { FLAG_SET_DEFAULT(EnableX86ECoreOpts, true); } diff --git a/src/hotspot/share/opto/library_call.cpp b/src/hotspot/share/opto/library_call.cpp index 6e104f09b9d..07504199f81 100644 --- a/src/hotspot/share/opto/library_call.cpp +++ b/src/hotspot/share/opto/library_call.cpp @@ -5387,6 +5387,10 @@ bool LibraryCallKit::inline_array_partition() { const TypeInstPtr* elem_klass = gvn().type(elementType)->isa_instptr(); ciType* elem_type = elem_klass->const_oop()->as_instance()->java_mirror_type(); BasicType bt = elem_type->basic_type(); + // Disable the intrinsic if the CPU does not support SIMD sort + if (!Matcher::supports_simd_sort(bt)) { + return false; + } address stubAddr = nullptr; stubAddr = StubRoutines::select_array_partition_function(); // stub not loaded @@ -5440,6 +5444,10 @@ bool LibraryCallKit::inline_array_sort() { const TypeInstPtr* elem_klass = gvn().type(elementType)->isa_instptr(); ciType* elem_type = elem_klass->const_oop()->as_instance()->java_mirror_type(); BasicType bt = elem_type->basic_type(); + // Disable the intrinsic if the CPU does not support SIMD sort + if (!Matcher::supports_simd_sort(bt)) { + return false; + } address stubAddr = nullptr; stubAddr = StubRoutines::select_arraysort_function(); //stub not loaded diff --git a/src/java.base/linux/native/libsimdsort/avx2-32bit-qsort.hpp b/src/java.base/linux/native/libsimdsort/avx2-32bit-qsort.hpp new file mode 100644 index 00000000000..9310b0098d8 --- /dev/null +++ b/src/java.base/linux/native/libsimdsort/avx2-32bit-qsort.hpp @@ -0,0 +1,367 @@ +/* + * Copyright (c) 2021, 2023, Intel Corporation. All rights reserved. + * Copyright (c) 2021 Serge Sans Paille. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + * + */ + +// This implementation is based on x86-simd-sort(https://github.com/intel/x86-simd-sort) + +#ifndef AVX2_QSORT_32BIT +#define AVX2_QSORT_32BIT + +#include "avx2-emu-funcs.hpp" +#include "xss-common-qsort.h" + +/* + * Constants used in sorting 8 elements in a ymm registers. Based on Bitonic + * sorting network (see + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) + */ + +// ymm 7, 6, 5, 4, 3, 2, 1, 0 +#define NETWORK_32BIT_AVX2_1 4, 5, 6, 7, 0, 1, 2, 3 +#define NETWORK_32BIT_AVX2_2 0, 1, 2, 3, 4, 5, 6, 7 +#define NETWORK_32BIT_AVX2_3 5, 4, 7, 6, 1, 0, 3, 2 +#define NETWORK_32BIT_AVX2_4 3, 2, 1, 0, 7, 6, 5, 4 + +/* + * Assumes ymm is random and performs a full sorting network defined in + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg + */ +template +X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit(reg_t ymm) { + const typename vtype::opmask_t oxAA = _mm256_set_epi32( + 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0); + const typename vtype::opmask_t oxCC = _mm256_set_epi32( + 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0); + const typename vtype::opmask_t oxF0 = _mm256_set_epi32( + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0, 0); + + const typename vtype::ymmi_t rev_index = vtype::seti(NETWORK_32BIT_AVX2_2); + ymm = cmp_merge( + ymm, vtype::template shuffle(ymm), oxAA); + ymm = cmp_merge( + ymm, vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_1), ymm), oxCC); + ymm = cmp_merge( + ymm, vtype::template shuffle(ymm), oxAA); + ymm = cmp_merge(ymm, vtype::permutexvar(rev_index, ymm), oxF0); + ymm = cmp_merge( + ymm, vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_3), ymm), oxCC); + ymm = cmp_merge( + ymm, vtype::template shuffle(ymm), oxAA); + return ymm; +} + +struct avx2_32bit_swizzle_ops; + +template <> +struct avx2_vector { + using type_t = int32_t; + using reg_t = __m256i; + using ymmi_t = __m256i; + using opmask_t = __m256i; + static const uint8_t numlanes = 8; +#ifdef XSS_MINIMAL_NETWORK_SORT + static constexpr int network_sort_threshold = numlanes; +#else + static constexpr int network_sort_threshold = 256; +#endif + static constexpr int partition_unroll_factor = 4; + + using swizzle_ops = avx2_32bit_swizzle_ops; + + static type_t type_max() { return X86_SIMD_SORT_MAX_INT32; } + static type_t type_min() { return X86_SIMD_SORT_MIN_INT32; } + static reg_t zmm_max() { + return _mm256_set1_epi32(type_max()); + } // TODO: this should broadcast bits as is? + static opmask_t get_partial_loadmask(uint64_t num_to_read) { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask(mask); + } + static ymmi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, + int v8) { + return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); + } + static opmask_t kxor_opmask(opmask_t x, opmask_t y) { + return _mm256_xor_si256(x, y); + } + static opmask_t ge(reg_t x, reg_t y) { + opmask_t equal = eq(x, y); + opmask_t greater = _mm256_cmpgt_epi32(x, y); + return _mm256_castps_si256(_mm256_or_ps(_mm256_castsi256_ps(equal), + _mm256_castsi256_ps(greater))); + } + static opmask_t gt(reg_t x, reg_t y) { return _mm256_cmpgt_epi32(x, y); } + static opmask_t eq(reg_t x, reg_t y) { return _mm256_cmpeq_epi32(x, y); } + template + static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, + void const *base) { + return _mm256_mask_i32gather_epi32(src, base, index, mask, scale); + } + template + static reg_t i64gather(__m256i index, void const *base) { + return _mm256_i32gather_epi32((int const *)base, index, scale); + } + static reg_t loadu(void const *mem) { + return _mm256_loadu_si256((reg_t const *)mem); + } + static reg_t max(reg_t x, reg_t y) { return _mm256_max_epi32(x, y); } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { + return avx2_emu_mask_compressstoreu32(mem, mask, x); + } + static reg_t maskz_loadu(opmask_t mask, void const *mem) { + return _mm256_maskload_epi32((const int *)mem, mask); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { + reg_t dst = _mm256_maskload_epi32((type_t *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) { + return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(x), + _mm256_castsi256_ps(y), + _mm256_castsi256_ps(mask))); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) { + return _mm256_maskstore_epi32((type_t *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) { return _mm256_min_epi32(x, y); } + static reg_t permutexvar(__m256i idx, reg_t ymm) { + return _mm256_permutevar8x32_epi32(ymm, idx); + // return avx2_emu_permutexvar_epi32(idx, ymm); + } + static reg_t permutevar(reg_t ymm, __m256i idx) { + return _mm256_permutevar8x32_epi32(ymm, idx); + } + static reg_t reverse(reg_t ymm) { + const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + return permutexvar(rev_index, ymm); + } + static type_t reducemax(reg_t v) { + return avx2_emu_reduce_max32(v); + } + static type_t reducemin(reg_t v) { + return avx2_emu_reduce_min32(v); + } + static reg_t set1(type_t v) { return _mm256_set1_epi32(v); } + template + static reg_t shuffle(reg_t ymm) { + return _mm256_shuffle_epi32(ymm, mask); + } + static void storeu(void *mem, reg_t x) { + _mm256_storeu_si256((__m256i *)mem, x); + } + static reg_t sort_vec(reg_t x) { + return sort_ymm_32bit>(x); + } + static reg_t cast_from(__m256i v) { return v; } + static __m256i cast_to(reg_t v) { return v; } + static int double_compressstore(type_t *left_addr, type_t *right_addr, + opmask_t k, reg_t reg) { + return avx2_double_compressstore32(left_addr, right_addr, k, + reg); + } +}; + +template <> +struct avx2_vector { + using type_t = float; + using reg_t = __m256; + using ymmi_t = __m256i; + using opmask_t = __m256i; + static const uint8_t numlanes = 8; +#ifdef XSS_MINIMAL_NETWORK_SORT + static constexpr int network_sort_threshold = numlanes; +#else + static constexpr int network_sort_threshold = 256; +#endif + static constexpr int partition_unroll_factor = 4; + + using swizzle_ops = avx2_32bit_swizzle_ops; + + static type_t type_max() { return X86_SIMD_SORT_INFINITYF; } + static type_t type_min() { return -X86_SIMD_SORT_INFINITYF; } + static reg_t zmm_max() { return _mm256_set1_ps(type_max()); } + + static ymmi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, + int v8) { + return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); + } + + static reg_t maskz_loadu(opmask_t mask, void const *mem) { + return _mm256_maskload_ps((const float *)mem, mask); + } + static opmask_t ge(reg_t x, reg_t y) { + return _mm256_castps_si256(_mm256_cmp_ps(x, y, _CMP_GE_OQ)); + } + static opmask_t gt(reg_t x, reg_t y) { + return _mm256_castps_si256(_mm256_cmp_ps(x, y, _CMP_GT_OQ)); + } + static opmask_t eq(reg_t x, reg_t y) { + return _mm256_castps_si256(_mm256_cmp_ps(x, y, _CMP_EQ_OQ)); + } + static opmask_t get_partial_loadmask(uint64_t num_to_read) { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask(mask); + } + static int32_t convert_mask_to_int(opmask_t mask) { + return convert_avx2_mask_to_int(mask); + } + template + static opmask_t fpclass(reg_t x) { + if constexpr (type == (0x01 | 0x80)) { + return _mm256_castps_si256(_mm256_cmp_ps(x, x, _CMP_UNORD_Q)); + } else { + static_assert(type == (0x01 | 0x80), "should not reach here"); + } + } + template + static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, + void const *base) { + return _mm256_mask_i32gather_ps(src, base, index, + _mm256_castsi256_ps(mask), scale); + ; + } + template + static reg_t i64gather(__m256i index, void const *base) { + return _mm256_i32gather_ps((float *)base, index, scale); + } + static reg_t loadu(void const *mem) { + return _mm256_loadu_ps((float const *)mem); + } + static reg_t max(reg_t x, reg_t y) { return _mm256_max_ps(x, y); } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { + return avx2_emu_mask_compressstoreu32(mem, mask, x); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { + reg_t dst = _mm256_maskload_ps((type_t *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) { + return _mm256_blendv_ps(x, y, _mm256_castsi256_ps(mask)); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) { + return _mm256_maskstore_ps((type_t *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) { return _mm256_min_ps(x, y); } + static reg_t permutexvar(__m256i idx, reg_t ymm) { + return _mm256_permutevar8x32_ps(ymm, idx); + } + static reg_t permutevar(reg_t ymm, __m256i idx) { + return _mm256_permutevar8x32_ps(ymm, idx); + } + static reg_t reverse(reg_t ymm) { + const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + return permutexvar(rev_index, ymm); + } + static type_t reducemax(reg_t v) { + return avx2_emu_reduce_max32(v); + } + static type_t reducemin(reg_t v) { + return avx2_emu_reduce_min32(v); + } + static reg_t set1(type_t v) { return _mm256_set1_ps(v); } + template + static reg_t shuffle(reg_t ymm) { + return _mm256_castsi256_ps( + _mm256_shuffle_epi32(_mm256_castps_si256(ymm), mask)); + } + static void storeu(void *mem, reg_t x) { + _mm256_storeu_ps((float *)mem, x); + } + static reg_t sort_vec(reg_t x) { + return sort_ymm_32bit>(x); + } + static reg_t cast_from(__m256i v) { return _mm256_castsi256_ps(v); } + static __m256i cast_to(reg_t v) { return _mm256_castps_si256(v); } + static int double_compressstore(type_t *left_addr, type_t *right_addr, + opmask_t k, reg_t reg) { + return avx2_double_compressstore32(left_addr, right_addr, k, + reg); + } +}; + +struct avx2_32bit_swizzle_ops { + template + X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n( + typename vtype::reg_t reg) { + __m256i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { + __m256 vf = _mm256_castsi256_ps(v); + vf = _mm256_permute_ps(vf, 0b10110001); + v = _mm256_castps_si256(vf); + } else if constexpr (scale == 4) { + __m256 vf = _mm256_castsi256_ps(v); + vf = _mm256_permute_ps(vf, 0b01001110); + v = _mm256_castps_si256(vf); + } else if constexpr (scale == 8) { + v = _mm256_permute2x128_si256(v, v, 0b00000001); + } else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t reverse_n( + typename vtype::reg_t reg) { + __m256i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { + return swap_n(reg); + } else if constexpr (scale == 4) { + constexpr uint64_t mask = 0b00011011; + __m256 vf = _mm256_castsi256_ps(v); + vf = _mm256_permute_ps(vf, mask); + v = _mm256_castps_si256(vf); + } else if constexpr (scale == 8) { + return vtype::reverse(reg); + } else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t merge_n( + typename vtype::reg_t reg, typename vtype::reg_t other) { + __m256i v1 = vtype::cast_to(reg); + __m256i v2 = vtype::cast_to(other); + + if constexpr (scale == 2) { + v1 = _mm256_blend_epi32(v1, v2, 0b01010101); + } else if constexpr (scale == 4) { + v1 = _mm256_blend_epi32(v1, v2, 0b00110011); + } else if constexpr (scale == 8) { + v1 = _mm256_blend_epi32(v1, v2, 0b00001111); + } else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v1); + } +}; + +#endif // AVX2_QSORT_32BIT diff --git a/src/java.base/linux/native/libsimdsort/avx2-emu-funcs.hpp b/src/java.base/linux/native/libsimdsort/avx2-emu-funcs.hpp new file mode 100644 index 00000000000..611f3f419bd --- /dev/null +++ b/src/java.base/linux/native/libsimdsort/avx2-emu-funcs.hpp @@ -0,0 +1,183 @@ +/* + * Copyright (c) 2021, 2023, Intel Corporation. All rights reserved. + * Copyright (c) 2021 Serge Sans Paille. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + * + */ + +// This implementation is based on x86-simd-sort(https://github.com/intel/x86-simd-sort) + +#ifndef AVX2_EMU_FUNCS +#define AVX2_EMU_FUNCS + +#include +#include + +#include "xss-common-qsort.h" + +constexpr auto avx2_mask_helper_lut32 = [] { + std::array, 256> lut{}; + for (int64_t i = 0; i <= 0xFF; i++) { + std::array entry{}; + for (int j = 0; j < 8; j++) { + if (((i >> j) & 1) == 1) + entry[j] = 0xFFFFFFFF; + else + entry[j] = 0; + } + lut[i] = entry; + } + return lut; +}(); + +constexpr auto avx2_compressstore_lut32_gen = [] { + std::array, 256>, 2> lutPair{}; + auto &permLut = lutPair[0]; + auto &leftLut = lutPair[1]; + for (int64_t i = 0; i <= 0xFF; i++) { + std::array indices{}; + std::array leftEntry = {0, 0, 0, 0, 0, 0, 0, 0}; + int right = 7; + int left = 0; + for (int j = 0; j < 8; j++) { + bool ge = (i >> j) & 1; + if (ge) { + indices[right] = j; + right--; + } else { + indices[left] = j; + leftEntry[left] = 0xFFFFFFFF; + left++; + } + } + permLut[i] = indices; + leftLut[i] = leftEntry; + } + return lutPair; +}(); + +constexpr auto avx2_compressstore_lut32_perm = avx2_compressstore_lut32_gen[0]; +constexpr auto avx2_compressstore_lut32_left = avx2_compressstore_lut32_gen[1]; + + +X86_SIMD_SORT_INLINE +__m256i convert_int_to_avx2_mask(int32_t m) { + return _mm256_loadu_si256( + (const __m256i *)avx2_mask_helper_lut32[m].data()); +} + +X86_SIMD_SORT_INLINE +int32_t convert_avx2_mask_to_int(__m256i m) { + return _mm256_movemask_ps(_mm256_castsi256_ps(m)); +} + +// Emulators for intrinsics missing from AVX2 compared to AVX512 +template +T avx2_emu_reduce_max32(typename avx2_vector::reg_t x) { + using vtype = avx2_vector; + using reg_t = typename vtype::reg_t; + + reg_t inter1 = + vtype::max(x, vtype::template shuffle(x)); + reg_t inter2 = vtype::max( + inter1, vtype::template shuffle(inter1)); + T arr[vtype::numlanes]; + vtype::storeu(arr, inter2); + return std::max(arr[0], arr[7]); +} + +template +T avx2_emu_reduce_min32(typename avx2_vector::reg_t x) { + using vtype = avx2_vector; + using reg_t = typename vtype::reg_t; + + reg_t inter1 = + vtype::min(x, vtype::template shuffle(x)); + reg_t inter2 = vtype::min( + inter1, vtype::template shuffle(inter1)); + T arr[vtype::numlanes]; + vtype::storeu(arr, inter2); + return std::min(arr[0], arr[7]); +} + +template +void avx2_emu_mask_compressstoreu32(void *base_addr, + typename avx2_vector::opmask_t k, + typename avx2_vector::reg_t reg) { + using vtype = avx2_vector; + + T *leftStore = (T *)base_addr; + + int32_t shortMask = convert_avx2_mask_to_int(k); + const __m256i &perm = _mm256_loadu_si256( + (const __m256i *)avx2_compressstore_lut32_perm[shortMask].data()); + const __m256i &left = _mm256_loadu_si256( + (const __m256i *)avx2_compressstore_lut32_left[shortMask].data()); + + typename vtype::reg_t temp = vtype::permutevar(reg, perm); + + vtype::mask_storeu(leftStore, left, temp); +} + + +template +int avx2_double_compressstore32(void *left_addr, void *right_addr, + typename avx2_vector::opmask_t k, + typename avx2_vector::reg_t reg) { + using vtype = avx2_vector; + + T *leftStore = (T *)left_addr; + T *rightStore = (T *)right_addr; + + int32_t shortMask = convert_avx2_mask_to_int(k); + const __m256i &perm = _mm256_loadu_si256( + (const __m256i *)avx2_compressstore_lut32_perm[shortMask].data()); + + typename vtype::reg_t temp = vtype::permutevar(reg, perm); + + vtype::storeu(leftStore, temp); + vtype::storeu(rightStore, temp); + + return _mm_popcnt_u32(shortMask); +} + + +template +typename avx2_vector::reg_t avx2_emu_max(typename avx2_vector::reg_t x, + typename avx2_vector::reg_t y) { + using vtype = avx2_vector; + typename vtype::opmask_t nlt = vtype::gt(x, y); + return _mm256_castpd_si256(_mm256_blendv_pd(_mm256_castsi256_pd(y), + _mm256_castsi256_pd(x), + _mm256_castsi256_pd(nlt))); +} + +template +typename avx2_vector::reg_t avx2_emu_min(typename avx2_vector::reg_t x, + typename avx2_vector::reg_t y) { + using vtype = avx2_vector; + typename vtype::opmask_t nlt = vtype::gt(x, y); + return _mm256_castpd_si256(_mm256_blendv_pd(_mm256_castsi256_pd(x), + _mm256_castsi256_pd(y), + _mm256_castsi256_pd(nlt))); +} + +#endif diff --git a/src/java.base/linux/native/libsimdsort/avx2-linux-qsort.cpp b/src/java.base/linux/native/libsimdsort/avx2-linux-qsort.cpp new file mode 100644 index 00000000000..628d65077c7 --- /dev/null +++ b/src/java.base/linux/native/libsimdsort/avx2-linux-qsort.cpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2023 Intel Corporation. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + * + */ + +#include "simdsort-support.hpp" +#ifdef __SIMDSORT_SUPPORTED_LINUX + +#pragma GCC target("avx2") +#include "avx2-32bit-qsort.hpp" +#include "classfile_constants.h" + + +#define DLL_PUBLIC __attribute__((visibility("default"))) +#define INSERTION_SORT_THRESHOLD_32BIT 16 + +extern "C" { + + DLL_PUBLIC void avx2_sort(void *array, int elem_type, int32_t from_index, int32_t to_index) { + switch(elem_type) { + case JVM_T_INT: + avx2_fast_sort((int32_t*)array, from_index, to_index, INSERTION_SORT_THRESHOLD_32BIT); + break; + case JVM_T_FLOAT: + avx2_fast_sort((float*)array, from_index, to_index, INSERTION_SORT_THRESHOLD_32BIT); + break; + default: + assert(false, "Unexpected type"); + } + } + + DLL_PUBLIC void avx2_partition(void *array, int elem_type, int32_t from_index, int32_t to_index, int32_t *pivot_indices, int32_t index_pivot1, int32_t index_pivot2) { + switch(elem_type) { + case JVM_T_INT: + avx2_fast_partition((int32_t*)array, from_index, to_index, pivot_indices, index_pivot1, index_pivot2); + break; + case JVM_T_FLOAT: + avx2_fast_partition((float*)array, from_index, to_index, pivot_indices, index_pivot1, index_pivot2); + break; + default: + assert(false, "Unexpected type"); + } + } + +} + +#endif \ No newline at end of file diff --git a/src/java.base/linux/native/libsimdsort/avx512-32bit-qsort.hpp b/src/java.base/linux/native/libsimdsort/avx512-32bit-qsort.hpp index 4fbe9b97450..25ad265d865 100644 --- a/src/java.base/linux/native/libsimdsort/avx512-32bit-qsort.hpp +++ b/src/java.base/linux/native/libsimdsort/avx512-32bit-qsort.hpp @@ -28,7 +28,7 @@ #ifndef AVX512_QSORT_32BIT #define AVX512_QSORT_32BIT -#include "avx512-common-qsort.h" +#include "xss-common-qsort.h" /* * Constants used in sorting 16 elements in a ZMM registers. Based on Bitonic @@ -43,130 +43,204 @@ #define NETWORK_32BIT_6 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4 #define NETWORK_32BIT_7 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8 +template +X86_SIMD_SORT_INLINE reg_t sort_zmm_32bit(reg_t zmm); + +struct avx512_32bit_swizzle_ops; + template <> struct zmm_vector { using type_t = int32_t; - using zmm_t = __m512i; - using ymm_t = __m256i; + using reg_t = __m512i; + using halfreg_t = __m256i; using opmask_t = __mmask16; static const uint8_t numlanes = 16; +#ifdef XSS_MINIMAL_NETWORK_SORT + static constexpr int network_sort_threshold = numlanes; +#else + static constexpr int network_sort_threshold = 512; +#endif + static constexpr int partition_unroll_factor = 8; + + using swizzle_ops = avx512_32bit_swizzle_ops; static type_t type_max() { return X86_SIMD_SORT_MAX_INT32; } static type_t type_min() { return X86_SIMD_SORT_MIN_INT32; } - static zmm_t zmm_max() { return _mm512_set1_epi32(type_max()); } + static reg_t zmm_max() { return _mm512_set1_epi32(type_max()); } static opmask_t knot_opmask(opmask_t x) { return _mm512_knot(x); } - static opmask_t ge(zmm_t x, zmm_t y) { + + static opmask_t ge(reg_t x, reg_t y) { return _mm512_cmp_epi32_mask(x, y, _MM_CMPINT_NLT); } - static opmask_t gt(zmm_t x, zmm_t y) { + + static opmask_t gt(reg_t x, reg_t y) { return _mm512_cmp_epi32_mask(x, y, _MM_CMPINT_GT); } + + static opmask_t get_partial_loadmask(uint64_t num_to_read) { + return ((0x1ull << num_to_read) - 0x1ull); + } template - static ymm_t i64gather(__m512i index, void const *base) { + static halfreg_t i64gather(__m512i index, void const *base) { return _mm512_i64gather_epi32(index, base, scale); } - static zmm_t merge(ymm_t y1, ymm_t y2) { - zmm_t z1 = _mm512_castsi256_si512(y1); + static reg_t merge(halfreg_t y1, halfreg_t y2) { + reg_t z1 = _mm512_castsi256_si512(y1); return _mm512_inserti32x8(z1, y2, 1); } - static zmm_t loadu(void const *mem) { return _mm512_loadu_si512(mem); } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) { + static reg_t loadu(void const *mem) { return _mm512_loadu_si512(mem); } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { return _mm512_mask_compressstoreu_epi32(mem, mask, x); } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) { + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { return _mm512_mask_loadu_epi32(x, mask, mem); } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) { + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) { return _mm512_mask_mov_epi32(x, mask, y); } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) { + static void mask_storeu(void *mem, opmask_t mask, reg_t x) { return _mm512_mask_storeu_epi32(mem, mask, x); } - static zmm_t min(zmm_t x, zmm_t y) { return _mm512_min_epi32(x, y); } - static zmm_t max(zmm_t x, zmm_t y) { return _mm512_max_epi32(x, y); } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) { + static reg_t min(reg_t x, reg_t y) { return _mm512_min_epi32(x, y); } + static reg_t max(reg_t x, reg_t y) { return _mm512_max_epi32(x, y); } + static reg_t permutexvar(__m512i idx, reg_t zmm) { return _mm512_permutexvar_epi32(idx, zmm); } - static type_t reducemax(zmm_t v) { return _mm512_reduce_max_epi32(v); } - static type_t reducemin(zmm_t v) { return _mm512_reduce_min_epi32(v); } - static zmm_t set1(type_t v) { return _mm512_set1_epi32(v); } + static type_t reducemax(reg_t v) { return _mm512_reduce_max_epi32(v); } + static type_t reducemin(reg_t v) { return _mm512_reduce_min_epi32(v); } + static reg_t set1(type_t v) { return _mm512_set1_epi32(v); } template - static zmm_t shuffle(zmm_t zmm) { + static reg_t shuffle(reg_t zmm) { return _mm512_shuffle_epi32(zmm, (_MM_PERM_ENUM)mask); } - static void storeu(void *mem, zmm_t x) { + static void storeu(void *mem, reg_t x) { return _mm512_storeu_si512(mem, x); } - static ymm_t max(ymm_t x, ymm_t y) { return _mm256_max_epi32(x, y); } - static ymm_t min(ymm_t x, ymm_t y) { return _mm256_min_epi32(x, y); } + static halfreg_t max(halfreg_t x, halfreg_t y) { + return _mm256_max_epi32(x, y); + } + static halfreg_t min(halfreg_t x, halfreg_t y) { + return _mm256_min_epi32(x, y); + } + static reg_t reverse(reg_t zmm) { + const auto rev_index = _mm512_set_epi32(NETWORK_32BIT_5); + return permutexvar(rev_index, zmm); + } + static reg_t sort_vec(reg_t x) { + return sort_zmm_32bit>(x); + } + static reg_t cast_from(__m512i v) { return v; } + static __m512i cast_to(reg_t v) { return v; } + static int double_compressstore(type_t *left_addr, type_t *right_addr, + opmask_t k, reg_t reg) { + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); + } }; template <> struct zmm_vector { using type_t = float; - using zmm_t = __m512; - using ymm_t = __m256; + using reg_t = __m512; + using halfreg_t = __m256; using opmask_t = __mmask16; static const uint8_t numlanes = 16; +#ifdef XSS_MINIMAL_NETWORK_SORT + static constexpr int network_sort_threshold = numlanes; +#else + static constexpr int network_sort_threshold = 512; +#endif + static constexpr int partition_unroll_factor = 8; + + using swizzle_ops = avx512_32bit_swizzle_ops; static type_t type_max() { return X86_SIMD_SORT_INFINITYF; } static type_t type_min() { return -X86_SIMD_SORT_INFINITYF; } - static zmm_t zmm_max() { return _mm512_set1_ps(type_max()); } + static reg_t zmm_max() { return _mm512_set1_ps(type_max()); } static opmask_t knot_opmask(opmask_t x) { return _mm512_knot(x); } - static opmask_t ge(zmm_t x, zmm_t y) { + static opmask_t ge(reg_t x, reg_t y) { return _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ); } - static opmask_t gt(zmm_t x, zmm_t y) { + static opmask_t gt(reg_t x, reg_t y) { return _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ); } + static opmask_t get_partial_loadmask(uint64_t num_to_read) { + return ((0x1ull << num_to_read) - 0x1ull); + } + static int32_t convert_mask_to_int(opmask_t mask) { return mask; } + template + static opmask_t fpclass(reg_t x) { + return _mm512_fpclass_ps_mask(x, type); + } template - static ymm_t i64gather(__m512i index, void const *base) { + static halfreg_t i64gather(__m512i index, void const *base) { return _mm512_i64gather_ps(index, base, scale); } - static zmm_t merge(ymm_t y1, ymm_t y2) { - zmm_t z1 = _mm512_castsi512_ps( + static reg_t merge(halfreg_t y1, halfreg_t y2) { + reg_t z1 = _mm512_castsi512_ps( _mm512_castsi256_si512(_mm256_castps_si256(y1))); return _mm512_insertf32x8(z1, y2, 1); } - static zmm_t loadu(void const *mem) { return _mm512_loadu_ps(mem); } - static zmm_t max(zmm_t x, zmm_t y) { return _mm512_max_ps(x, y); } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) { + static reg_t loadu(void const *mem) { return _mm512_loadu_ps(mem); } + static reg_t max(reg_t x, reg_t y) { return _mm512_max_ps(x, y); } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { return _mm512_mask_compressstoreu_ps(mem, mask, x); } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) { + static reg_t maskz_loadu(opmask_t mask, void const *mem) { + return _mm512_maskz_loadu_ps(mask, mem); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { return _mm512_mask_loadu_ps(x, mask, mem); } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) { + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) { return _mm512_mask_mov_ps(x, mask, y); } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) { + static void mask_storeu(void *mem, opmask_t mask, reg_t x) { return _mm512_mask_storeu_ps(mem, mask, x); } - static zmm_t min(zmm_t x, zmm_t y) { return _mm512_min_ps(x, y); } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) { + static reg_t min(reg_t x, reg_t y) { return _mm512_min_ps(x, y); } + static reg_t permutexvar(__m512i idx, reg_t zmm) { return _mm512_permutexvar_ps(idx, zmm); } - static type_t reducemax(zmm_t v) { return _mm512_reduce_max_ps(v); } - static type_t reducemin(zmm_t v) { return _mm512_reduce_min_ps(v); } - static zmm_t set1(type_t v) { return _mm512_set1_ps(v); } + static type_t reducemax(reg_t v) { return _mm512_reduce_max_ps(v); } + static type_t reducemin(reg_t v) { return _mm512_reduce_min_ps(v); } + static reg_t set1(type_t v) { return _mm512_set1_ps(v); } template - static zmm_t shuffle(zmm_t zmm) { + static reg_t shuffle(reg_t zmm) { return _mm512_shuffle_ps(zmm, zmm, (_MM_PERM_ENUM)mask); } - static void storeu(void *mem, zmm_t x) { return _mm512_storeu_ps(mem, x); } + static void storeu(void *mem, reg_t x) { return _mm512_storeu_ps(mem, x); } - static ymm_t max(ymm_t x, ymm_t y) { return _mm256_max_ps(x, y); } - static ymm_t min(ymm_t x, ymm_t y) { return _mm256_min_ps(x, y); } + static halfreg_t max(halfreg_t x, halfreg_t y) { + return _mm256_max_ps(x, y); + } + static halfreg_t min(halfreg_t x, halfreg_t y) { + return _mm256_min_ps(x, y); + } + static reg_t reverse(reg_t zmm) { + const auto rev_index = _mm512_set_epi32(NETWORK_32BIT_5); + return permutexvar(rev_index, zmm); + } + static reg_t sort_vec(reg_t x) { + return sort_zmm_32bit>(x); + } + static reg_t cast_from(__m512i v) { return _mm512_castsi512_ps(v); } + static __m512i cast_to(reg_t v) { return _mm512_castps_si512(v); } + static int double_compressstore(type_t *left_addr, type_t *right_addr, + opmask_t k, reg_t reg) { + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); + } }; /* * Assumes zmm is random and performs a full sorting network defined in * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg */ -template -X86_SIMD_SORT_INLINE zmm_t sort_zmm_32bit(zmm_t zmm) { +template +X86_SIMD_SORT_INLINE reg_t sort_zmm_32bit(reg_t zmm) { zmm = cmp_merge( zmm, vtype::template shuffle(zmm), 0xAAAA); zmm = cmp_merge( @@ -193,249 +267,71 @@ X86_SIMD_SORT_INLINE zmm_t sort_zmm_32bit(zmm_t zmm) { return zmm; } -// Assumes zmm is bitonic and performs a recursive half cleaner -template -X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_32bit(zmm_t zmm) { - // 1) half_cleaner[16]: compare 1-9, 2-10, 3-11 etc .. - zmm = cmp_merge( - zmm, vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_7), zmm), - 0xFF00); - // 2) half_cleaner[8]: compare 1-5, 2-6, 3-7 etc .. - zmm = cmp_merge( - zmm, vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_6), zmm), - 0xF0F0); - // 3) half_cleaner[4] - zmm = cmp_merge( - zmm, vtype::template shuffle(zmm), 0xCCCC); - // 3) half_cleaner[1] - zmm = cmp_merge( - zmm, vtype::template shuffle(zmm), 0xAAAA); - return zmm; -} +struct avx512_32bit_swizzle_ops { + template + X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n( + typename vtype::reg_t reg) { + __m512i v = vtype::cast_to(reg); -// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner -template -X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_32bit(zmm_t *zmm1, - zmm_t *zmm2) { - // 1) First step of a merging network: coex of zmm1 and zmm2 reversed - *zmm2 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), *zmm2); - zmm_t zmm3 = vtype::min(*zmm1, *zmm2); - zmm_t zmm4 = vtype::max(*zmm1, *zmm2); - // 2) Recursive half cleaner for each - *zmm1 = bitonic_merge_zmm_32bit(zmm3); - *zmm2 = bitonic_merge_zmm_32bit(zmm4); -} + if constexpr (scale == 2) { + v = _mm512_shuffle_epi32(v, (_MM_PERM_ENUM)0b10110001); + } else if constexpr (scale == 4) { + v = _mm512_shuffle_epi32(v, (_MM_PERM_ENUM)0b01001110); + } else if constexpr (scale == 8) { + v = _mm512_shuffle_i64x2(v, v, 0b10110001); + } else if constexpr (scale == 16) { + v = _mm512_shuffle_i64x2(v, v, 0b01001110); + } else { + static_assert(scale == -1, "should not be reached"); + } -// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive -// half cleaner -template -X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_32bit(zmm_t *zmm) { - zmm_t zmm2r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[2]); - zmm_t zmm3r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[3]); - zmm_t zmm_t1 = vtype::min(zmm[0], zmm3r); - zmm_t zmm_t2 = vtype::min(zmm[1], zmm2r); - zmm_t zmm_t3 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), - vtype::max(zmm[1], zmm2r)); - zmm_t zmm_t4 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), - vtype::max(zmm[0], zmm3r)); - zmm_t zmm0 = vtype::min(zmm_t1, zmm_t2); - zmm_t zmm1 = vtype::max(zmm_t1, zmm_t2); - zmm_t zmm2 = vtype::min(zmm_t3, zmm_t4); - zmm_t zmm3 = vtype::max(zmm_t3, zmm_t4); - zmm[0] = bitonic_merge_zmm_32bit(zmm0); - zmm[1] = bitonic_merge_zmm_32bit(zmm1); - zmm[2] = bitonic_merge_zmm_32bit(zmm2); - zmm[3] = bitonic_merge_zmm_32bit(zmm3); -} - -template -X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_32bit(zmm_t *zmm) { - zmm_t zmm4r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[4]); - zmm_t zmm5r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[5]); - zmm_t zmm6r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[6]); - zmm_t zmm7r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[7]); - zmm_t zmm_t1 = vtype::min(zmm[0], zmm7r); - zmm_t zmm_t2 = vtype::min(zmm[1], zmm6r); - zmm_t zmm_t3 = vtype::min(zmm[2], zmm5r); - zmm_t zmm_t4 = vtype::min(zmm[3], zmm4r); - zmm_t zmm_t5 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), - vtype::max(zmm[3], zmm4r)); - zmm_t zmm_t6 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), - vtype::max(zmm[2], zmm5r)); - zmm_t zmm_t7 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), - vtype::max(zmm[1], zmm6r)); - zmm_t zmm_t8 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), - vtype::max(zmm[0], zmm7r)); - COEX(zmm_t1, zmm_t3); - COEX(zmm_t2, zmm_t4); - COEX(zmm_t5, zmm_t7); - COEX(zmm_t6, zmm_t8); - COEX(zmm_t1, zmm_t2); - COEX(zmm_t3, zmm_t4); - COEX(zmm_t5, zmm_t6); - COEX(zmm_t7, zmm_t8); - zmm[0] = bitonic_merge_zmm_32bit(zmm_t1); - zmm[1] = bitonic_merge_zmm_32bit(zmm_t2); - zmm[2] = bitonic_merge_zmm_32bit(zmm_t3); - zmm[3] = bitonic_merge_zmm_32bit(zmm_t4); - zmm[4] = bitonic_merge_zmm_32bit(zmm_t5); - zmm[5] = bitonic_merge_zmm_32bit(zmm_t6); - zmm[6] = bitonic_merge_zmm_32bit(zmm_t7); - zmm[7] = bitonic_merge_zmm_32bit(zmm_t8); -} - -template -X86_SIMD_SORT_INLINE void sort_16_32bit(type_t *arr, int32_t N) { - typename vtype::opmask_t load_mask = (0x0001 << N) - 0x0001; - typename vtype::zmm_t zmm = - vtype::mask_loadu(vtype::zmm_max(), load_mask, arr); - vtype::mask_storeu(arr, load_mask, sort_zmm_32bit(zmm)); -} - -template -X86_SIMD_SORT_INLINE void sort_32_32bit(type_t *arr, int32_t N) { - if (N <= 16) { - sort_16_32bit(arr, N); - return; - } - using zmm_t = typename vtype::zmm_t; - zmm_t zmm1 = vtype::loadu(arr); - typename vtype::opmask_t load_mask = (0x0001 << (N - 16)) - 0x0001; - zmm_t zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr + 16); - zmm1 = sort_zmm_32bit(zmm1); - zmm2 = sort_zmm_32bit(zmm2); - bitonic_merge_two_zmm_32bit(&zmm1, &zmm2); - vtype::storeu(arr, zmm1); - vtype::mask_storeu(arr + 16, load_mask, zmm2); -} - -template -X86_SIMD_SORT_INLINE void sort_64_32bit(type_t *arr, int32_t N) { - if (N <= 32) { - sort_32_32bit(arr, N); - return; - } - using zmm_t = typename vtype::zmm_t; - using opmask_t = typename vtype::opmask_t; - zmm_t zmm[4]; - zmm[0] = vtype::loadu(arr); - zmm[1] = vtype::loadu(arr + 16); - opmask_t load_mask1 = 0xFFFF, load_mask2 = 0xFFFF; - uint64_t combined_mask = (0x1ull << (N - 32)) - 0x1ull; - load_mask1 &= combined_mask & 0xFFFF; - load_mask2 &= (combined_mask >> 16) & 0xFFFF; - zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 32); - zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 48); - zmm[0] = sort_zmm_32bit(zmm[0]); - zmm[1] = sort_zmm_32bit(zmm[1]); - zmm[2] = sort_zmm_32bit(zmm[2]); - zmm[3] = sort_zmm_32bit(zmm[3]); - bitonic_merge_two_zmm_32bit(&zmm[0], &zmm[1]); - bitonic_merge_two_zmm_32bit(&zmm[2], &zmm[3]); - bitonic_merge_four_zmm_32bit(zmm); - vtype::storeu(arr, zmm[0]); - vtype::storeu(arr + 16, zmm[1]); - vtype::mask_storeu(arr + 32, load_mask1, zmm[2]); - vtype::mask_storeu(arr + 48, load_mask2, zmm[3]); -} - -template -X86_SIMD_SORT_INLINE void sort_128_32bit(type_t *arr, int32_t N) { - if (N <= 64) { - sort_64_32bit(arr, N); - return; - } - using zmm_t = typename vtype::zmm_t; - using opmask_t = typename vtype::opmask_t; - zmm_t zmm[8]; - zmm[0] = vtype::loadu(arr); - zmm[1] = vtype::loadu(arr + 16); - zmm[2] = vtype::loadu(arr + 32); - zmm[3] = vtype::loadu(arr + 48); - zmm[0] = sort_zmm_32bit(zmm[0]); - zmm[1] = sort_zmm_32bit(zmm[1]); - zmm[2] = sort_zmm_32bit(zmm[2]); - zmm[3] = sort_zmm_32bit(zmm[3]); - opmask_t load_mask1 = 0xFFFF, load_mask2 = 0xFFFF; - opmask_t load_mask3 = 0xFFFF, load_mask4 = 0xFFFF; - if (N != 128) { - uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull; - load_mask1 &= combined_mask & 0xFFFF; - load_mask2 &= (combined_mask >> 16) & 0xFFFF; - load_mask3 &= (combined_mask >> 32) & 0xFFFF; - load_mask4 &= (combined_mask >> 48) & 0xFFFF; - } - zmm[4] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 64); - zmm[5] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 80); - zmm[6] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 96); - zmm[7] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 112); - zmm[4] = sort_zmm_32bit(zmm[4]); - zmm[5] = sort_zmm_32bit(zmm[5]); - zmm[6] = sort_zmm_32bit(zmm[6]); - zmm[7] = sort_zmm_32bit(zmm[7]); - bitonic_merge_two_zmm_32bit(&zmm[0], &zmm[1]); - bitonic_merge_two_zmm_32bit(&zmm[2], &zmm[3]); - bitonic_merge_two_zmm_32bit(&zmm[4], &zmm[5]); - bitonic_merge_two_zmm_32bit(&zmm[6], &zmm[7]); - bitonic_merge_four_zmm_32bit(zmm); - bitonic_merge_four_zmm_32bit(zmm + 4); - bitonic_merge_eight_zmm_32bit(zmm); - vtype::storeu(arr, zmm[0]); - vtype::storeu(arr + 16, zmm[1]); - vtype::storeu(arr + 32, zmm[2]); - vtype::storeu(arr + 48, zmm[3]); - vtype::mask_storeu(arr + 64, load_mask1, zmm[4]); - vtype::mask_storeu(arr + 80, load_mask2, zmm[5]); - vtype::mask_storeu(arr + 96, load_mask3, zmm[6]); - vtype::mask_storeu(arr + 112, load_mask4, zmm[7]); -} - - -template -static void qsort_32bit_(type_t *arr, int64_t left, int64_t right, - int64_t max_iters) { - /* - * Resort to std::sort if quicksort isnt making any progress - */ - if (max_iters <= 0) { - std::sort(arr + left, arr + right + 1); - return; - } - /* - * Base case: use bitonic networks to sort arrays <= 128 - */ - if (right + 1 - left <= 128) { - sort_128_32bit(arr + left, (int32_t)(right + 1 - left)); - return; + return vtype::cast_from(v); } - type_t pivot = get_pivot_scalar(arr, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512_unrolled( - arr, left, right + 1, pivot, &smallest, &biggest, false); - if (pivot != smallest) - qsort_32bit_(arr, left, pivot_index - 1, max_iters - 1); - if (pivot != biggest) - qsort_32bit_(arr, pivot_index, right, max_iters - 1); -} + template + X86_SIMD_SORT_INLINE typename vtype::reg_t reverse_n( + typename vtype::reg_t reg) { + __m512i v = vtype::cast_to(reg); -template <> -void inline avx512_qsort(int32_t *arr, int64_t fromIndex, int64_t toIndex) { - int64_t arrsize = toIndex - fromIndex; - if (arrsize > 1) { - qsort_32bit_, int32_t>(arr, fromIndex, toIndex - 1, - 2 * (int64_t)log2(arrsize)); - } -} + if constexpr (scale == 2) { + return swap_n(reg); + } else if constexpr (scale == 4) { + __m512i mask = _mm512_set_epi32(12, 13, 14, 15, 8, 9, 10, 11, 4, 5, + 6, 7, 0, 1, 2, 3); + v = _mm512_permutexvar_epi32(mask, v); + } else if constexpr (scale == 8) { + __m512i mask = _mm512_set_epi32(8, 9, 10, 11, 12, 13, 14, 15, 0, 1, + 2, 3, 4, 5, 6, 7); + v = _mm512_permutexvar_epi32(mask, v); + } else if constexpr (scale == 16) { + return vtype::reverse(reg); + } else { + static_assert(scale == -1, "should not be reached"); + } -template <> -void inline avx512_qsort(float *arr, int64_t fromIndex, int64_t toIndex) { - int64_t arrsize = toIndex - fromIndex; - if (arrsize > 1) { - qsort_32bit_, float>(arr, fromIndex, toIndex - 1, - 2 * (int64_t)log2(arrsize)); + return vtype::cast_from(v); } -} + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t merge_n( + typename vtype::reg_t reg, typename vtype::reg_t other) { + __m512i v1 = vtype::cast_to(reg); + __m512i v2 = vtype::cast_to(other); + + if constexpr (scale == 2) { + v1 = _mm512_mask_blend_epi32(0b0101010101010101, v1, v2); + } else if constexpr (scale == 4) { + v1 = _mm512_mask_blend_epi32(0b0011001100110011, v1, v2); + } else if constexpr (scale == 8) { + v1 = _mm512_mask_blend_epi32(0b0000111100001111, v1, v2); + } else if constexpr (scale == 16) { + v1 = _mm512_mask_blend_epi32(0b0000000011111111, v1, v2); + } else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v1); + } +}; #endif // AVX512_QSORT_32BIT diff --git a/src/java.base/linux/native/libsimdsort/avx512-64bit-common.h b/src/java.base/linux/native/libsimdsort/avx512-64bit-common.h deleted file mode 100644 index 9993cd22e63..00000000000 --- a/src/java.base/linux/native/libsimdsort/avx512-64bit-common.h +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Copyright (c) 2021, 2023, Intel Corporation. All rights reserved. - * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. - * - * This code is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License version 2 only, as - * published by the Free Software Foundation. - * - * This code is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - * version 2 for more details (a copy is included in the LICENSE file that - * accompanied this code). - * - * You should have received a copy of the GNU General Public License version - * 2 along with this work; if not, write to the Free Software Foundation, - * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. - * - * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA - * or visit www.oracle.com if you need additional information or have any - * questions. - * - */ - -// This implementation is based on x86-simd-sort(https://github.com/intel/x86-simd-sort) - -#ifndef AVX512_64BIT_COMMON -#define AVX512_64BIT_COMMON -#include "avx512-common-qsort.h" - -/* - * Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic - * sorting network (see - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) - */ -// ZMM 7, 6, 5, 4, 3, 2, 1, 0 -#define NETWORK_64BIT_1 4, 5, 6, 7, 0, 1, 2, 3 -#define NETWORK_64BIT_2 0, 1, 2, 3, 4, 5, 6, 7 -#define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2 -#define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4 - -template <> -struct zmm_vector { - using type_t = int64_t; - using zmm_t = __m512i; - using zmmi_t = __m512i; - using ymm_t = __m512i; - using opmask_t = __mmask8; - static const uint8_t numlanes = 8; - - static type_t type_max() { return X86_SIMD_SORT_MAX_INT64; } - static type_t type_min() { return X86_SIMD_SORT_MIN_INT64; } - static zmm_t zmm_max() { - return _mm512_set1_epi64(type_max()); - } // TODO: this should broadcast bits as is? - - static zmmi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, - int v8) { - return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); - } - static opmask_t kxor_opmask(opmask_t x, opmask_t y) { - return _kxor_mask8(x, y); - } - static opmask_t knot_opmask(opmask_t x) { return _knot_mask8(x); } - static opmask_t le(zmm_t x, zmm_t y) { - return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_LE); - } - static opmask_t ge(zmm_t x, zmm_t y) { - return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_NLT); - } - static opmask_t gt(zmm_t x, zmm_t y) { - return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_GT); - } - static opmask_t eq(zmm_t x, zmm_t y) { - return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_EQ); - } - template - static zmm_t mask_i64gather(zmm_t src, opmask_t mask, __m512i index, - void const *base) { - return _mm512_mask_i64gather_epi64(src, mask, index, base, scale); - } - template - static zmm_t i64gather(__m512i index, void const *base) { - return _mm512_i64gather_epi64(index, base, scale); - } - static zmm_t loadu(void const *mem) { return _mm512_loadu_si512(mem); } - static zmm_t max(zmm_t x, zmm_t y) { return _mm512_max_epi64(x, y); } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) { - return _mm512_mask_compressstoreu_epi64(mem, mask, x); - } - static zmm_t maskz_loadu(opmask_t mask, void const *mem) { - return _mm512_maskz_loadu_epi64(mask, mem); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) { - return _mm512_mask_loadu_epi64(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) { - return _mm512_mask_mov_epi64(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) { - return _mm512_mask_storeu_epi64(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) { return _mm512_min_epi64(x, y); } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) { - return _mm512_permutexvar_epi64(idx, zmm); - } - static type_t reducemax(zmm_t v) { return _mm512_reduce_max_epi64(v); } - static type_t reducemin(zmm_t v) { return _mm512_reduce_min_epi64(v); } - static zmm_t set1(type_t v) { return _mm512_set1_epi64(v); } - template - static zmm_t shuffle(zmm_t zmm) { - __m512d temp = _mm512_castsi512_pd(zmm); - return _mm512_castpd_si512( - _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); - } - static void storeu(void *mem, zmm_t x) { _mm512_storeu_si512(mem, x); } -}; -template <> -struct zmm_vector { - using type_t = double; - using zmm_t = __m512d; - using zmmi_t = __m512i; - using ymm_t = __m512d; - using opmask_t = __mmask8; - static const uint8_t numlanes = 8; - - static type_t type_max() { return X86_SIMD_SORT_INFINITY; } - static type_t type_min() { return -X86_SIMD_SORT_INFINITY; } - static zmm_t zmm_max() { return _mm512_set1_pd(type_max()); } - - static zmmi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, - int v8) { - return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); - } - - static zmm_t maskz_loadu(opmask_t mask, void const *mem) { - return _mm512_maskz_loadu_pd(mask, mem); - } - static opmask_t knot_opmask(opmask_t x) { return _knot_mask8(x); } - static opmask_t ge(zmm_t x, zmm_t y) { - return _mm512_cmp_pd_mask(x, y, _CMP_GE_OQ); - } - static opmask_t gt(zmm_t x, zmm_t y) { - return _mm512_cmp_pd_mask(x, y, _CMP_GT_OQ); - } - static opmask_t eq(zmm_t x, zmm_t y) { - return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OQ); - } - template - static opmask_t fpclass(zmm_t x) { - return _mm512_fpclass_pd_mask(x, type); - } - template - static zmm_t mask_i64gather(zmm_t src, opmask_t mask, __m512i index, - void const *base) { - return _mm512_mask_i64gather_pd(src, mask, index, base, scale); - } - template - static zmm_t i64gather(__m512i index, void const *base) { - return _mm512_i64gather_pd(index, base, scale); - } - static zmm_t loadu(void const *mem) { return _mm512_loadu_pd(mem); } - static zmm_t max(zmm_t x, zmm_t y) { return _mm512_max_pd(x, y); } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) { - return _mm512_mask_compressstoreu_pd(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) { - return _mm512_mask_loadu_pd(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) { - return _mm512_mask_mov_pd(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) { - return _mm512_mask_storeu_pd(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) { return _mm512_min_pd(x, y); } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) { - return _mm512_permutexvar_pd(idx, zmm); - } - static type_t reducemax(zmm_t v) { return _mm512_reduce_max_pd(v); } - static type_t reducemin(zmm_t v) { return _mm512_reduce_min_pd(v); } - static zmm_t set1(type_t v) { return _mm512_set1_pd(v); } - template - static zmm_t shuffle(zmm_t zmm) { - return _mm512_shuffle_pd(zmm, zmm, (_MM_PERM_ENUM)mask); - } - static void storeu(void *mem, zmm_t x) { _mm512_storeu_pd(mem, x); } -}; - -/* - * Assumes zmm is random and performs a full sorting network defined in - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg - */ -template -X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t zmm) { - const typename vtype::zmmi_t rev_index = vtype::seti(NETWORK_64BIT_2); - zmm = cmp_merge( - zmm, vtype::template shuffle(zmm), 0xAA); - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::seti(NETWORK_64BIT_1), zmm), 0xCC); - zmm = cmp_merge( - zmm, vtype::template shuffle(zmm), 0xAA); - zmm = cmp_merge(zmm, vtype::permutexvar(rev_index, zmm), 0xF0); - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::seti(NETWORK_64BIT_3), zmm), 0xCC); - zmm = cmp_merge( - zmm, vtype::template shuffle(zmm), 0xAA); - return zmm; -} - - -#endif diff --git a/src/java.base/linux/native/libsimdsort/avx512-64bit-qsort.hpp b/src/java.base/linux/native/libsimdsort/avx512-64bit-qsort.hpp index e28ebe19695..6c1fba6ebb6 100644 --- a/src/java.base/linux/native/libsimdsort/avx512-64bit-qsort.hpp +++ b/src/java.base/linux/native/libsimdsort/avx512-64bit-qsort.hpp @@ -27,746 +27,317 @@ #ifndef AVX512_QSORT_64BIT #define AVX512_QSORT_64BIT -#include "avx512-64bit-common.h" +#include "xss-common-includes.h" +#include "xss-common-qsort.h" -// Assumes zmm is bitonic and performs a recursive half cleaner -template -X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit(zmm_t zmm) { - // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 +/* + * Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic + * sorting network (see + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) + */ +// ZMM 7, 6, 5, 4, 3, 2, 1, 0 +#define NETWORK_64BIT_1 4, 5, 6, 7, 0, 1, 2, 3 +#define NETWORK_64BIT_2 0, 1, 2, 3, 4, 5, 6, 7 +#define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2 +#define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4 + +template +X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t zmm); + +struct avx512_64bit_swizzle_ops; + +template <> +struct zmm_vector { + using type_t = int64_t; + using reg_t = __m512i; + using regi_t = __m512i; + using halfreg_t = __m512i; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; +#ifdef XSS_MINIMAL_NETWORK_SORT + static constexpr int network_sort_threshold = numlanes; +#else + static constexpr int network_sort_threshold = 256; +#endif + static constexpr int partition_unroll_factor = 8; + + using swizzle_ops = avx512_64bit_swizzle_ops; + + static type_t type_max() { return X86_SIMD_SORT_MAX_INT64; } + static type_t type_min() { return X86_SIMD_SORT_MIN_INT64; } + static reg_t zmm_max() { + return _mm512_set1_epi64(type_max()); + } // TODO: this should broadcast bits as is? + + static regi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, + int v8) { + return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); + } + static reg_t set(type_t v1, type_t v2, type_t v3, type_t v4, type_t v5, + type_t v6, type_t v7, type_t v8) { + return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); + } + static opmask_t kxor_opmask(opmask_t x, opmask_t y) { + return _kxor_mask8(x, y); + } + static opmask_t knot_opmask(opmask_t x) { return _knot_mask8(x); } + static opmask_t le(reg_t x, reg_t y) { + return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_LE); + } + static opmask_t ge(reg_t x, reg_t y) { + return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_NLT); + } + static opmask_t gt(reg_t x, reg_t y) { + return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_GT); + } + static opmask_t get_partial_loadmask(uint64_t num_to_read) { + return ((0x1ull << num_to_read) - 0x1ull); + } + static opmask_t eq(reg_t x, reg_t y) { + return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_EQ); + } + template + static reg_t mask_i64gather(reg_t src, opmask_t mask, __m512i index, + void const *base) { + return _mm512_mask_i64gather_epi64(src, mask, index, base, scale); + } + template + static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, + void const *base) { + return _mm512_mask_i32gather_epi64(src, mask, index, base, scale); + } + static reg_t i64gather(type_t *arr, arrsize_t *ind) { + return set(arr[ind[7]], arr[ind[6]], arr[ind[5]], arr[ind[4]], + arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); + } + static reg_t loadu(void const *mem) { return _mm512_loadu_si512(mem); } + static reg_t max(reg_t x, reg_t y) { return _mm512_max_epi64(x, y); } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { + return _mm512_mask_compressstoreu_epi64(mem, mask, x); + } + static reg_t maskz_loadu(opmask_t mask, void const *mem) { + return _mm512_maskz_loadu_epi64(mask, mem); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { + return _mm512_mask_loadu_epi64(x, mask, mem); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) { + return _mm512_mask_mov_epi64(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) { + return _mm512_mask_storeu_epi64(mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) { return _mm512_min_epi64(x, y); } + static reg_t permutexvar(__m512i idx, reg_t zmm) { + return _mm512_permutexvar_epi64(idx, zmm); + } + static type_t reducemax(reg_t v) { return _mm512_reduce_max_epi64(v); } + static type_t reducemin(reg_t v) { return _mm512_reduce_min_epi64(v); } + static reg_t set1(type_t v) { return _mm512_set1_epi64(v); } + template + static reg_t shuffle(reg_t zmm) { + __m512d temp = _mm512_castsi512_pd(zmm); + return _mm512_castpd_si512( + _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); + } + static void storeu(void *mem, reg_t x) { _mm512_storeu_si512(mem, x); } + static reg_t reverse(reg_t zmm) { + const regi_t rev_index = seti(NETWORK_64BIT_2); + return permutexvar(rev_index, zmm); + } + static reg_t sort_vec(reg_t x) { + return sort_zmm_64bit>(x); + } + static reg_t cast_from(__m512i v) { return v; } + static __m512i cast_to(reg_t v) { return v; } + static int double_compressstore(type_t *left_addr, type_t *right_addr, + opmask_t k, reg_t reg) { + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); + } +}; +template <> +struct zmm_vector { + using type_t = double; + using reg_t = __m512d; + using regi_t = __m512i; + using halfreg_t = __m512d; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; +#ifdef XSS_MINIMAL_NETWORK_SORT + static constexpr int network_sort_threshold = numlanes; +#else + static constexpr int network_sort_threshold = 256; +#endif + static constexpr int partition_unroll_factor = 8; + + using swizzle_ops = avx512_64bit_swizzle_ops; + + static type_t type_max() { return X86_SIMD_SORT_INFINITY; } + static type_t type_min() { return -X86_SIMD_SORT_INFINITY; } + static reg_t zmm_max() { return _mm512_set1_pd(type_max()); } + static regi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, + int v8) { + return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); + } + static reg_t set(type_t v1, type_t v2, type_t v3, type_t v4, type_t v5, + type_t v6, type_t v7, type_t v8) { + return _mm512_set_pd(v1, v2, v3, v4, v5, v6, v7, v8); + } + static reg_t maskz_loadu(opmask_t mask, void const *mem) { + return _mm512_maskz_loadu_pd(mask, mem); + } + static opmask_t knot_opmask(opmask_t x) { return _knot_mask8(x); } + static opmask_t ge(reg_t x, reg_t y) { + return _mm512_cmp_pd_mask(x, y, _CMP_GE_OQ); + } + static opmask_t gt(reg_t x, reg_t y) { + return _mm512_cmp_pd_mask(x, y, _CMP_GT_OQ); + } + static opmask_t eq(reg_t x, reg_t y) { + return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OQ); + } + static opmask_t get_partial_loadmask(uint64_t num_to_read) { + return ((0x1ull << num_to_read) - 0x1ull); + } + static int32_t convert_mask_to_int(opmask_t mask) { return mask; } + template + static opmask_t fpclass(reg_t x) { + return _mm512_fpclass_pd_mask(x, type); + } + template + static reg_t mask_i64gather(reg_t src, opmask_t mask, __m512i index, + void const *base) { + return _mm512_mask_i64gather_pd(src, mask, index, base, scale); + } + template + static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, + void const *base) { + return _mm512_mask_i32gather_pd(src, mask, index, base, scale); + } + static reg_t i64gather(type_t *arr, arrsize_t *ind) { + return set(arr[ind[7]], arr[ind[6]], arr[ind[5]], arr[ind[4]], + arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); + } + static reg_t loadu(void const *mem) { return _mm512_loadu_pd(mem); } + static reg_t max(reg_t x, reg_t y) { return _mm512_max_pd(x, y); } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { + return _mm512_mask_compressstoreu_pd(mem, mask, x); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { + return _mm512_mask_loadu_pd(x, mask, mem); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) { + return _mm512_mask_mov_pd(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) { + return _mm512_mask_storeu_pd(mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) { return _mm512_min_pd(x, y); } + static reg_t permutexvar(__m512i idx, reg_t zmm) { + return _mm512_permutexvar_pd(idx, zmm); + } + static type_t reducemax(reg_t v) { return _mm512_reduce_max_pd(v); } + static type_t reducemin(reg_t v) { return _mm512_reduce_min_pd(v); } + static reg_t set1(type_t v) { return _mm512_set1_pd(v); } + template + static reg_t shuffle(reg_t zmm) { + return _mm512_shuffle_pd(zmm, zmm, (_MM_PERM_ENUM)mask); + } + static void storeu(void *mem, reg_t x) { _mm512_storeu_pd(mem, x); } + static reg_t reverse(reg_t zmm) { + const regi_t rev_index = seti(NETWORK_64BIT_2); + return permutexvar(rev_index, zmm); + } + static reg_t sort_vec(reg_t x) { + return sort_zmm_64bit>(x); + } + static reg_t cast_from(__m512i v) { return _mm512_castsi512_pd(v); } + static __m512i cast_to(reg_t v) { return _mm512_castpd_si512(v); } + static int double_compressstore(type_t *left_addr, type_t *right_addr, + opmask_t k, reg_t reg) { + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); + } +}; + +/* + * Assumes zmm is random and performs a full sorting network defined in + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg + */ +template +X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t zmm) { + const typename vtype::regi_t rev_index = vtype::seti(NETWORK_64BIT_2); zmm = cmp_merge( - zmm, vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), zmm), 0xF0); - // 2) half_cleaner[4] + zmm, vtype::template shuffle(zmm), 0xAA); zmm = cmp_merge( - zmm, vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), zmm), 0xCC); - // 3) half_cleaner[1] + zmm, vtype::permutexvar(vtype::seti(NETWORK_64BIT_1), zmm), 0xCC); + zmm = cmp_merge( + zmm, vtype::template shuffle(zmm), 0xAA); + zmm = cmp_merge(zmm, vtype::permutexvar(rev_index, zmm), 0xF0); + zmm = cmp_merge( + zmm, vtype::permutexvar(vtype::seti(NETWORK_64BIT_3), zmm), 0xCC); zmm = cmp_merge( zmm, vtype::template shuffle(zmm), 0xAA); return zmm; } -// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner -template -X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &zmm1, - zmm_t &zmm2) { - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - // 1) First step of a merging network: coex of zmm1 and zmm2 reversed - zmm2 = vtype::permutexvar(rev_index, zmm2); - zmm_t zmm3 = vtype::min(zmm1, zmm2); - zmm_t zmm4 = vtype::max(zmm1, zmm2); - // 2) Recursive half cleaner for each - zmm1 = bitonic_merge_zmm_64bit(zmm3); - zmm2 = bitonic_merge_zmm_64bit(zmm4); -} -// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive -// half cleaner -template -X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *zmm) { - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - // 1) First step of a merging network - zmm_t zmm2r = vtype::permutexvar(rev_index, zmm[2]); - zmm_t zmm3r = vtype::permutexvar(rev_index, zmm[3]); - zmm_t zmm_t1 = vtype::min(zmm[0], zmm3r); - zmm_t zmm_t2 = vtype::min(zmm[1], zmm2r); - // 2) Recursive half clearer: 16 - zmm_t zmm_t3 = vtype::permutexvar(rev_index, vtype::max(zmm[1], zmm2r)); - zmm_t zmm_t4 = vtype::permutexvar(rev_index, vtype::max(zmm[0], zmm3r)); - zmm_t zmm0 = vtype::min(zmm_t1, zmm_t2); - zmm_t zmm1 = vtype::max(zmm_t1, zmm_t2); - zmm_t zmm2 = vtype::min(zmm_t3, zmm_t4); - zmm_t zmm3 = vtype::max(zmm_t3, zmm_t4); - zmm[0] = bitonic_merge_zmm_64bit(zmm0); - zmm[1] = bitonic_merge_zmm_64bit(zmm1); - zmm[2] = bitonic_merge_zmm_64bit(zmm2); - zmm[3] = bitonic_merge_zmm_64bit(zmm3); -} -template -X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *zmm) { - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - zmm_t zmm4r = vtype::permutexvar(rev_index, zmm[4]); - zmm_t zmm5r = vtype::permutexvar(rev_index, zmm[5]); - zmm_t zmm6r = vtype::permutexvar(rev_index, zmm[6]); - zmm_t zmm7r = vtype::permutexvar(rev_index, zmm[7]); - zmm_t zmm_t1 = vtype::min(zmm[0], zmm7r); - zmm_t zmm_t2 = vtype::min(zmm[1], zmm6r); - zmm_t zmm_t3 = vtype::min(zmm[2], zmm5r); - zmm_t zmm_t4 = vtype::min(zmm[3], zmm4r); - zmm_t zmm_t5 = vtype::permutexvar(rev_index, vtype::max(zmm[3], zmm4r)); - zmm_t zmm_t6 = vtype::permutexvar(rev_index, vtype::max(zmm[2], zmm5r)); - zmm_t zmm_t7 = vtype::permutexvar(rev_index, vtype::max(zmm[1], zmm6r)); - zmm_t zmm_t8 = vtype::permutexvar(rev_index, vtype::max(zmm[0], zmm7r)); - COEX(zmm_t1, zmm_t3); - COEX(zmm_t2, zmm_t4); - COEX(zmm_t5, zmm_t7); - COEX(zmm_t6, zmm_t8); - COEX(zmm_t1, zmm_t2); - COEX(zmm_t3, zmm_t4); - COEX(zmm_t5, zmm_t6); - COEX(zmm_t7, zmm_t8); - zmm[0] = bitonic_merge_zmm_64bit(zmm_t1); - zmm[1] = bitonic_merge_zmm_64bit(zmm_t2); - zmm[2] = bitonic_merge_zmm_64bit(zmm_t3); - zmm[3] = bitonic_merge_zmm_64bit(zmm_t4); - zmm[4] = bitonic_merge_zmm_64bit(zmm_t5); - zmm[5] = bitonic_merge_zmm_64bit(zmm_t6); - zmm[6] = bitonic_merge_zmm_64bit(zmm_t7); - zmm[7] = bitonic_merge_zmm_64bit(zmm_t8); -} -template -X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *zmm) { - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - zmm_t zmm8r = vtype::permutexvar(rev_index, zmm[8]); - zmm_t zmm9r = vtype::permutexvar(rev_index, zmm[9]); - zmm_t zmm10r = vtype::permutexvar(rev_index, zmm[10]); - zmm_t zmm11r = vtype::permutexvar(rev_index, zmm[11]); - zmm_t zmm12r = vtype::permutexvar(rev_index, zmm[12]); - zmm_t zmm13r = vtype::permutexvar(rev_index, zmm[13]); - zmm_t zmm14r = vtype::permutexvar(rev_index, zmm[14]); - zmm_t zmm15r = vtype::permutexvar(rev_index, zmm[15]); - zmm_t zmm_t1 = vtype::min(zmm[0], zmm15r); - zmm_t zmm_t2 = vtype::min(zmm[1], zmm14r); - zmm_t zmm_t3 = vtype::min(zmm[2], zmm13r); - zmm_t zmm_t4 = vtype::min(zmm[3], zmm12r); - zmm_t zmm_t5 = vtype::min(zmm[4], zmm11r); - zmm_t zmm_t6 = vtype::min(zmm[5], zmm10r); - zmm_t zmm_t7 = vtype::min(zmm[6], zmm9r); - zmm_t zmm_t8 = vtype::min(zmm[7], zmm8r); - zmm_t zmm_t9 = vtype::permutexvar(rev_index, vtype::max(zmm[7], zmm8r)); - zmm_t zmm_t10 = vtype::permutexvar(rev_index, vtype::max(zmm[6], zmm9r)); - zmm_t zmm_t11 = vtype::permutexvar(rev_index, vtype::max(zmm[5], zmm10r)); - zmm_t zmm_t12 = vtype::permutexvar(rev_index, vtype::max(zmm[4], zmm11r)); - zmm_t zmm_t13 = vtype::permutexvar(rev_index, vtype::max(zmm[3], zmm12r)); - zmm_t zmm_t14 = vtype::permutexvar(rev_index, vtype::max(zmm[2], zmm13r)); - zmm_t zmm_t15 = vtype::permutexvar(rev_index, vtype::max(zmm[1], zmm14r)); - zmm_t zmm_t16 = vtype::permutexvar(rev_index, vtype::max(zmm[0], zmm15r)); - // Recusive half clear 16 zmm regs - COEX(zmm_t1, zmm_t5); - COEX(zmm_t2, zmm_t6); - COEX(zmm_t3, zmm_t7); - COEX(zmm_t4, zmm_t8); - COEX(zmm_t9, zmm_t13); - COEX(zmm_t10, zmm_t14); - COEX(zmm_t11, zmm_t15); - COEX(zmm_t12, zmm_t16); - // - COEX(zmm_t1, zmm_t3); - COEX(zmm_t2, zmm_t4); - COEX(zmm_t5, zmm_t7); - COEX(zmm_t6, zmm_t8); - COEX(zmm_t9, zmm_t11); - COEX(zmm_t10, zmm_t12); - COEX(zmm_t13, zmm_t15); - COEX(zmm_t14, zmm_t16); - // - COEX(zmm_t1, zmm_t2); - COEX(zmm_t3, zmm_t4); - COEX(zmm_t5, zmm_t6); - COEX(zmm_t7, zmm_t8); - COEX(zmm_t9, zmm_t10); - COEX(zmm_t11, zmm_t12); - COEX(zmm_t13, zmm_t14); - COEX(zmm_t15, zmm_t16); - // - zmm[0] = bitonic_merge_zmm_64bit(zmm_t1); - zmm[1] = bitonic_merge_zmm_64bit(zmm_t2); - zmm[2] = bitonic_merge_zmm_64bit(zmm_t3); - zmm[3] = bitonic_merge_zmm_64bit(zmm_t4); - zmm[4] = bitonic_merge_zmm_64bit(zmm_t5); - zmm[5] = bitonic_merge_zmm_64bit(zmm_t6); - zmm[6] = bitonic_merge_zmm_64bit(zmm_t7); - zmm[7] = bitonic_merge_zmm_64bit(zmm_t8); - zmm[8] = bitonic_merge_zmm_64bit(zmm_t9); - zmm[9] = bitonic_merge_zmm_64bit(zmm_t10); - zmm[10] = bitonic_merge_zmm_64bit(zmm_t11); - zmm[11] = bitonic_merge_zmm_64bit(zmm_t12); - zmm[12] = bitonic_merge_zmm_64bit(zmm_t13); - zmm[13] = bitonic_merge_zmm_64bit(zmm_t14); - zmm[14] = bitonic_merge_zmm_64bit(zmm_t15); - zmm[15] = bitonic_merge_zmm_64bit(zmm_t16); -} -template -X86_SIMD_SORT_INLINE void bitonic_merge_32_zmm_64bit(zmm_t *zmm) { - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - zmm_t zmm16r = vtype::permutexvar(rev_index, zmm[16]); - zmm_t zmm17r = vtype::permutexvar(rev_index, zmm[17]); - zmm_t zmm18r = vtype::permutexvar(rev_index, zmm[18]); - zmm_t zmm19r = vtype::permutexvar(rev_index, zmm[19]); - zmm_t zmm20r = vtype::permutexvar(rev_index, zmm[20]); - zmm_t zmm21r = vtype::permutexvar(rev_index, zmm[21]); - zmm_t zmm22r = vtype::permutexvar(rev_index, zmm[22]); - zmm_t zmm23r = vtype::permutexvar(rev_index, zmm[23]); - zmm_t zmm24r = vtype::permutexvar(rev_index, zmm[24]); - zmm_t zmm25r = vtype::permutexvar(rev_index, zmm[25]); - zmm_t zmm26r = vtype::permutexvar(rev_index, zmm[26]); - zmm_t zmm27r = vtype::permutexvar(rev_index, zmm[27]); - zmm_t zmm28r = vtype::permutexvar(rev_index, zmm[28]); - zmm_t zmm29r = vtype::permutexvar(rev_index, zmm[29]); - zmm_t zmm30r = vtype::permutexvar(rev_index, zmm[30]); - zmm_t zmm31r = vtype::permutexvar(rev_index, zmm[31]); - zmm_t zmm_t1 = vtype::min(zmm[0], zmm31r); - zmm_t zmm_t2 = vtype::min(zmm[1], zmm30r); - zmm_t zmm_t3 = vtype::min(zmm[2], zmm29r); - zmm_t zmm_t4 = vtype::min(zmm[3], zmm28r); - zmm_t zmm_t5 = vtype::min(zmm[4], zmm27r); - zmm_t zmm_t6 = vtype::min(zmm[5], zmm26r); - zmm_t zmm_t7 = vtype::min(zmm[6], zmm25r); - zmm_t zmm_t8 = vtype::min(zmm[7], zmm24r); - zmm_t zmm_t9 = vtype::min(zmm[8], zmm23r); - zmm_t zmm_t10 = vtype::min(zmm[9], zmm22r); - zmm_t zmm_t11 = vtype::min(zmm[10], zmm21r); - zmm_t zmm_t12 = vtype::min(zmm[11], zmm20r); - zmm_t zmm_t13 = vtype::min(zmm[12], zmm19r); - zmm_t zmm_t14 = vtype::min(zmm[13], zmm18r); - zmm_t zmm_t15 = vtype::min(zmm[14], zmm17r); - zmm_t zmm_t16 = vtype::min(zmm[15], zmm16r); - zmm_t zmm_t17 = vtype::permutexvar(rev_index, vtype::max(zmm[15], zmm16r)); - zmm_t zmm_t18 = vtype::permutexvar(rev_index, vtype::max(zmm[14], zmm17r)); - zmm_t zmm_t19 = vtype::permutexvar(rev_index, vtype::max(zmm[13], zmm18r)); - zmm_t zmm_t20 = vtype::permutexvar(rev_index, vtype::max(zmm[12], zmm19r)); - zmm_t zmm_t21 = vtype::permutexvar(rev_index, vtype::max(zmm[11], zmm20r)); - zmm_t zmm_t22 = vtype::permutexvar(rev_index, vtype::max(zmm[10], zmm21r)); - zmm_t zmm_t23 = vtype::permutexvar(rev_index, vtype::max(zmm[9], zmm22r)); - zmm_t zmm_t24 = vtype::permutexvar(rev_index, vtype::max(zmm[8], zmm23r)); - zmm_t zmm_t25 = vtype::permutexvar(rev_index, vtype::max(zmm[7], zmm24r)); - zmm_t zmm_t26 = vtype::permutexvar(rev_index, vtype::max(zmm[6], zmm25r)); - zmm_t zmm_t27 = vtype::permutexvar(rev_index, vtype::max(zmm[5], zmm26r)); - zmm_t zmm_t28 = vtype::permutexvar(rev_index, vtype::max(zmm[4], zmm27r)); - zmm_t zmm_t29 = vtype::permutexvar(rev_index, vtype::max(zmm[3], zmm28r)); - zmm_t zmm_t30 = vtype::permutexvar(rev_index, vtype::max(zmm[2], zmm29r)); - zmm_t zmm_t31 = vtype::permutexvar(rev_index, vtype::max(zmm[1], zmm30r)); - zmm_t zmm_t32 = vtype::permutexvar(rev_index, vtype::max(zmm[0], zmm31r)); - // Recusive half clear 16 zmm regs - COEX(zmm_t1, zmm_t9); - COEX(zmm_t2, zmm_t10); - COEX(zmm_t3, zmm_t11); - COEX(zmm_t4, zmm_t12); - COEX(zmm_t5, zmm_t13); - COEX(zmm_t6, zmm_t14); - COEX(zmm_t7, zmm_t15); - COEX(zmm_t8, zmm_t16); - COEX(zmm_t17, zmm_t25); - COEX(zmm_t18, zmm_t26); - COEX(zmm_t19, zmm_t27); - COEX(zmm_t20, zmm_t28); - COEX(zmm_t21, zmm_t29); - COEX(zmm_t22, zmm_t30); - COEX(zmm_t23, zmm_t31); - COEX(zmm_t24, zmm_t32); - // - COEX(zmm_t1, zmm_t5); - COEX(zmm_t2, zmm_t6); - COEX(zmm_t3, zmm_t7); - COEX(zmm_t4, zmm_t8); - COEX(zmm_t9, zmm_t13); - COEX(zmm_t10, zmm_t14); - COEX(zmm_t11, zmm_t15); - COEX(zmm_t12, zmm_t16); - COEX(zmm_t17, zmm_t21); - COEX(zmm_t18, zmm_t22); - COEX(zmm_t19, zmm_t23); - COEX(zmm_t20, zmm_t24); - COEX(zmm_t25, zmm_t29); - COEX(zmm_t26, zmm_t30); - COEX(zmm_t27, zmm_t31); - COEX(zmm_t28, zmm_t32); - // - COEX(zmm_t1, zmm_t3); - COEX(zmm_t2, zmm_t4); - COEX(zmm_t5, zmm_t7); - COEX(zmm_t6, zmm_t8); - COEX(zmm_t9, zmm_t11); - COEX(zmm_t10, zmm_t12); - COEX(zmm_t13, zmm_t15); - COEX(zmm_t14, zmm_t16); - COEX(zmm_t17, zmm_t19); - COEX(zmm_t18, zmm_t20); - COEX(zmm_t21, zmm_t23); - COEX(zmm_t22, zmm_t24); - COEX(zmm_t25, zmm_t27); - COEX(zmm_t26, zmm_t28); - COEX(zmm_t29, zmm_t31); - COEX(zmm_t30, zmm_t32); - // - COEX(zmm_t1, zmm_t2); - COEX(zmm_t3, zmm_t4); - COEX(zmm_t5, zmm_t6); - COEX(zmm_t7, zmm_t8); - COEX(zmm_t9, zmm_t10); - COEX(zmm_t11, zmm_t12); - COEX(zmm_t13, zmm_t14); - COEX(zmm_t15, zmm_t16); - COEX(zmm_t17, zmm_t18); - COEX(zmm_t19, zmm_t20); - COEX(zmm_t21, zmm_t22); - COEX(zmm_t23, zmm_t24); - COEX(zmm_t25, zmm_t26); - COEX(zmm_t27, zmm_t28); - COEX(zmm_t29, zmm_t30); - COEX(zmm_t31, zmm_t32); - // - zmm[0] = bitonic_merge_zmm_64bit(zmm_t1); - zmm[1] = bitonic_merge_zmm_64bit(zmm_t2); - zmm[2] = bitonic_merge_zmm_64bit(zmm_t3); - zmm[3] = bitonic_merge_zmm_64bit(zmm_t4); - zmm[4] = bitonic_merge_zmm_64bit(zmm_t5); - zmm[5] = bitonic_merge_zmm_64bit(zmm_t6); - zmm[6] = bitonic_merge_zmm_64bit(zmm_t7); - zmm[7] = bitonic_merge_zmm_64bit(zmm_t8); - zmm[8] = bitonic_merge_zmm_64bit(zmm_t9); - zmm[9] = bitonic_merge_zmm_64bit(zmm_t10); - zmm[10] = bitonic_merge_zmm_64bit(zmm_t11); - zmm[11] = bitonic_merge_zmm_64bit(zmm_t12); - zmm[12] = bitonic_merge_zmm_64bit(zmm_t13); - zmm[13] = bitonic_merge_zmm_64bit(zmm_t14); - zmm[14] = bitonic_merge_zmm_64bit(zmm_t15); - zmm[15] = bitonic_merge_zmm_64bit(zmm_t16); - zmm[16] = bitonic_merge_zmm_64bit(zmm_t17); - zmm[17] = bitonic_merge_zmm_64bit(zmm_t18); - zmm[18] = bitonic_merge_zmm_64bit(zmm_t19); - zmm[19] = bitonic_merge_zmm_64bit(zmm_t20); - zmm[20] = bitonic_merge_zmm_64bit(zmm_t21); - zmm[21] = bitonic_merge_zmm_64bit(zmm_t22); - zmm[22] = bitonic_merge_zmm_64bit(zmm_t23); - zmm[23] = bitonic_merge_zmm_64bit(zmm_t24); - zmm[24] = bitonic_merge_zmm_64bit(zmm_t25); - zmm[25] = bitonic_merge_zmm_64bit(zmm_t26); - zmm[26] = bitonic_merge_zmm_64bit(zmm_t27); - zmm[27] = bitonic_merge_zmm_64bit(zmm_t28); - zmm[28] = bitonic_merge_zmm_64bit(zmm_t29); - zmm[29] = bitonic_merge_zmm_64bit(zmm_t30); - zmm[30] = bitonic_merge_zmm_64bit(zmm_t31); - zmm[31] = bitonic_merge_zmm_64bit(zmm_t32); -} +struct avx512_64bit_swizzle_ops { + template + X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n( + typename vtype::reg_t reg) { + __m512i v = vtype::cast_to(reg); -template -X86_SIMD_SORT_INLINE void sort_8_64bit(type_t *arr, int32_t N) { - typename vtype::opmask_t load_mask = (0x01 << N) - 0x01; - typename vtype::zmm_t zmm = - vtype::mask_loadu(vtype::zmm_max(), load_mask, arr); - vtype::mask_storeu(arr, load_mask, sort_zmm_64bit(zmm)); -} - -template -X86_SIMD_SORT_INLINE void sort_16_64bit(type_t *arr, int32_t N) { - if (N <= 8) { - sort_8_64bit(arr, N); - return; - } - using zmm_t = typename vtype::zmm_t; - zmm_t zmm1 = vtype::loadu(arr); - typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; - zmm_t zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr + 8); - zmm1 = sort_zmm_64bit(zmm1); - zmm2 = sort_zmm_64bit(zmm2); - bitonic_merge_two_zmm_64bit(zmm1, zmm2); - vtype::storeu(arr, zmm1); - vtype::mask_storeu(arr + 8, load_mask, zmm2); -} - -template -X86_SIMD_SORT_INLINE void sort_32_64bit(type_t *arr, int32_t N) { - if (N <= 16) { - sort_16_64bit(arr, N); - return; - } - using zmm_t = typename vtype::zmm_t; - using opmask_t = typename vtype::opmask_t; - zmm_t zmm[4]; - zmm[0] = vtype::loadu(arr); - zmm[1] = vtype::loadu(arr + 8); - opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; - uint64_t combined_mask = (0x1ull << (N - 16)) - 0x1ull; - load_mask1 = (combined_mask)&0xFF; - load_mask2 = (combined_mask >> 8) & 0xFF; - zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 16); - zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 24); - zmm[0] = sort_zmm_64bit(zmm[0]); - zmm[1] = sort_zmm_64bit(zmm[1]); - zmm[2] = sort_zmm_64bit(zmm[2]); - zmm[3] = sort_zmm_64bit(zmm[3]); - bitonic_merge_two_zmm_64bit(zmm[0], zmm[1]); - bitonic_merge_two_zmm_64bit(zmm[2], zmm[3]); - bitonic_merge_four_zmm_64bit(zmm); - vtype::storeu(arr, zmm[0]); - vtype::storeu(arr + 8, zmm[1]); - vtype::mask_storeu(arr + 16, load_mask1, zmm[2]); - vtype::mask_storeu(arr + 24, load_mask2, zmm[3]); -} - -template -X86_SIMD_SORT_INLINE void sort_64_64bit(type_t *arr, int32_t N) { - if (N <= 32) { - sort_32_64bit(arr, N); - return; - } - using zmm_t = typename vtype::zmm_t; - using opmask_t = typename vtype::opmask_t; - zmm_t zmm[8]; - zmm[0] = vtype::loadu(arr); - zmm[1] = vtype::loadu(arr + 8); - zmm[2] = vtype::loadu(arr + 16); - zmm[3] = vtype::loadu(arr + 24); - zmm[0] = sort_zmm_64bit(zmm[0]); - zmm[1] = sort_zmm_64bit(zmm[1]); - zmm[2] = sort_zmm_64bit(zmm[2]); - zmm[3] = sort_zmm_64bit(zmm[3]); - opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; - opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; - // N-32 >= 1 - uint64_t combined_mask = (0x1ull << (N - 32)) - 0x1ull; - load_mask1 = (combined_mask)&0xFF; - load_mask2 = (combined_mask >> 8) & 0xFF; - load_mask3 = (combined_mask >> 16) & 0xFF; - load_mask4 = (combined_mask >> 24) & 0xFF; - zmm[4] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 32); - zmm[5] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 40); - zmm[6] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 48); - zmm[7] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 56); - zmm[4] = sort_zmm_64bit(zmm[4]); - zmm[5] = sort_zmm_64bit(zmm[5]); - zmm[6] = sort_zmm_64bit(zmm[6]); - zmm[7] = sort_zmm_64bit(zmm[7]); - bitonic_merge_two_zmm_64bit(zmm[0], zmm[1]); - bitonic_merge_two_zmm_64bit(zmm[2], zmm[3]); - bitonic_merge_two_zmm_64bit(zmm[4], zmm[5]); - bitonic_merge_two_zmm_64bit(zmm[6], zmm[7]); - bitonic_merge_four_zmm_64bit(zmm); - bitonic_merge_four_zmm_64bit(zmm + 4); - bitonic_merge_eight_zmm_64bit(zmm); - vtype::storeu(arr, zmm[0]); - vtype::storeu(arr + 8, zmm[1]); - vtype::storeu(arr + 16, zmm[2]); - vtype::storeu(arr + 24, zmm[3]); - vtype::mask_storeu(arr + 32, load_mask1, zmm[4]); - vtype::mask_storeu(arr + 40, load_mask2, zmm[5]); - vtype::mask_storeu(arr + 48, load_mask3, zmm[6]); - vtype::mask_storeu(arr + 56, load_mask4, zmm[7]); -} - -template -X86_SIMD_SORT_INLINE void sort_128_64bit(type_t *arr, int32_t N) { - if (N <= 64) { - sort_64_64bit(arr, N); - return; - } - using zmm_t = typename vtype::zmm_t; - using opmask_t = typename vtype::opmask_t; - zmm_t zmm[16]; - zmm[0] = vtype::loadu(arr); - zmm[1] = vtype::loadu(arr + 8); - zmm[2] = vtype::loadu(arr + 16); - zmm[3] = vtype::loadu(arr + 24); - zmm[4] = vtype::loadu(arr + 32); - zmm[5] = vtype::loadu(arr + 40); - zmm[6] = vtype::loadu(arr + 48); - zmm[7] = vtype::loadu(arr + 56); - zmm[0] = sort_zmm_64bit(zmm[0]); - zmm[1] = sort_zmm_64bit(zmm[1]); - zmm[2] = sort_zmm_64bit(zmm[2]); - zmm[3] = sort_zmm_64bit(zmm[3]); - zmm[4] = sort_zmm_64bit(zmm[4]); - zmm[5] = sort_zmm_64bit(zmm[5]); - zmm[6] = sort_zmm_64bit(zmm[6]); - zmm[7] = sort_zmm_64bit(zmm[7]); - opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; - opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; - opmask_t load_mask5 = 0xFF, load_mask6 = 0xFF; - opmask_t load_mask7 = 0xFF, load_mask8 = 0xFF; - if (N != 128) { - uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull; - load_mask1 = (combined_mask)&0xFF; - load_mask2 = (combined_mask >> 8) & 0xFF; - load_mask3 = (combined_mask >> 16) & 0xFF; - load_mask4 = (combined_mask >> 24) & 0xFF; - load_mask5 = (combined_mask >> 32) & 0xFF; - load_mask6 = (combined_mask >> 40) & 0xFF; - load_mask7 = (combined_mask >> 48) & 0xFF; - load_mask8 = (combined_mask >> 56) & 0xFF; - } - zmm[8] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 64); - zmm[9] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 72); - zmm[10] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 80); - zmm[11] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 88); - zmm[12] = vtype::mask_loadu(vtype::zmm_max(), load_mask5, arr + 96); - zmm[13] = vtype::mask_loadu(vtype::zmm_max(), load_mask6, arr + 104); - zmm[14] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, arr + 112); - zmm[15] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, arr + 120); - zmm[8] = sort_zmm_64bit(zmm[8]); - zmm[9] = sort_zmm_64bit(zmm[9]); - zmm[10] = sort_zmm_64bit(zmm[10]); - zmm[11] = sort_zmm_64bit(zmm[11]); - zmm[12] = sort_zmm_64bit(zmm[12]); - zmm[13] = sort_zmm_64bit(zmm[13]); - zmm[14] = sort_zmm_64bit(zmm[14]); - zmm[15] = sort_zmm_64bit(zmm[15]); - bitonic_merge_two_zmm_64bit(zmm[0], zmm[1]); - bitonic_merge_two_zmm_64bit(zmm[2], zmm[3]); - bitonic_merge_two_zmm_64bit(zmm[4], zmm[5]); - bitonic_merge_two_zmm_64bit(zmm[6], zmm[7]); - bitonic_merge_two_zmm_64bit(zmm[8], zmm[9]); - bitonic_merge_two_zmm_64bit(zmm[10], zmm[11]); - bitonic_merge_two_zmm_64bit(zmm[12], zmm[13]); - bitonic_merge_two_zmm_64bit(zmm[14], zmm[15]); - bitonic_merge_four_zmm_64bit(zmm); - bitonic_merge_four_zmm_64bit(zmm + 4); - bitonic_merge_four_zmm_64bit(zmm + 8); - bitonic_merge_four_zmm_64bit(zmm + 12); - bitonic_merge_eight_zmm_64bit(zmm); - bitonic_merge_eight_zmm_64bit(zmm + 8); - bitonic_merge_sixteen_zmm_64bit(zmm); - vtype::storeu(arr, zmm[0]); - vtype::storeu(arr + 8, zmm[1]); - vtype::storeu(arr + 16, zmm[2]); - vtype::storeu(arr + 24, zmm[3]); - vtype::storeu(arr + 32, zmm[4]); - vtype::storeu(arr + 40, zmm[5]); - vtype::storeu(arr + 48, zmm[6]); - vtype::storeu(arr + 56, zmm[7]); - vtype::mask_storeu(arr + 64, load_mask1, zmm[8]); - vtype::mask_storeu(arr + 72, load_mask2, zmm[9]); - vtype::mask_storeu(arr + 80, load_mask3, zmm[10]); - vtype::mask_storeu(arr + 88, load_mask4, zmm[11]); - vtype::mask_storeu(arr + 96, load_mask5, zmm[12]); - vtype::mask_storeu(arr + 104, load_mask6, zmm[13]); - vtype::mask_storeu(arr + 112, load_mask7, zmm[14]); - vtype::mask_storeu(arr + 120, load_mask8, zmm[15]); -} - -template -X86_SIMD_SORT_INLINE void sort_256_64bit(type_t *arr, int32_t N) { - if (N <= 128) { - sort_128_64bit(arr, N); - return; - } - using zmm_t = typename vtype::zmm_t; - using opmask_t = typename vtype::opmask_t; - zmm_t zmm[32]; - zmm[0] = vtype::loadu(arr); - zmm[1] = vtype::loadu(arr + 8); - zmm[2] = vtype::loadu(arr + 16); - zmm[3] = vtype::loadu(arr + 24); - zmm[4] = vtype::loadu(arr + 32); - zmm[5] = vtype::loadu(arr + 40); - zmm[6] = vtype::loadu(arr + 48); - zmm[7] = vtype::loadu(arr + 56); - zmm[8] = vtype::loadu(arr + 64); - zmm[9] = vtype::loadu(arr + 72); - zmm[10] = vtype::loadu(arr + 80); - zmm[11] = vtype::loadu(arr + 88); - zmm[12] = vtype::loadu(arr + 96); - zmm[13] = vtype::loadu(arr + 104); - zmm[14] = vtype::loadu(arr + 112); - zmm[15] = vtype::loadu(arr + 120); - zmm[0] = sort_zmm_64bit(zmm[0]); - zmm[1] = sort_zmm_64bit(zmm[1]); - zmm[2] = sort_zmm_64bit(zmm[2]); - zmm[3] = sort_zmm_64bit(zmm[3]); - zmm[4] = sort_zmm_64bit(zmm[4]); - zmm[5] = sort_zmm_64bit(zmm[5]); - zmm[6] = sort_zmm_64bit(zmm[6]); - zmm[7] = sort_zmm_64bit(zmm[7]); - zmm[8] = sort_zmm_64bit(zmm[8]); - zmm[9] = sort_zmm_64bit(zmm[9]); - zmm[10] = sort_zmm_64bit(zmm[10]); - zmm[11] = sort_zmm_64bit(zmm[11]); - zmm[12] = sort_zmm_64bit(zmm[12]); - zmm[13] = sort_zmm_64bit(zmm[13]); - zmm[14] = sort_zmm_64bit(zmm[14]); - zmm[15] = sort_zmm_64bit(zmm[15]); - opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; - opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; - opmask_t load_mask5 = 0xFF, load_mask6 = 0xFF; - opmask_t load_mask7 = 0xFF, load_mask8 = 0xFF; - opmask_t load_mask9 = 0xFF, load_mask10 = 0xFF; - opmask_t load_mask11 = 0xFF, load_mask12 = 0xFF; - opmask_t load_mask13 = 0xFF, load_mask14 = 0xFF; - opmask_t load_mask15 = 0xFF, load_mask16 = 0xFF; - if (N != 256) { - uint64_t combined_mask; - if (N < 192) { - combined_mask = (0x1ull << (N - 128)) - 0x1ull; - load_mask1 = (combined_mask)&0xFF; - load_mask2 = (combined_mask >> 8) & 0xFF; - load_mask3 = (combined_mask >> 16) & 0xFF; - load_mask4 = (combined_mask >> 24) & 0xFF; - load_mask5 = (combined_mask >> 32) & 0xFF; - load_mask6 = (combined_mask >> 40) & 0xFF; - load_mask7 = (combined_mask >> 48) & 0xFF; - load_mask8 = (combined_mask >> 56) & 0xFF; - load_mask9 = 0x00; - load_mask10 = 0x0; - load_mask11 = 0x00; - load_mask12 = 0x00; - load_mask13 = 0x00; - load_mask14 = 0x00; - load_mask15 = 0x00; - load_mask16 = 0x00; + if constexpr (scale == 2) { + v = _mm512_shuffle_epi32(v, (_MM_PERM_ENUM)0b01001110); + } else if constexpr (scale == 4) { + v = _mm512_shuffle_i64x2(v, v, 0b10110001); + } else if constexpr (scale == 8) { + v = _mm512_shuffle_i64x2(v, v, 0b01001110); } else { - combined_mask = (0x1ull << (N - 192)) - 0x1ull; - load_mask9 = (combined_mask)&0xFF; - load_mask10 = (combined_mask >> 8) & 0xFF; - load_mask11 = (combined_mask >> 16) & 0xFF; - load_mask12 = (combined_mask >> 24) & 0xFF; - load_mask13 = (combined_mask >> 32) & 0xFF; - load_mask14 = (combined_mask >> 40) & 0xFF; - load_mask15 = (combined_mask >> 48) & 0xFF; - load_mask16 = (combined_mask >> 56) & 0xFF; + static_assert(scale == -1, "should not be reached"); } - } - zmm[16] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 128); - zmm[17] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 136); - zmm[18] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 144); - zmm[19] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 152); - zmm[20] = vtype::mask_loadu(vtype::zmm_max(), load_mask5, arr + 160); - zmm[21] = vtype::mask_loadu(vtype::zmm_max(), load_mask6, arr + 168); - zmm[22] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, arr + 176); - zmm[23] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, arr + 184); - if (N < 192) { - zmm[24] = vtype::zmm_max(); - zmm[25] = vtype::zmm_max(); - zmm[26] = vtype::zmm_max(); - zmm[27] = vtype::zmm_max(); - zmm[28] = vtype::zmm_max(); - zmm[29] = vtype::zmm_max(); - zmm[30] = vtype::zmm_max(); - zmm[31] = vtype::zmm_max(); - } else { - zmm[24] = vtype::mask_loadu(vtype::zmm_max(), load_mask9, arr + 192); - zmm[25] = vtype::mask_loadu(vtype::zmm_max(), load_mask10, arr + 200); - zmm[26] = vtype::mask_loadu(vtype::zmm_max(), load_mask11, arr + 208); - zmm[27] = vtype::mask_loadu(vtype::zmm_max(), load_mask12, arr + 216); - zmm[28] = vtype::mask_loadu(vtype::zmm_max(), load_mask13, arr + 224); - zmm[29] = vtype::mask_loadu(vtype::zmm_max(), load_mask14, arr + 232); - zmm[30] = vtype::mask_loadu(vtype::zmm_max(), load_mask15, arr + 240); - zmm[31] = vtype::mask_loadu(vtype::zmm_max(), load_mask16, arr + 248); - } - zmm[16] = sort_zmm_64bit(zmm[16]); - zmm[17] = sort_zmm_64bit(zmm[17]); - zmm[18] = sort_zmm_64bit(zmm[18]); - zmm[19] = sort_zmm_64bit(zmm[19]); - zmm[20] = sort_zmm_64bit(zmm[20]); - zmm[21] = sort_zmm_64bit(zmm[21]); - zmm[22] = sort_zmm_64bit(zmm[22]); - zmm[23] = sort_zmm_64bit(zmm[23]); - zmm[24] = sort_zmm_64bit(zmm[24]); - zmm[25] = sort_zmm_64bit(zmm[25]); - zmm[26] = sort_zmm_64bit(zmm[26]); - zmm[27] = sort_zmm_64bit(zmm[27]); - zmm[28] = sort_zmm_64bit(zmm[28]); - zmm[29] = sort_zmm_64bit(zmm[29]); - zmm[30] = sort_zmm_64bit(zmm[30]); - zmm[31] = sort_zmm_64bit(zmm[31]); - bitonic_merge_two_zmm_64bit(zmm[0], zmm[1]); - bitonic_merge_two_zmm_64bit(zmm[2], zmm[3]); - bitonic_merge_two_zmm_64bit(zmm[4], zmm[5]); - bitonic_merge_two_zmm_64bit(zmm[6], zmm[7]); - bitonic_merge_two_zmm_64bit(zmm[8], zmm[9]); - bitonic_merge_two_zmm_64bit(zmm[10], zmm[11]); - bitonic_merge_two_zmm_64bit(zmm[12], zmm[13]); - bitonic_merge_two_zmm_64bit(zmm[14], zmm[15]); - bitonic_merge_two_zmm_64bit(zmm[16], zmm[17]); - bitonic_merge_two_zmm_64bit(zmm[18], zmm[19]); - bitonic_merge_two_zmm_64bit(zmm[20], zmm[21]); - bitonic_merge_two_zmm_64bit(zmm[22], zmm[23]); - bitonic_merge_two_zmm_64bit(zmm[24], zmm[25]); - bitonic_merge_two_zmm_64bit(zmm[26], zmm[27]); - bitonic_merge_two_zmm_64bit(zmm[28], zmm[29]); - bitonic_merge_two_zmm_64bit(zmm[30], zmm[31]); - bitonic_merge_four_zmm_64bit(zmm); - bitonic_merge_four_zmm_64bit(zmm + 4); - bitonic_merge_four_zmm_64bit(zmm + 8); - bitonic_merge_four_zmm_64bit(zmm + 12); - bitonic_merge_four_zmm_64bit(zmm + 16); - bitonic_merge_four_zmm_64bit(zmm + 20); - bitonic_merge_four_zmm_64bit(zmm + 24); - bitonic_merge_four_zmm_64bit(zmm + 28); - bitonic_merge_eight_zmm_64bit(zmm); - bitonic_merge_eight_zmm_64bit(zmm + 8); - bitonic_merge_eight_zmm_64bit(zmm + 16); - bitonic_merge_eight_zmm_64bit(zmm + 24); - bitonic_merge_sixteen_zmm_64bit(zmm); - bitonic_merge_sixteen_zmm_64bit(zmm + 16); - bitonic_merge_32_zmm_64bit(zmm); - vtype::storeu(arr, zmm[0]); - vtype::storeu(arr + 8, zmm[1]); - vtype::storeu(arr + 16, zmm[2]); - vtype::storeu(arr + 24, zmm[3]); - vtype::storeu(arr + 32, zmm[4]); - vtype::storeu(arr + 40, zmm[5]); - vtype::storeu(arr + 48, zmm[6]); - vtype::storeu(arr + 56, zmm[7]); - vtype::storeu(arr + 64, zmm[8]); - vtype::storeu(arr + 72, zmm[9]); - vtype::storeu(arr + 80, zmm[10]); - vtype::storeu(arr + 88, zmm[11]); - vtype::storeu(arr + 96, zmm[12]); - vtype::storeu(arr + 104, zmm[13]); - vtype::storeu(arr + 112, zmm[14]); - vtype::storeu(arr + 120, zmm[15]); - vtype::mask_storeu(arr + 128, load_mask1, zmm[16]); - vtype::mask_storeu(arr + 136, load_mask2, zmm[17]); - vtype::mask_storeu(arr + 144, load_mask3, zmm[18]); - vtype::mask_storeu(arr + 152, load_mask4, zmm[19]); - vtype::mask_storeu(arr + 160, load_mask5, zmm[20]); - vtype::mask_storeu(arr + 168, load_mask6, zmm[21]); - vtype::mask_storeu(arr + 176, load_mask7, zmm[22]); - vtype::mask_storeu(arr + 184, load_mask8, zmm[23]); - if (N > 192) { - vtype::mask_storeu(arr + 192, load_mask9, zmm[24]); - vtype::mask_storeu(arr + 200, load_mask10, zmm[25]); - vtype::mask_storeu(arr + 208, load_mask11, zmm[26]); - vtype::mask_storeu(arr + 216, load_mask12, zmm[27]); - vtype::mask_storeu(arr + 224, load_mask13, zmm[28]); - vtype::mask_storeu(arr + 232, load_mask14, zmm[29]); - vtype::mask_storeu(arr + 240, load_mask15, zmm[30]); - vtype::mask_storeu(arr + 248, load_mask16, zmm[31]); - } -} -template -static void qsort_64bit_(type_t *arr, int64_t left, int64_t right, - int64_t max_iters) { - /* - * Resort to std::sort if quicksort isnt making any progress - */ - if (max_iters <= 0) { - std::sort(arr + left, arr + right + 1); - return; - } - /* - * Base case: use bitonic networks to sort arrays <= 128 - */ - if (right + 1 - left <= 256) { - sort_256_64bit(arr + left, (int32_t)(right + 1 - left)); - return; + return vtype::cast_from(v); } - type_t pivot = get_pivot_scalar(arr, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512_unrolled( - arr, left, right + 1, pivot, &smallest, &biggest, false); - if (pivot != smallest) - qsort_64bit_(arr, left, pivot_index - 1, max_iters - 1); - if (pivot != biggest) - qsort_64bit_(arr, pivot_index, right, max_iters - 1); -} + template + X86_SIMD_SORT_INLINE typename vtype::reg_t reverse_n( + typename vtype::reg_t reg) { + __m512i v = vtype::cast_to(reg); -template <> -void inline avx512_qsort(int64_t *arr, int64_t fromIndex, int64_t toIndex) { - int64_t arrsize = toIndex - fromIndex; - if (arrsize > 1) { - qsort_64bit_, int64_t>(arr, fromIndex, toIndex - 1, - 2 * (int64_t)log2(arrsize)); + if constexpr (scale == 2) { + return swap_n(reg); + } else if constexpr (scale == 4) { + constexpr uint64_t mask = 0b00011011; + v = _mm512_permutex_epi64(v, mask); + } else if constexpr (scale == 8) { + return vtype::reverse(reg); + } else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); } -} -template <> -void inline avx512_qsort(double *arr, int64_t fromIndex, int64_t toIndex) { - int64_t arrsize = toIndex - fromIndex; - if (arrsize > 1) { - qsort_64bit_, double>(arr, fromIndex, toIndex - 1, - 2 * (int64_t)log2(arrsize)); + template + X86_SIMD_SORT_INLINE typename vtype::reg_t merge_n( + typename vtype::reg_t reg, typename vtype::reg_t other) { + __m512i v1 = vtype::cast_to(reg); + __m512i v2 = vtype::cast_to(other); + + if constexpr (scale == 2) { + v1 = _mm512_mask_blend_epi64(0b01010101, v1, v2); + } else if constexpr (scale == 4) { + v1 = _mm512_mask_blend_epi64(0b00110011, v1, v2); + } else if constexpr (scale == 8) { + v1 = _mm512_mask_blend_epi64(0b00001111, v1, v2); + } else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v1); } -} +}; -#endif // AVX512_QSORT_64BIT +#endif diff --git a/src/java.base/linux/native/libsimdsort/avx512-common-qsort.h b/src/java.base/linux/native/libsimdsort/avx512-common-qsort.h deleted file mode 100644 index 7e1c1e31a31..00000000000 --- a/src/java.base/linux/native/libsimdsort/avx512-common-qsort.h +++ /dev/null @@ -1,483 +0,0 @@ -/* - * Copyright (c) 2021, 2023, Intel Corporation. All rights reserved. - * Copyright (c) 2021 Serge Sans Paille. All rights reserved. - * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. - * - * This code is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License version 2 only, as - * published by the Free Software Foundation. - * - * This code is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - * version 2 for more details (a copy is included in the LICENSE file that - * accompanied this code). - * - * You should have received a copy of the GNU General Public License version - * 2 along with this work; if not, write to the Free Software Foundation, - * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. - * - * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA - * or visit www.oracle.com if you need additional information or have any - * questions. - * - */ - -// This implementation is based on x86-simd-sort(https://github.com/intel/x86-simd-sort) -#ifndef AVX512_QSORT_COMMON -#define AVX512_QSORT_COMMON - -/* - * Quicksort using AVX-512. The ideas and code are based on these two research - * papers [1] and [2]. On a high level, the idea is to vectorize quicksort - * partitioning using AVX-512 compressstore instructions. If the array size is - * < 128, then use Bitonic sorting network implemented on 512-bit registers. - * The precise network definitions depend on the dtype and are defined in - * separate files: avx512-16bit-qsort.hpp, avx512-32bit-qsort.hpp and - * avx512-64bit-qsort.hpp. Article [4] is a good resource for bitonic sorting - * network. The core implementations of the vectorized qsort functions - * avx512_qsort(T*, int64_t) are modified versions of avx2 quicksort - * presented in the paper [2] and source code associated with that paper [3]. - * - * [1] Fast and Robust Vectorized In-Place Sorting of Primitive Types - * https://drops.dagstuhl.de/opus/volltexte/2021/13775/ - * - * [2] A Novel Hybrid Quicksort Algorithm Vectorized using AVX-512 on Intel - * Skylake https://arxiv.org/pdf/1704.08579.pdf - * - * [3] https://github.com/simd-sorting/fast-and-robust: SPDX-License-Identifier: - * MIT - * - * [4] - * http://mitp-content-server.mit.edu:18180/books/content/sectbyfn?collid=books_pres_0&fn=Chapter%2027.pdf&id=8030 - * - */ - -#include -#include -#include -#include -#include - -/* -Workaround for the bug in GCC12 (that was fixed in GCC 12.3.1). -More details are available at: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105593 -*/ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" -#pragma GCC diagnostic ignored "-Wuninitialized" -#include -#pragma GCC diagnostic pop - -#define X86_SIMD_SORT_INFINITY std::numeric_limits::infinity() -#define X86_SIMD_SORT_INFINITYF std::numeric_limits::infinity() -#define X86_SIMD_SORT_INFINITYH 0x7c00 -#define X86_SIMD_SORT_NEGINFINITYH 0xfc00 -#define X86_SIMD_SORT_MAX_UINT16 std::numeric_limits::max() -#define X86_SIMD_SORT_MAX_INT16 std::numeric_limits::max() -#define X86_SIMD_SORT_MIN_INT16 std::numeric_limits::min() -#define X86_SIMD_SORT_MAX_UINT32 std::numeric_limits::max() -#define X86_SIMD_SORT_MAX_INT32 std::numeric_limits::max() -#define X86_SIMD_SORT_MIN_INT32 std::numeric_limits::min() -#define X86_SIMD_SORT_MAX_UINT64 std::numeric_limits::max() -#define X86_SIMD_SORT_MAX_INT64 std::numeric_limits::max() -#define X86_SIMD_SORT_MIN_INT64 std::numeric_limits::min() -#define ZMM_MAX_DOUBLE _mm512_set1_pd(X86_SIMD_SORT_INFINITY) -#define ZMM_MAX_UINT64 _mm512_set1_epi64(X86_SIMD_SORT_MAX_UINT64) -#define ZMM_MAX_INT64 _mm512_set1_epi64(X86_SIMD_SORT_MAX_INT64) -#define ZMM_MAX_FLOAT _mm512_set1_ps(X86_SIMD_SORT_INFINITYF) -#define ZMM_MAX_UINT _mm512_set1_epi32(X86_SIMD_SORT_MAX_UINT32) -#define ZMM_MAX_INT _mm512_set1_epi32(X86_SIMD_SORT_MAX_INT32) -#define ZMM_MAX_HALF _mm512_set1_epi16(X86_SIMD_SORT_INFINITYH) -#define YMM_MAX_HALF _mm256_set1_epi16(X86_SIMD_SORT_INFINITYH) -#define ZMM_MAX_UINT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_UINT16) -#define ZMM_MAX_INT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_INT16) -#define SHUFFLE_MASK(a, b, c, d) (a << 6) | (b << 4) | (c << 2) | d - -#ifdef _MSC_VER -#define X86_SIMD_SORT_INLINE static inline -#define X86_SIMD_SORT_FINLINE static __forceinline -#elif defined(__CYGWIN__) -/* - * Force inline in cygwin to work around a compiler bug. See - * https://github.com/numpy/numpy/pull/22315#issuecomment-1267757584 - */ -#define X86_SIMD_SORT_INLINE static __attribute__((always_inline)) -#define X86_SIMD_SORT_FINLINE static __attribute__((always_inline)) -#elif defined(__GNUC__) -#define X86_SIMD_SORT_INLINE static inline -#define X86_SIMD_SORT_FINLINE static __attribute__((always_inline)) -#else -#define X86_SIMD_SORT_INLINE static -#define X86_SIMD_SORT_FINLINE static -#endif - -#define LIKELY(x) __builtin_expect((x), 1) -#define UNLIKELY(x) __builtin_expect((x), 0) - -template -struct zmm_vector; - -template -struct ymm_vector; - -// Regular quicksort routines: -template -void avx512_qsort(T *arr, int64_t arrsize); - -template -void inline avx512_qsort(T *arr, int64_t from_index, int64_t to_index); - -template -bool is_a_nan(T elem) { - return std::isnan(elem); -} - -template -X86_SIMD_SORT_INLINE T get_pivot_scalar(T *arr, const int64_t left, const int64_t right) { - // median of 8 equally spaced elements - int64_t NUM_ELEMENTS = 8; - int64_t MID = NUM_ELEMENTS / 2; - int64_t size = (right - left) / NUM_ELEMENTS; - T temp[NUM_ELEMENTS]; - for (int64_t i = 0; i < NUM_ELEMENTS; i++) temp[i] = arr[left + (i * size)]; - std::sort(temp, temp + NUM_ELEMENTS); - return temp[MID]; -} - -template -bool comparison_func_ge(const T &a, const T &b) { - return a < b; -} - -template -bool comparison_func_gt(const T &a, const T &b) { - return a <= b; -} - -/* - * COEX == Compare and Exchange two registers by swapping min and max values - */ -template -static void COEX(mm_t &a, mm_t &b) { - mm_t temp = a; - a = vtype::min(a, b); - b = vtype::max(temp, b); -} -template -static inline zmm_t cmp_merge(zmm_t in1, zmm_t in2, opmask_t mask) { - zmm_t min = vtype::min(in2, in1); - zmm_t max = vtype::max(in2, in1); - return vtype::mask_mov(min, mask, max); // 0 -> min, 1 -> max -} -/* - * Parition one ZMM register based on the pivot and returns the - * number of elements that are greater than or equal to the pivot. - */ -template -static inline int32_t partition_vec(type_t *arr, int64_t left, int64_t right, - const zmm_t curr_vec, const zmm_t pivot_vec, - zmm_t *smallest_vec, zmm_t *biggest_vec, bool use_gt) { - /* which elements are larger than or equal to the pivot */ - typename vtype::opmask_t mask; - if (use_gt) mask = vtype::gt(curr_vec, pivot_vec); - else mask = vtype::ge(curr_vec, pivot_vec); - //mask = vtype::ge(curr_vec, pivot_vec); - int32_t amount_ge_pivot = _mm_popcnt_u32((int32_t)mask); - vtype::mask_compressstoreu(arr + left, vtype::knot_opmask(mask), - curr_vec); - vtype::mask_compressstoreu(arr + right - amount_ge_pivot, mask, - curr_vec); - *smallest_vec = vtype::min(curr_vec, *smallest_vec); - *biggest_vec = vtype::max(curr_vec, *biggest_vec); - return amount_ge_pivot; -} -/* - * Parition an array based on the pivot and returns the index of the - * first element that is greater than or equal to the pivot. - */ -template -static inline int64_t partition_avx512(type_t *arr, int64_t left, int64_t right, - type_t pivot, type_t *smallest, - type_t *biggest, bool use_gt) { - auto comparison_func = use_gt ? comparison_func_gt : comparison_func_ge; - /* make array length divisible by vtype::numlanes , shortening the array */ - for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { - *smallest = std::min(*smallest, arr[left], comparison_func); - *biggest = std::max(*biggest, arr[left], comparison_func); - if (!comparison_func(arr[left], pivot)) { - std::swap(arr[left], arr[--right]); - } else { - ++left; - } - } - - if (left == right) - return left; /* less than vtype::numlanes elements in the array */ - - using zmm_t = typename vtype::zmm_t; - zmm_t pivot_vec = vtype::set1(pivot); - zmm_t min_vec = vtype::set1(*smallest); - zmm_t max_vec = vtype::set1(*biggest); - - if (right - left == vtype::numlanes) { - zmm_t vec = vtype::loadu(arr + left); - int32_t amount_ge_pivot = - partition_vec(arr, left, left + vtype::numlanes, vec, - pivot_vec, &min_vec, &max_vec, use_gt); - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return left + (vtype::numlanes - amount_ge_pivot); - } - - // first and last vtype::numlanes values are partitioned at the end - zmm_t vec_left = vtype::loadu(arr + left); - zmm_t vec_right = vtype::loadu(arr + (right - vtype::numlanes)); - // store points of the vectors - int64_t r_store = right - vtype::numlanes; - int64_t l_store = left; - // indices for loading the elements - left += vtype::numlanes; - right -= vtype::numlanes; - while (right - left != 0) { - zmm_t curr_vec; - /* - * if fewer elements are stored on the right side of the array, - * then next elements are loaded from the right side, - * otherwise from the left side - */ - if ((r_store + vtype::numlanes) - right < left - l_store) { - right -= vtype::numlanes; - curr_vec = vtype::loadu(arr + right); - } else { - curr_vec = vtype::loadu(arr + left); - left += vtype::numlanes; - } - // partition the current vector and save it on both sides of the array - int32_t amount_ge_pivot = - partition_vec(arr, l_store, r_store + vtype::numlanes, - curr_vec, pivot_vec, &min_vec, &max_vec, use_gt); - ; - r_store -= amount_ge_pivot; - l_store += (vtype::numlanes - amount_ge_pivot); - } - - /* partition and save vec_left and vec_right */ - int32_t amount_ge_pivot = - partition_vec(arr, l_store, r_store + vtype::numlanes, vec_left, - pivot_vec, &min_vec, &max_vec, use_gt); - l_store += (vtype::numlanes - amount_ge_pivot); - amount_ge_pivot = - partition_vec(arr, l_store, l_store + vtype::numlanes, vec_right, - pivot_vec, &min_vec, &max_vec, use_gt); - l_store += (vtype::numlanes - amount_ge_pivot); - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return l_store; -} - -template -static inline int64_t partition_avx512_unrolled(type_t *arr, int64_t left, - int64_t right, type_t pivot, - type_t *smallest, - type_t *biggest, bool use_gt) { - if (right - left <= 2 * num_unroll * vtype::numlanes) { - return partition_avx512(arr, left, right, pivot, smallest, - biggest, use_gt); - } - - auto comparison_func = use_gt ? comparison_func_gt : comparison_func_ge; - /* make array length divisible by 8*vtype::numlanes , shortening the array - */ - for (int32_t i = ((right - left) % (num_unroll * vtype::numlanes)); i > 0; - --i) { - *smallest = std::min(*smallest, arr[left], comparison_func); - *biggest = std::max(*biggest, arr[left], comparison_func); - if (!comparison_func(arr[left], pivot)) { - std::swap(arr[left], arr[--right]); - } else { - ++left; - } - } - - if (left == right) - return left; /* less than vtype::numlanes elements in the array */ - - using zmm_t = typename vtype::zmm_t; - zmm_t pivot_vec = vtype::set1(pivot); - zmm_t min_vec = vtype::set1(*smallest); - zmm_t max_vec = vtype::set1(*biggest); - - // We will now have atleast 16 registers worth of data to process: - // left and right vtype::numlanes values are partitioned at the end - zmm_t vec_left[num_unroll], vec_right[num_unroll]; -#pragma GCC unroll 8 - for (int ii = 0; ii < num_unroll; ++ii) { - vec_left[ii] = vtype::loadu(arr + left + vtype::numlanes * ii); - vec_right[ii] = - vtype::loadu(arr + (right - vtype::numlanes * (num_unroll - ii))); - } - // store points of the vectors - int64_t r_store = right - vtype::numlanes; - int64_t l_store = left; - // indices for loading the elements - left += num_unroll * vtype::numlanes; - right -= num_unroll * vtype::numlanes; - while (right - left != 0) { - zmm_t curr_vec[num_unroll]; - /* - * if fewer elements are stored on the right side of the array, - * then next elements are loaded from the right side, - * otherwise from the left side - */ - if ((r_store + vtype::numlanes) - right < left - l_store) { - right -= num_unroll * vtype::numlanes; -#pragma GCC unroll 8 - for (int ii = 0; ii < num_unroll; ++ii) { - curr_vec[ii] = vtype::loadu(arr + right + ii * vtype::numlanes); - } - } else { -#pragma GCC unroll 8 - for (int ii = 0; ii < num_unroll; ++ii) { - curr_vec[ii] = vtype::loadu(arr + left + ii * vtype::numlanes); - } - left += num_unroll * vtype::numlanes; - } -// partition the current vector and save it on both sides of the array -#pragma GCC unroll 8 - for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_ge_pivot = partition_vec( - arr, l_store, r_store + vtype::numlanes, curr_vec[ii], - pivot_vec, &min_vec, &max_vec, use_gt); - l_store += (vtype::numlanes - amount_ge_pivot); - r_store -= amount_ge_pivot; - } - } - -/* partition and save vec_left[8] and vec_right[8] */ -#pragma GCC unroll 8 - for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_ge_pivot = - partition_vec(arr, l_store, r_store + vtype::numlanes, - vec_left[ii], pivot_vec, &min_vec, &max_vec, use_gt); - l_store += (vtype::numlanes - amount_ge_pivot); - r_store -= amount_ge_pivot; - } -#pragma GCC unroll 8 - for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_ge_pivot = - partition_vec(arr, l_store, r_store + vtype::numlanes, - vec_right[ii], pivot_vec, &min_vec, &max_vec, use_gt); - l_store += (vtype::numlanes - amount_ge_pivot); - r_store -= amount_ge_pivot; - } - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return l_store; -} - -// to_index (exclusive) -template -static int64_t vectorized_partition(type_t *arr, int64_t from_index, int64_t to_index, type_t pivot, bool use_gt) { - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512_unrolled( - arr, from_index, to_index, pivot, &smallest, &biggest, use_gt); - return pivot_index; -} - -// partitioning functions -template -void avx512_dual_pivot_partition(T *arr, int64_t from_index, int64_t to_index, int32_t *pivot_indices, int64_t index_pivot1, int64_t index_pivot2){ - const T pivot1 = arr[index_pivot1]; - const T pivot2 = arr[index_pivot2]; - - const int64_t low = from_index; - const int64_t high = to_index; - const int64_t start = low + 1; - const int64_t end = high - 1; - - - std::swap(arr[index_pivot1], arr[low]); - std::swap(arr[index_pivot2], arr[end]); - - - const int64_t pivot_index2 = vectorized_partition, T>(arr, start, end, pivot2, true); // use_gt = true - std::swap(arr[end], arr[pivot_index2]); - int64_t upper = pivot_index2; - - // if all other elements are greater than pivot2 (and pivot1), no need to do further partitioning - if (upper == start) { - pivot_indices[0] = low; - pivot_indices[1] = upper; - return; - } - - const int64_t pivot_index1 = vectorized_partition, T>(arr, start, upper, pivot1, false); // use_ge (use_gt = false) - int64_t lower = pivot_index1 - 1; - std::swap(arr[low], arr[lower]); - - pivot_indices[0] = lower; - pivot_indices[1] = upper; -} - -template -void avx512_single_pivot_partition(T *arr, int64_t from_index, int64_t to_index, int32_t *pivot_indices, int64_t index_pivot){ - const T pivot = arr[index_pivot]; - - const int64_t low = from_index; - const int64_t high = to_index; - const int64_t end = high - 1; - - - const int64_t pivot_index1 = vectorized_partition, T>(arr, low, high, pivot, false); // use_gt = false (use_ge) - int64_t lower = pivot_index1; - - const int64_t pivot_index2 = vectorized_partition, T>(arr, pivot_index1, high, pivot, true); // use_gt = true - int64_t upper = pivot_index2; - - pivot_indices[0] = lower; - pivot_indices[1] = upper; -} - -template -void inline avx512_fast_partition(T *arr, int64_t from_index, int64_t to_index, int32_t *pivot_indices, int64_t index_pivot1, int64_t index_pivot2) { - if (index_pivot1 != index_pivot2) { - avx512_dual_pivot_partition(arr, from_index, to_index, pivot_indices, index_pivot1, index_pivot2); - } - else { - avx512_single_pivot_partition(arr, from_index, to_index, pivot_indices, index_pivot1); - } -} - -template -void inline insertion_sort(T *arr, int32_t from_index, int32_t to_index) { - for (int i, k = from_index; ++k < to_index; ) { - T ai = arr[i = k]; - - if (ai < arr[i - 1]) { - while (--i >= from_index && ai < arr[i]) { - arr[i + 1] = arr[i]; - } - arr[i + 1] = ai; - } - } -} - -template -void inline avx512_fast_sort(T *arr, int64_t from_index, int64_t to_index, const int32_t INS_SORT_THRESHOLD) { - int32_t size = to_index - from_index; - - if (size <= INS_SORT_THRESHOLD) { - insertion_sort(arr, from_index, to_index); - } - else { - avx512_qsort(arr, from_index, to_index); - } -} - - - -#endif // AVX512_QSORT_COMMON diff --git a/src/java.base/linux/native/libsimdsort/avx512-linux-qsort.cpp b/src/java.base/linux/native/libsimdsort/avx512-linux-qsort.cpp index 6bd0c5871d6..35b71c421a5 100644 --- a/src/java.base/linux/native/libsimdsort/avx512-linux-qsort.cpp +++ b/src/java.base/linux/native/libsimdsort/avx512-linux-qsort.cpp @@ -21,12 +21,15 @@ * questions. * */ +#include "simdsort-support.hpp" +#ifdef __SIMDSORT_SUPPORTED_LINUX #pragma GCC target("avx512dq", "avx512f") #include "avx512-32bit-qsort.hpp" #include "avx512-64bit-qsort.hpp" #include "classfile_constants.h" + #define DLL_PUBLIC __attribute__((visibility("default"))) #define INSERTION_SORT_THRESHOLD_32BIT 16 #define INSERTION_SORT_THRESHOLD_64BIT 20 @@ -36,35 +39,41 @@ extern "C" { DLL_PUBLIC void avx512_sort(void *array, int elem_type, int32_t from_index, int32_t to_index) { switch(elem_type) { case JVM_T_INT: - avx512_fast_sort((int32_t*)array, from_index, to_index, INSERTION_SORT_THRESHOLD_32BIT); + avx512_fast_sort((int32_t*)array, from_index, to_index, INSERTION_SORT_THRESHOLD_32BIT); break; case JVM_T_LONG: - avx512_fast_sort((int64_t*)array, from_index, to_index, INSERTION_SORT_THRESHOLD_64BIT); + avx512_fast_sort((int64_t*)array, from_index, to_index, INSERTION_SORT_THRESHOLD_64BIT); break; case JVM_T_FLOAT: - avx512_fast_sort((float*)array, from_index, to_index, INSERTION_SORT_THRESHOLD_32BIT); + avx512_fast_sort((float*)array, from_index, to_index, INSERTION_SORT_THRESHOLD_32BIT); break; case JVM_T_DOUBLE: - avx512_fast_sort((double*)array, from_index, to_index, INSERTION_SORT_THRESHOLD_64BIT); + avx512_fast_sort((double*)array, from_index, to_index, INSERTION_SORT_THRESHOLD_64BIT); break; + default: + assert(false, "Unexpected type"); } } DLL_PUBLIC void avx512_partition(void *array, int elem_type, int32_t from_index, int32_t to_index, int32_t *pivot_indices, int32_t index_pivot1, int32_t index_pivot2) { switch(elem_type) { case JVM_T_INT: - avx512_fast_partition((int32_t*)array, from_index, to_index, pivot_indices, index_pivot1, index_pivot2); + avx512_fast_partition((int32_t*)array, from_index, to_index, pivot_indices, index_pivot1, index_pivot2); break; case JVM_T_LONG: - avx512_fast_partition((int64_t*)array, from_index, to_index, pivot_indices, index_pivot1, index_pivot2); + avx512_fast_partition((int64_t*)array, from_index, to_index, pivot_indices, index_pivot1, index_pivot2); break; case JVM_T_FLOAT: - avx512_fast_partition((float*)array, from_index, to_index, pivot_indices, index_pivot1, index_pivot2); + avx512_fast_partition((float*)array, from_index, to_index, pivot_indices, index_pivot1, index_pivot2); break; case JVM_T_DOUBLE: - avx512_fast_partition((double*)array, from_index, to_index, pivot_indices, index_pivot1, index_pivot2); + avx512_fast_partition((double*)array, from_index, to_index, pivot_indices, index_pivot1, index_pivot2); break; + default: + assert(false, "Unexpected type"); } } } + +#endif diff --git a/src/java.base/linux/native/libsimdsort/simdsort-support.hpp b/src/java.base/linux/native/libsimdsort/simdsort-support.hpp new file mode 100644 index 00000000000..c73fd7920d8 --- /dev/null +++ b/src/java.base/linux/native/libsimdsort/simdsort-support.hpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023 Intel Corporation. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + * + */ + +#ifndef SIMDSORT_SUPPORT_HPP +#define SIMDSORT_SUPPORT_HPP +#include +#include + +#undef assert +#define assert(cond, msg) { if (!(cond)) { fprintf(stderr, "assert fails %s %d: %s\n", __FILE__, __LINE__, msg); abort(); }} + + +// GCC >= 7.5 is needed to build AVX2 portions of libsimdsort using C++17 features +#if defined(_LP64) && (defined(__GNUC__) && ((__GNUC__ > 7) || ((__GNUC__ == 7) && (__GNUC_MINOR__ >= 5)))) +#define __SIMDSORT_SUPPORTED_LINUX +#endif + +#endif //SIMDSORT_SUPPORT_HPP \ No newline at end of file diff --git a/src/java.base/linux/native/libsimdsort/xss-common-includes.h b/src/java.base/linux/native/libsimdsort/xss-common-includes.h new file mode 100644 index 00000000000..68121cf1b7d --- /dev/null +++ b/src/java.base/linux/native/libsimdsort/xss-common-includes.h @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2021, 2023, Intel Corporation. All rights reserved. + * Copyright (c) 2021 Serge Sans Paille. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + * + */ + +// This implementation is based on x86-simd-sort(https://github.com/intel/x86-simd-sort) + +#ifndef XSS_COMMON_INCLUDES +#define XSS_COMMON_INCLUDES +#include +#include +#include +#include +/* +Workaround for the bug in GCC12 (that was fixed in GCC 12.3.1). +More details are available at: +https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105593 +*/ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#pragma GCC diagnostic ignored "-Wuninitialized" +#include +#pragma GCC diagnostic pop +#include +#include + +#define X86_SIMD_SORT_INFINITY std::numeric_limits::infinity() +#define X86_SIMD_SORT_INFINITYF std::numeric_limits::infinity() +#define X86_SIMD_SORT_INFINITYH 0x7c00 +#define X86_SIMD_SORT_NEGINFINITYH 0xfc00 +#define X86_SIMD_SORT_MAX_UINT16 std::numeric_limits::max() +#define X86_SIMD_SORT_MAX_INT16 std::numeric_limits::max() +#define X86_SIMD_SORT_MIN_INT16 std::numeric_limits::min() +#define X86_SIMD_SORT_MAX_UINT32 std::numeric_limits::max() +#define X86_SIMD_SORT_MAX_INT32 std::numeric_limits::max() +#define X86_SIMD_SORT_MIN_INT32 std::numeric_limits::min() +#define X86_SIMD_SORT_MAX_UINT64 std::numeric_limits::max() +#define X86_SIMD_SORT_MAX_INT64 std::numeric_limits::max() +#define X86_SIMD_SORT_MIN_INT64 std::numeric_limits::min() +#define ZMM_MAX_DOUBLE _mm512_set1_pd(X86_SIMD_SORT_INFINITY) +#define ZMM_MAX_UINT64 _mm512_set1_epi64(X86_SIMD_SORT_MAX_UINT64) +#define ZMM_MAX_INT64 _mm512_set1_epi64(X86_SIMD_SORT_MAX_INT64) +#define ZMM_MAX_FLOAT _mm512_set1_ps(X86_SIMD_SORT_INFINITYF) +#define ZMM_MAX_UINT _mm512_set1_epi32(X86_SIMD_SORT_MAX_UINT32) +#define ZMM_MAX_INT _mm512_set1_epi32(X86_SIMD_SORT_MAX_INT32) +#define ZMM_MAX_HALF _mm512_set1_epi16(X86_SIMD_SORT_INFINITYH) +#define YMM_MAX_HALF _mm256_set1_epi16(X86_SIMD_SORT_INFINITYH) +#define ZMM_MAX_UINT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_UINT16) +#define ZMM_MAX_INT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_INT16) +#define SHUFFLE_MASK(a, b, c, d) (a << 6) | (b << 4) | (c << 2) | d + +#define PRAGMA(x) _Pragma(#x) +#define UNUSED(x) (void)(x) + +/* Compiler specific macros specific */ +#if defined(__GNUC__) +#define X86_SIMD_SORT_INLINE static inline +#define X86_SIMD_SORT_FINLINE static inline __attribute__((always_inline)) +#else +#define X86_SIMD_SORT_INLINE static +#define X86_SIMD_SORT_FINLINE static +#endif + +#if __GNUC__ >= 8 +#define X86_SIMD_SORT_UNROLL_LOOP(num) PRAGMA(GCC unroll num) +#else +#define X86_SIMD_SORT_UNROLL_LOOP(num) +#endif + +typedef size_t arrsize_t; + +template +struct zmm_vector; + +template +struct ymm_vector; + +template +struct avx2_vector; + +#endif // XSS_COMMON_INCLUDES diff --git a/src/java.base/linux/native/libsimdsort/xss-common-qsort.h b/src/java.base/linux/native/libsimdsort/xss-common-qsort.h new file mode 100644 index 00000000000..07279a487c4 --- /dev/null +++ b/src/java.base/linux/native/libsimdsort/xss-common-qsort.h @@ -0,0 +1,528 @@ +/* + * Copyright (c) 2021, 2023, Intel Corporation. All rights reserved. + * Copyright (c) 2021 Serge Sans Paille. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + * + */ + +// This implementation is based on x86-simd-sort(https://github.com/intel/x86-simd-sort) + +#ifndef XSS_COMMON_QSORT +#define XSS_COMMON_QSORT + +/* + * Quicksort using AVX-512. The ideas and code are based on these two research + * papers [1] and [2]. On a high level, the idea is to vectorize quicksort + * partitioning using AVX-512 compressstore instructions. If the array size is + * < 128, then use Bitonic sorting network implemented on 512-bit registers. + * The precise network definitions depend on the dtype and are defined in + * separate files: avx512-16bit-qsort.hpp, avx512-32bit-qsort.hpp and + * avx512-64bit-qsort.hpp. Article [4] is a good resource for bitonic sorting + * network. The core implementations of the vectorized qsort functions + * avx512_qsort(T*, arrsize_t) are modified versions of avx2 quicksort + * presented in the paper [2] and source code associated with that paper [3]. + * + * [1] Fast and Robust Vectorized In-Place Sorting of Primitive Types + * https://drops.dagstuhl.de/opus/volltexte/2021/13775/ + * + * [2] A Novel Hybrid Quicksort Algorithm Vectorized using AVX-512 on Intel + * Skylake https://arxiv.org/pdf/1704.08579.pdf + * + * [3] https://github.com/simd-sorting/fast-and-robust: SPDX-License-Identifier: + * MIT + * + * [4] http://mitp-content-server.mit.edu:18180/books/content/sectbyfn?collid=books_pres_0&fn=Chapter%2027.pdf&id=8030 + * + */ + +#include "xss-common-includes.h" +#include "xss-pivot-selection.hpp" +#include "xss-network-qsort.hpp" + + +template +bool is_a_nan(T elem) { + return std::isnan(elem); +} + +template +X86_SIMD_SORT_INLINE T get_pivot_scalar(T *arr, const int64_t left, const int64_t right) { + // median of 8 equally spaced elements + int64_t NUM_ELEMENTS = 8; + int64_t MID = NUM_ELEMENTS / 2; + int64_t size = (right - left) / NUM_ELEMENTS; + T temp[NUM_ELEMENTS]; + for (int64_t i = 0; i < NUM_ELEMENTS; i++) temp[i] = arr[left + (i * size)]; + std::sort(temp, temp + NUM_ELEMENTS); + return temp[MID]; +} + +template +bool comparison_func_ge(const T &a, const T &b) { + return a < b; +} + +template +bool comparison_func_gt(const T &a, const T &b) { + return a <= b; +} + +/* + * COEX == Compare and Exchange two registers by swapping min and max values + */ +template +X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b) { + mm_t temp = a; + a = vtype::min(a, b); + b = vtype::max(temp, b); +} + +template +X86_SIMD_SORT_INLINE reg_t cmp_merge(reg_t in1, reg_t in2, opmask_t mask) { + reg_t min = vtype::min(in2, in1); + reg_t max = vtype::max(in2, in1); + return vtype::mask_mov(min, mask, max); // 0 -> min, 1 -> max +} + +template +int avx512_double_compressstore(type_t *left_addr, type_t *right_addr, + typename vtype::opmask_t k, reg_t reg) { + int amount_ge_pivot = _mm_popcnt_u32((int)k); + + vtype::mask_compressstoreu(left_addr, vtype::knot_opmask(k), reg); + vtype::mask_compressstoreu(right_addr + vtype::numlanes - amount_ge_pivot, + k, reg); + + return amount_ge_pivot; +} + +// Generic function dispatches to AVX2 or AVX512 code +template +X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store, type_t *r_store, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t &smallest_vec, + reg_t &biggest_vec, bool use_gt) { + //typename vtype::opmask_t ge_mask = vtype::ge(curr_vec, pivot_vec); + typename vtype::opmask_t mask; + if (use_gt) mask = vtype::gt(curr_vec, pivot_vec); + else mask = vtype::ge(curr_vec, pivot_vec); + + int amount_ge_pivot = + vtype::double_compressstore(l_store, r_store, mask, curr_vec); + + smallest_vec = vtype::min(curr_vec, smallest_vec); + biggest_vec = vtype::max(curr_vec, biggest_vec); + + return amount_ge_pivot; +} + +/* + * Parition an array based on the pivot and returns the index of the + * first element that is greater than or equal to the pivot. + */ +template +X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, arrsize_t left, + arrsize_t right, type_t pivot, + type_t *smallest, + type_t *biggest, + bool use_gt) { + auto comparison_func = use_gt ? comparison_func_gt : comparison_func_ge; + /* make array length divisible by vtype::numlanes , shortening the array */ + for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { + *smallest = std::min(*smallest, arr[left], comparison_func); + *biggest = std::max(*biggest, arr[left], comparison_func); + if (!comparison_func(arr[left], pivot)) { + std::swap(arr[left], arr[--right]); + } else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using reg_t = typename vtype::reg_t; + reg_t pivot_vec = vtype::set1(pivot); + reg_t min_vec = vtype::set1(*smallest); + reg_t max_vec = vtype::set1(*biggest); + + if (right - left == vtype::numlanes) { + reg_t vec = vtype::loadu(arr + left); + arrsize_t unpartitioned = right - left - vtype::numlanes; + arrsize_t l_store = left; + + arrsize_t amount_ge_pivot = + partition_vec(arr + l_store, arr + l_store + unpartitioned, + vec, pivot_vec, min_vec, max_vec, use_gt); + l_store += (vtype::numlanes - amount_ge_pivot); + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; + } + + // first and last vtype::numlanes values are partitioned at the end + reg_t vec_left = vtype::loadu(arr + left); + reg_t vec_right = vtype::loadu(arr + (right - vtype::numlanes)); + // store points of the vectors + arrsize_t unpartitioned = right - left - vtype::numlanes; + arrsize_t l_store = left; + // indices for loading the elements + left += vtype::numlanes; + right -= vtype::numlanes; + while (right - left != 0) { + reg_t curr_vec; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((l_store + unpartitioned + vtype::numlanes) - right < + left - l_store) { + right -= vtype::numlanes; + curr_vec = vtype::loadu(arr + right); + } else { + curr_vec = vtype::loadu(arr + left); + left += vtype::numlanes; + } + // partition the current vector and save it on both sides of the array + arrsize_t amount_ge_pivot = + partition_vec(arr + l_store, arr + l_store + unpartitioned, + curr_vec, pivot_vec, min_vec, max_vec, use_gt); + l_store += (vtype::numlanes - amount_ge_pivot); + unpartitioned -= vtype::numlanes; + } + + /* partition and save vec_left and vec_right */ + arrsize_t amount_ge_pivot = + partition_vec(arr + l_store, arr + l_store + unpartitioned, + vec_left, pivot_vec, min_vec, max_vec, use_gt); + l_store += (vtype::numlanes - amount_ge_pivot); + unpartitioned -= vtype::numlanes; + + amount_ge_pivot = + partition_vec(arr + l_store, arr + l_store + unpartitioned, + vec_right, pivot_vec, min_vec, max_vec, use_gt); + l_store += (vtype::numlanes - amount_ge_pivot); + unpartitioned -= vtype::numlanes; + + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} + +template +X86_SIMD_SORT_INLINE arrsize_t +partition_avx512_unrolled(type_t *arr, arrsize_t left, arrsize_t right, + type_t pivot, type_t *smallest, type_t *biggest, bool use_gt) { + if constexpr (num_unroll == 0) { + return partition_avx512(arr, left, right, pivot, smallest, + biggest, use_gt); + } + + /* Use regular partition_avx512 for smaller arrays */ + if (right - left < 3 * num_unroll * vtype::numlanes) { + return partition_avx512(arr, left, right, pivot, smallest, + biggest, use_gt); + } + + auto comparison_func = use_gt ? comparison_func_gt : comparison_func_ge; + /* make array length divisible by vtype::numlanes, shortening the array */ + for (int32_t i = ((right - left) % (vtype::numlanes)); i > 0; --i) { + *smallest = std::min(*smallest, arr[left], comparison_func); + *biggest = std::max(*biggest, arr[left], comparison_func); + if (!comparison_func(arr[left], pivot)) { + std::swap(arr[left], arr[--right]); + } else { + ++left; + } + } + + arrsize_t unpartitioned = right - left - vtype::numlanes; + arrsize_t l_store = left; + + using reg_t = typename vtype::reg_t; + reg_t pivot_vec = vtype::set1(pivot); + reg_t min_vec = vtype::set1(*smallest); + reg_t max_vec = vtype::set1(*biggest); + + /* Calculate and load more registers to make the rest of the array a + * multiple of num_unroll. These registers will be partitioned at the very + * end. */ + int vecsToPartition = ((right - left) / vtype::numlanes) % num_unroll; + reg_t vec_align[num_unroll]; + for (int i = 0; i < vecsToPartition; i++) { + vec_align[i] = vtype::loadu(arr + left + i * vtype::numlanes); + } + left += vecsToPartition * vtype::numlanes; + + /* We will now have atleast 3*num_unroll registers worth of data to + * process. Load left and right vtype::numlanes*num_unroll values into + * registers to make space for in-place parition. The vec_left and + * vec_right registers are partitioned at the end */ + reg_t vec_left[num_unroll], vec_right[num_unroll]; + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + vec_left[ii] = vtype::loadu(arr + left + vtype::numlanes * ii); + vec_right[ii] = + vtype::loadu(arr + (right - vtype::numlanes * (num_unroll - ii))); + } + /* indices for loading the elements */ + left += num_unroll * vtype::numlanes; + right -= num_unroll * vtype::numlanes; + while (right - left != 0) { + reg_t curr_vec[num_unroll]; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((l_store + unpartitioned + vtype::numlanes) - right < + left - l_store) { + right -= num_unroll * vtype::numlanes; + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + curr_vec[ii] = vtype::loadu(arr + right + ii * vtype::numlanes); + /* + * error: '_mm_prefetch' needs target feature mmx on clang-cl + */ +#if !(defined(_MSC_VER) && defined(__clang__)) + _mm_prefetch((char *)(arr + right + ii * vtype::numlanes - + num_unroll * vtype::numlanes), + _MM_HINT_T0); +#endif + } + } else { + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + curr_vec[ii] = vtype::loadu(arr + left + ii * vtype::numlanes); + /* + * error: '_mm_prefetch' needs target feature mmx on clang-cl + */ +#if !(defined(_MSC_VER) && defined(__clang__)) + _mm_prefetch((char *)(arr + left + ii * vtype::numlanes + + num_unroll * vtype::numlanes), + _MM_HINT_T0); +#endif + } + left += num_unroll * vtype::numlanes; + } + /* partition the current vector and save it on both sides of the array + * */ + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + arrsize_t amount_ge_pivot = partition_vec( + arr + l_store, arr + l_store + unpartitioned, curr_vec[ii], + pivot_vec, min_vec, max_vec, use_gt); + l_store += (vtype::numlanes - amount_ge_pivot); + unpartitioned -= vtype::numlanes; + } + } + + /* partition and save vec_left[num_unroll] and vec_right[num_unroll] */ + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + arrsize_t amount_ge_pivot = + partition_vec(arr + l_store, arr + l_store + unpartitioned, + vec_left[ii], pivot_vec, min_vec, max_vec, use_gt); + l_store += (vtype::numlanes - amount_ge_pivot); + unpartitioned -= vtype::numlanes; + } + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + arrsize_t amount_ge_pivot = + partition_vec(arr + l_store, arr + l_store + unpartitioned, + vec_right[ii], pivot_vec, min_vec, max_vec, use_gt); + l_store += (vtype::numlanes - amount_ge_pivot); + unpartitioned -= vtype::numlanes; + } + + /* partition and save vec_align[vecsToPartition] */ + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < vecsToPartition; ++ii) { + arrsize_t amount_ge_pivot = + partition_vec(arr + l_store, arr + l_store + unpartitioned, + vec_align[ii], pivot_vec, min_vec, max_vec, use_gt); + l_store += (vtype::numlanes - amount_ge_pivot); + unpartitioned -= vtype::numlanes; + } + + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} + +template +void sort_n(typename vtype::type_t *arr, int N); + +template +static void qsort_(type_t *arr, arrsize_t left, arrsize_t right, + arrsize_t max_iters) { + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std::sort(arr + left, arr + right + 1, comparison_func_ge); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= + * vtype::network_sort_threshold + */ + if (right + 1 - left <= vtype::network_sort_threshold) { + sort_n( + arr + left, (int32_t)(right + 1 - left)); + return; + } + + type_t pivot = get_pivot_blocks(arr, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + + arrsize_t pivot_index = + partition_avx512_unrolled( + arr, left, right + 1, pivot, &smallest, &biggest, false); + + if (pivot != smallest) + qsort_(arr, left, pivot_index - 1, max_iters - 1); + if (pivot != biggest) qsort_(arr, pivot_index, right, max_iters - 1); +} + +// Hooks for OpenJDK sort +// to_index (exclusive) +template +static int64_t vectorized_partition(type_t *arr, int64_t from_index, int64_t to_index, type_t pivot, bool use_gt) { + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + int64_t pivot_index = partition_avx512_unrolled( + arr, from_index, to_index, pivot, &smallest, &biggest, use_gt); + return pivot_index; +} + +// partitioning functions +template +X86_SIMD_SORT_INLINE void simd_dual_pivot_partition(T *arr, int64_t from_index, int64_t to_index, int32_t *pivot_indices, int64_t index_pivot1, int64_t index_pivot2){ + const T pivot1 = arr[index_pivot1]; + const T pivot2 = arr[index_pivot2]; + + const int64_t low = from_index; + const int64_t high = to_index; + const int64_t start = low + 1; + const int64_t end = high - 1; + + + std::swap(arr[index_pivot1], arr[low]); + std::swap(arr[index_pivot2], arr[end]); + + + const int64_t pivot_index2 = vectorized_partition(arr, start, end, pivot2, true); // use_gt = true + std::swap(arr[end], arr[pivot_index2]); + int64_t upper = pivot_index2; + + // if all other elements are greater than pivot2 (and pivot1), no need to do further partitioning + if (upper == start) { + pivot_indices[0] = low; + pivot_indices[1] = upper; + return; + } + + const int64_t pivot_index1 = vectorized_partition(arr, start, upper, pivot1, false); // use_ge (use_gt = false) + int64_t lower = pivot_index1 - 1; + std::swap(arr[low], arr[lower]); + + pivot_indices[0] = lower; + pivot_indices[1] = upper; +} + +template +X86_SIMD_SORT_INLINE void simd_single_pivot_partition(T *arr, int64_t from_index, int64_t to_index, int32_t *pivot_indices, int64_t index_pivot) { + const T pivot = arr[index_pivot]; + + const int64_t low = from_index; + const int64_t high = to_index; + const int64_t end = high - 1; + + + const int64_t pivot_index1 = vectorized_partition(arr, low, high, pivot, false); // use_gt = false (use_ge) + int64_t lower = pivot_index1; + + const int64_t pivot_index2 = vectorized_partition(arr, pivot_index1, high, pivot, true); // use_gt = true + int64_t upper = pivot_index2; + + pivot_indices[0] = lower; + pivot_indices[1] = upper; +} + +template +X86_SIMD_SORT_INLINE void simd_fast_partition(T *arr, int64_t from_index, int64_t to_index, int32_t *pivot_indices, int64_t index_pivot1, int64_t index_pivot2) { + if (index_pivot1 != index_pivot2) { + simd_dual_pivot_partition(arr, from_index, to_index, pivot_indices, index_pivot1, index_pivot2); + } + else { + simd_single_pivot_partition(arr, from_index, to_index, pivot_indices, index_pivot1); + } +} + +template +X86_SIMD_SORT_INLINE void insertion_sort(T *arr, int32_t from_index, int32_t to_index) { + for (int i, k = from_index; ++k < to_index; ) { + T ai = arr[i = k]; + if (ai < arr[i - 1]) { + while (--i >= from_index && ai < arr[i]) { + arr[i + 1] = arr[i]; + } + arr[i + 1] = ai; + } + } +} + +template +X86_SIMD_SORT_INLINE void simd_fast_sort(T *arr, arrsize_t from_index, arrsize_t to_index, const arrsize_t INS_SORT_THRESHOLD) +{ + arrsize_t arrsize = to_index - from_index; + if (arrsize <= INS_SORT_THRESHOLD) { + insertion_sort(arr, from_index, to_index); + } else { + qsort_(arr, from_index, to_index - 1, 2 * (arrsize_t)log2(arrsize)); + } +} + +#define DEFINE_METHODS(ISA, VTYPE) \ + template \ + X86_SIMD_SORT_INLINE void ISA##_fast_sort( \ + T *arr, arrsize_t from_index, arrsize_t to_index, const arrsize_t INS_SORT_THRESHOLD) \ + { \ + simd_fast_sort(arr, from_index, to_index, INS_SORT_THRESHOLD); \ + } \ + template \ + X86_SIMD_SORT_INLINE void ISA##_fast_partition( \ + T *arr, int64_t from_index, int64_t to_index, int32_t *pivot_indices, int64_t index_pivot1, int64_t index_pivot2) \ + { \ + simd_fast_partition(arr, from_index, to_index, pivot_indices, index_pivot1, index_pivot2); \ + } + +DEFINE_METHODS(avx2, avx2_vector) +DEFINE_METHODS(avx512, zmm_vector) + +#endif // XSS_COMMON_QSORT diff --git a/src/java.base/linux/native/libsimdsort/xss-network-qsort.hpp b/src/java.base/linux/native/libsimdsort/xss-network-qsort.hpp new file mode 100644 index 00000000000..d0a6188b63b --- /dev/null +++ b/src/java.base/linux/native/libsimdsort/xss-network-qsort.hpp @@ -0,0 +1,209 @@ +/* + * Copyright (c) 2021, 2023, Intel Corporation. All rights reserved. + * Copyright (c) 2021 Serge Sans Paille. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + * + */ + +// This implementation is based on x86-simd-sort(https://github.com/intel/x86-simd-sort) + +#ifndef XSS_NETWORK_QSORT +#define XSS_NETWORK_QSORT + +#include "xss-common-qsort.h" +#include "xss-optimal-networks.hpp" + +template +X86_SIMD_SORT_FINLINE void bitonic_sort_n_vec(reg_t *regs) { + if constexpr (numVecs == 1) { + UNUSED(regs); + return; + } else if constexpr (numVecs == 2) { + COEX(regs[0], regs[1]); + } else if constexpr (numVecs == 4) { + optimal_sort_4(regs); + } else if constexpr (numVecs == 8) { + optimal_sort_8(regs); + } else if constexpr (numVecs == 16) { + optimal_sort_16(regs); + } else if constexpr (numVecs == 32) { + optimal_sort_32(regs); + } else { + static_assert(numVecs == -1, "should not reach here"); + } +} + +/* + * Swizzle ops explained: + * swap_n: swap neighbouring blocks of size within block of + * size reg i = [7,6,5,4,3,2,1,0] swap_n<2>: = + * [[6,7],[4,5],[2,3],[0,1]] swap_n<4>: = [[5,4,7,6],[1,0,3,2]] swap_n<8>: = + * [[3,2,1,0,7,6,5,4]] reverse_n: reverse elements within block of size + * reg i = [7,6,5,4,3,2,1,0] rev_n<2>: = + * [[6,7],[4,5],[2,3],[0,1]] rev_n<4>: = [[4,5,6,7],[0,1,2,3]] rev_n<8>: = + * [[0,1,2,3,4,5,6,7]] merge_n: merge blocks of elements from + * two regs reg b,a = [a,a,a,a,a,a,a,a], [b,b,b,b,b,b,b,b] merge_n<2> = + * [a,b,a,b,a,b,a,b] merge_n<4> = [a,a,b,b,a,a,b,b] merge_n<8> = + * [a,a,a,a,b,b,b,b] + */ + +template +X86_SIMD_SORT_FINLINE void internal_merge_n_vec(typename vtype::reg_t *reg) { + using reg_t = typename vtype::reg_t; + using swizzle = typename vtype::swizzle_ops; + if constexpr (scale <= 1) { + UNUSED(reg); + return; + } else { + if constexpr (first) { + // Use reverse then merge + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs; i++) { + reg_t &v = reg[i]; + reg_t rev = swizzle::template reverse_n(v); + COEX(rev, v); + v = swizzle::template merge_n(v, rev); + } + } else { + // Use swap then merge + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs; i++) { + reg_t &v = reg[i]; + reg_t swap = swizzle::template swap_n(v); + COEX(swap, v); + v = swizzle::template merge_n(v, swap); + } + } + internal_merge_n_vec(reg); + } +} + +template +X86_SIMD_SORT_FINLINE void merge_substep_n_vec(reg_t *regs) { + using swizzle = typename vtype::swizzle_ops; + if constexpr (numVecs <= 1) { + UNUSED(regs); + return; + } + + // Reverse upper half of vectors + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2; i < numVecs; i++) { + regs[i] = swizzle::template reverse_n(regs[i]); + } + // Do compare exchanges + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + COEX(regs[i], regs[numVecs - 1 - i]); + } + + merge_substep_n_vec(regs); + merge_substep_n_vec(regs + numVecs / 2); +} + +template +X86_SIMD_SORT_FINLINE void merge_step_n_vec(reg_t *regs) { + // Do cross vector merges + merge_substep_n_vec(regs); + + // Do internal vector merges + internal_merge_n_vec(regs); +} + +template +X86_SIMD_SORT_FINLINE void merge_n_vec(reg_t *regs) { + if constexpr (numPer > vtype::numlanes) { + UNUSED(regs); + return; + } else { + merge_step_n_vec(regs); + merge_n_vec(regs); + } +} + +template +X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int N) { + static_assert(numVecs > 0, "numVecs should be > 0"); + if constexpr (numVecs > 1) { + if (N * 2 <= numVecs * vtype::numlanes) { + sort_n_vec(arr, N); + return; + } + } + + reg_t vecs[numVecs]; + + // Generate masks for loading and storing + typename vtype::opmask_t ioMasks[numVecs - numVecs / 2]; + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + uint64_t num_to_read = + std::min((uint64_t)std::max(0, N - i * vtype::numlanes), + (uint64_t)vtype::numlanes); + ioMasks[j] = vtype::get_partial_loadmask(num_to_read); + } + + // Unmasked part of the load + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + vecs[i] = vtype::loadu(arr + i * vtype::numlanes); + } + // Masked part of the load + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + vecs[i] = vtype::mask_loadu(vtype::zmm_max(), ioMasks[j], + arr + i * vtype::numlanes); + } + + /* Run the initial sorting network to sort the columns of the [numVecs x + * num_lanes] matrix + */ + bitonic_sort_n_vec(vecs); + + // Merge the vectors using bitonic merging networks + merge_n_vec(vecs); + + // Unmasked part of the store + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + vtype::storeu(arr + i * vtype::numlanes, vecs[i]); + } + // Masked part of the store + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + vtype::mask_storeu(arr + i * vtype::numlanes, ioMasks[j], vecs[i]); + } +} + +template +X86_SIMD_SORT_INLINE void sort_n(typename vtype::type_t *arr, int N) { + constexpr int numVecs = maxN / vtype::numlanes; + constexpr bool isMultiple = (maxN == (vtype::numlanes * numVecs)); + constexpr bool powerOfTwo = (numVecs != 0 && !(numVecs & (numVecs - 1))); + static_assert(powerOfTwo == true && isMultiple == true, + "maxN must be vtype::numlanes times a power of 2"); + + sort_n_vec(arr, N); +} +#endif diff --git a/src/java.base/linux/native/libsimdsort/xss-optimal-networks.hpp b/src/java.base/linux/native/libsimdsort/xss-optimal-networks.hpp new file mode 100644 index 00000000000..584b8c84118 --- /dev/null +++ b/src/java.base/linux/native/libsimdsort/xss-optimal-networks.hpp @@ -0,0 +1,342 @@ +/* + * Copyright (c) 2021, 2023, Intel Corporation. All rights reserved. + * Copyright (c) 2021 Serge Sans Paille. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + * + */ + +// This implementation is based on x86-simd-sort(https://github.com/intel/x86-simd-sort) All of these sources +// files are generated from the optimal networks described in +// https://bertdobbelaere.github.io/sorting_networks.html + +template +X86_SIMD_SORT_FINLINE void optimal_sort_4(reg_t *vecs) { + COEX(vecs[0], vecs[2]); + COEX(vecs[1], vecs[3]); + + COEX(vecs[0], vecs[1]); + COEX(vecs[2], vecs[3]); + + COEX(vecs[1], vecs[2]); +} + +template +X86_SIMD_SORT_FINLINE void optimal_sort_8(reg_t *vecs) { + COEX(vecs[0], vecs[2]); + COEX(vecs[1], vecs[3]); + COEX(vecs[4], vecs[6]); + COEX(vecs[5], vecs[7]); + + COEX(vecs[0], vecs[4]); + COEX(vecs[1], vecs[5]); + COEX(vecs[2], vecs[6]); + COEX(vecs[3], vecs[7]); + + COEX(vecs[0], vecs[1]); + COEX(vecs[2], vecs[3]); + COEX(vecs[4], vecs[5]); + COEX(vecs[6], vecs[7]); + + COEX(vecs[2], vecs[4]); + COEX(vecs[3], vecs[5]); + + COEX(vecs[1], vecs[4]); + COEX(vecs[3], vecs[6]); + + COEX(vecs[1], vecs[2]); + COEX(vecs[3], vecs[4]); + COEX(vecs[5], vecs[6]); +} + +template +X86_SIMD_SORT_FINLINE void optimal_sort_16(reg_t *vecs) { + COEX(vecs[0], vecs[13]); + COEX(vecs[1], vecs[12]); + COEX(vecs[2], vecs[15]); + COEX(vecs[3], vecs[14]); + COEX(vecs[4], vecs[8]); + COEX(vecs[5], vecs[6]); + COEX(vecs[7], vecs[11]); + COEX(vecs[9], vecs[10]); + + COEX(vecs[0], vecs[5]); + COEX(vecs[1], vecs[7]); + COEX(vecs[2], vecs[9]); + COEX(vecs[3], vecs[4]); + COEX(vecs[6], vecs[13]); + COEX(vecs[8], vecs[14]); + COEX(vecs[10], vecs[15]); + COEX(vecs[11], vecs[12]); + + COEX(vecs[0], vecs[1]); + COEX(vecs[2], vecs[3]); + COEX(vecs[4], vecs[5]); + COEX(vecs[6], vecs[8]); + COEX(vecs[7], vecs[9]); + COEX(vecs[10], vecs[11]); + COEX(vecs[12], vecs[13]); + COEX(vecs[14], vecs[15]); + + COEX(vecs[0], vecs[2]); + COEX(vecs[1], vecs[3]); + COEX(vecs[4], vecs[10]); + COEX(vecs[5], vecs[11]); + COEX(vecs[6], vecs[7]); + COEX(vecs[8], vecs[9]); + COEX(vecs[12], vecs[14]); + COEX(vecs[13], vecs[15]); + + COEX(vecs[1], vecs[2]); + COEX(vecs[3], vecs[12]); + COEX(vecs[4], vecs[6]); + COEX(vecs[5], vecs[7]); + COEX(vecs[8], vecs[10]); + COEX(vecs[9], vecs[11]); + COEX(vecs[13], vecs[14]); + + COEX(vecs[1], vecs[4]); + COEX(vecs[2], vecs[6]); + COEX(vecs[5], vecs[8]); + COEX(vecs[7], vecs[10]); + COEX(vecs[9], vecs[13]); + COEX(vecs[11], vecs[14]); + + COEX(vecs[2], vecs[4]); + COEX(vecs[3], vecs[6]); + COEX(vecs[9], vecs[12]); + COEX(vecs[11], vecs[13]); + + COEX(vecs[3], vecs[5]); + COEX(vecs[6], vecs[8]); + COEX(vecs[7], vecs[9]); + COEX(vecs[10], vecs[12]); + + COEX(vecs[3], vecs[4]); + COEX(vecs[5], vecs[6]); + COEX(vecs[7], vecs[8]); + COEX(vecs[9], vecs[10]); + COEX(vecs[11], vecs[12]); + + COEX(vecs[6], vecs[7]); + COEX(vecs[8], vecs[9]); +} + +template +X86_SIMD_SORT_FINLINE void optimal_sort_32(reg_t *vecs) { + COEX(vecs[0], vecs[1]); + COEX(vecs[2], vecs[3]); + COEX(vecs[4], vecs[5]); + COEX(vecs[6], vecs[7]); + COEX(vecs[8], vecs[9]); + COEX(vecs[10], vecs[11]); + COEX(vecs[12], vecs[13]); + COEX(vecs[14], vecs[15]); + COEX(vecs[16], vecs[17]); + COEX(vecs[18], vecs[19]); + COEX(vecs[20], vecs[21]); + COEX(vecs[22], vecs[23]); + COEX(vecs[24], vecs[25]); + COEX(vecs[26], vecs[27]); + COEX(vecs[28], vecs[29]); + COEX(vecs[30], vecs[31]); + + COEX(vecs[0], vecs[2]); + COEX(vecs[1], vecs[3]); + COEX(vecs[4], vecs[6]); + COEX(vecs[5], vecs[7]); + COEX(vecs[8], vecs[10]); + COEX(vecs[9], vecs[11]); + COEX(vecs[12], vecs[14]); + COEX(vecs[13], vecs[15]); + COEX(vecs[16], vecs[18]); + COEX(vecs[17], vecs[19]); + COEX(vecs[20], vecs[22]); + COEX(vecs[21], vecs[23]); + COEX(vecs[24], vecs[26]); + COEX(vecs[25], vecs[27]); + COEX(vecs[28], vecs[30]); + COEX(vecs[29], vecs[31]); + + COEX(vecs[0], vecs[4]); + COEX(vecs[1], vecs[5]); + COEX(vecs[2], vecs[6]); + COEX(vecs[3], vecs[7]); + COEX(vecs[8], vecs[12]); + COEX(vecs[9], vecs[13]); + COEX(vecs[10], vecs[14]); + COEX(vecs[11], vecs[15]); + COEX(vecs[16], vecs[20]); + COEX(vecs[17], vecs[21]); + COEX(vecs[18], vecs[22]); + COEX(vecs[19], vecs[23]); + COEX(vecs[24], vecs[28]); + COEX(vecs[25], vecs[29]); + COEX(vecs[26], vecs[30]); + COEX(vecs[27], vecs[31]); + + COEX(vecs[0], vecs[8]); + COEX(vecs[1], vecs[9]); + COEX(vecs[2], vecs[10]); + COEX(vecs[3], vecs[11]); + COEX(vecs[4], vecs[12]); + COEX(vecs[5], vecs[13]); + COEX(vecs[6], vecs[14]); + COEX(vecs[7], vecs[15]); + COEX(vecs[16], vecs[24]); + COEX(vecs[17], vecs[25]); + COEX(vecs[18], vecs[26]); + COEX(vecs[19], vecs[27]); + COEX(vecs[20], vecs[28]); + COEX(vecs[21], vecs[29]); + COEX(vecs[22], vecs[30]); + COEX(vecs[23], vecs[31]); + + COEX(vecs[0], vecs[16]); + COEX(vecs[1], vecs[8]); + COEX(vecs[2], vecs[4]); + COEX(vecs[3], vecs[12]); + COEX(vecs[5], vecs[10]); + COEX(vecs[6], vecs[9]); + COEX(vecs[7], vecs[14]); + COEX(vecs[11], vecs[13]); + COEX(vecs[15], vecs[31]); + COEX(vecs[17], vecs[24]); + COEX(vecs[18], vecs[20]); + COEX(vecs[19], vecs[28]); + COEX(vecs[21], vecs[26]); + COEX(vecs[22], vecs[25]); + COEX(vecs[23], vecs[30]); + COEX(vecs[27], vecs[29]); + + COEX(vecs[1], vecs[2]); + COEX(vecs[3], vecs[5]); + COEX(vecs[4], vecs[8]); + COEX(vecs[6], vecs[22]); + COEX(vecs[7], vecs[11]); + COEX(vecs[9], vecs[25]); + COEX(vecs[10], vecs[12]); + COEX(vecs[13], vecs[14]); + COEX(vecs[17], vecs[18]); + COEX(vecs[19], vecs[21]); + COEX(vecs[20], vecs[24]); + COEX(vecs[23], vecs[27]); + COEX(vecs[26], vecs[28]); + COEX(vecs[29], vecs[30]); + + COEX(vecs[1], vecs[17]); + COEX(vecs[2], vecs[18]); + COEX(vecs[3], vecs[19]); + COEX(vecs[4], vecs[20]); + COEX(vecs[5], vecs[10]); + COEX(vecs[7], vecs[23]); + COEX(vecs[8], vecs[24]); + COEX(vecs[11], vecs[27]); + COEX(vecs[12], vecs[28]); + COEX(vecs[13], vecs[29]); + COEX(vecs[14], vecs[30]); + COEX(vecs[21], vecs[26]); + + COEX(vecs[3], vecs[17]); + COEX(vecs[4], vecs[16]); + COEX(vecs[5], vecs[21]); + COEX(vecs[6], vecs[18]); + COEX(vecs[7], vecs[9]); + COEX(vecs[8], vecs[20]); + COEX(vecs[10], vecs[26]); + COEX(vecs[11], vecs[23]); + COEX(vecs[13], vecs[25]); + COEX(vecs[14], vecs[28]); + COEX(vecs[15], vecs[27]); + COEX(vecs[22], vecs[24]); + + COEX(vecs[1], vecs[4]); + COEX(vecs[3], vecs[8]); + COEX(vecs[5], vecs[16]); + COEX(vecs[7], vecs[17]); + COEX(vecs[9], vecs[21]); + COEX(vecs[10], vecs[22]); + COEX(vecs[11], vecs[19]); + COEX(vecs[12], vecs[20]); + COEX(vecs[14], vecs[24]); + COEX(vecs[15], vecs[26]); + COEX(vecs[23], vecs[28]); + COEX(vecs[27], vecs[30]); + + COEX(vecs[2], vecs[5]); + COEX(vecs[7], vecs[8]); + COEX(vecs[9], vecs[18]); + COEX(vecs[11], vecs[17]); + COEX(vecs[12], vecs[16]); + COEX(vecs[13], vecs[22]); + COEX(vecs[14], vecs[20]); + COEX(vecs[15], vecs[19]); + COEX(vecs[23], vecs[24]); + COEX(vecs[26], vecs[29]); + + COEX(vecs[2], vecs[4]); + COEX(vecs[6], vecs[12]); + COEX(vecs[9], vecs[16]); + COEX(vecs[10], vecs[11]); + COEX(vecs[13], vecs[17]); + COEX(vecs[14], vecs[18]); + COEX(vecs[15], vecs[22]); + COEX(vecs[19], vecs[25]); + COEX(vecs[20], vecs[21]); + COEX(vecs[27], vecs[29]); + + COEX(vecs[5], vecs[6]); + COEX(vecs[8], vecs[12]); + COEX(vecs[9], vecs[10]); + COEX(vecs[11], vecs[13]); + COEX(vecs[14], vecs[16]); + COEX(vecs[15], vecs[17]); + COEX(vecs[18], vecs[20]); + COEX(vecs[19], vecs[23]); + COEX(vecs[21], vecs[22]); + COEX(vecs[25], vecs[26]); + + COEX(vecs[3], vecs[5]); + COEX(vecs[6], vecs[7]); + COEX(vecs[8], vecs[9]); + COEX(vecs[10], vecs[12]); + COEX(vecs[11], vecs[14]); + COEX(vecs[13], vecs[16]); + COEX(vecs[15], vecs[18]); + COEX(vecs[17], vecs[20]); + COEX(vecs[19], vecs[21]); + COEX(vecs[22], vecs[23]); + COEX(vecs[24], vecs[25]); + COEX(vecs[26], vecs[28]); + + COEX(vecs[3], vecs[4]); + COEX(vecs[5], vecs[6]); + COEX(vecs[7], vecs[8]); + COEX(vecs[9], vecs[10]); + COEX(vecs[11], vecs[12]); + COEX(vecs[13], vecs[14]); + COEX(vecs[15], vecs[16]); + COEX(vecs[17], vecs[18]); + COEX(vecs[19], vecs[20]); + COEX(vecs[21], vecs[22]); + COEX(vecs[23], vecs[24]); + COEX(vecs[25], vecs[26]); + COEX(vecs[27], vecs[28]); +} diff --git a/src/java.base/linux/native/libsimdsort/xss-pivot-selection.hpp b/src/java.base/linux/native/libsimdsort/xss-pivot-selection.hpp new file mode 100644 index 00000000000..d65a30b56d6 --- /dev/null +++ b/src/java.base/linux/native/libsimdsort/xss-pivot-selection.hpp @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2021, 2023, Intel Corporation. All rights reserved. + * Copyright (c) 2021 Serge Sans Paille. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + * + */ + +// This implementation is based on x86-simd-sort(https://github.com/intel/x86-simd-sort) + +template +X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b); + +template +X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr, const arrsize_t left, + const arrsize_t right) { + using reg_t = typename vtype::reg_t; + type_t samples[vtype::numlanes]; + arrsize_t delta = (right - left) / vtype::numlanes; + for (int i = 0; i < vtype::numlanes; i++) { + samples[i] = arr[left + i * delta]; + } + reg_t rand_vec = vtype::loadu(samples); + reg_t sort = vtype::sort_vec(rand_vec); + + return ((type_t *)&sort)[vtype::numlanes / 2]; +} + +template +X86_SIMD_SORT_INLINE type_t get_pivot_blocks(type_t *arr, const arrsize_t left, + const arrsize_t right) { + if (right - left <= 1024) { + return get_pivot(arr, left, right); + } + + using reg_t = typename vtype::reg_t; + constexpr int numVecs = 5; + + arrsize_t width = (right - vtype::numlanes) - left; + arrsize_t delta = width / numVecs; + + reg_t vecs[numVecs]; + // Load data + for (int i = 0; i < numVecs; i++) { + vecs[i] = vtype::loadu(arr + left + delta * i); + } + + // Implement sorting network (from + // https://bertdobbelaere.github.io/sorting_networks.html) + COEX(vecs[0], vecs[3]); + COEX(vecs[1], vecs[4]); + + COEX(vecs[0], vecs[2]); + COEX(vecs[1], vecs[3]); + + COEX(vecs[0], vecs[1]); + COEX(vecs[2], vecs[4]); + + COEX(vecs[1], vecs[2]); + COEX(vecs[3], vecs[4]); + + COEX(vecs[2], vecs[3]); + + // Calculate median of the middle vector + reg_t &vec = vecs[numVecs / 2]; + vec = vtype::sort_vec(vec); + + type_t data[vtype::numlanes]; + vtype::storeu(data, vec); + return data[vtype::numlanes / 2]; +}