8309130: x86_64 AVX512 intrinsics for Arrays.sort methods (int, long, float and double arrays)
Reviewed-by: jbhateja, sviswanathan, psandoz, kvn
This commit is contained in:
parent
6c6beba256
commit
a4e9168bab
@ -227,10 +227,30 @@ ifeq ($(ENABLE_FALLBACK_LINKER), true)
|
||||
NAME := fallbackLinker, \
|
||||
CFLAGS := $(CFLAGS_JDKLIB) $(LIBFFI_CFLAGS), \
|
||||
LDFLAGS := $(LDFLAGS_JDKLIB) \
|
||||
$(call SET_SHARED_LIBRARY_ORIGIN), \
|
||||
$(call SET_SHARED_LIBRARY_ORIGIN), \
|
||||
LIBS := $(LIBFFI_LIBS), \
|
||||
LIBS_windows := $(LIBFFI_LIBS) ws2_32.lib, \
|
||||
))
|
||||
|
||||
TARGETS += $(BUILD_LIBFALLBACKLINKER)
|
||||
endif
|
||||
|
||||
################################################################################
|
||||
|
||||
ifeq ($(call isTargetOs, linux)+$(call isTargetCpu, x86_64)+$(INCLUDE_COMPILER2)+$(filter $(TOOLCHAIN_TYPE), gcc), true+true+true+gcc)
|
||||
$(eval $(call SetupJdkLibrary, BUILD_LIB_SIMD_SORT, \
|
||||
NAME := simdsort, \
|
||||
TOOLCHAIN := TOOLCHAIN_LINK_CXX, \
|
||||
OPTIMIZATION := HIGH, \
|
||||
CFLAGS := $(CFLAGS_JDKLIB), \
|
||||
CXXFLAGS := $(CXXFLAGS_JDKLIB), \
|
||||
LDFLAGS := $(LDFLAGS_JDKLIB) \
|
||||
$(call SET_SHARED_LIBRARY_ORIGIN), \
|
||||
LIBS := $(LIBCXX), \
|
||||
LIBS_linux := -lc -lm -ldl, \
|
||||
))
|
||||
|
||||
TARGETS += $(BUILD_LIB_SIMD_SORT)
|
||||
endif
|
||||
|
||||
################################################################################
|
||||
|
@ -4172,6 +4172,26 @@ void StubGenerator::generate_compiler_stubs() {
|
||||
= CAST_FROM_FN_PTR(address, SharedRuntime::montgomery_square);
|
||||
}
|
||||
|
||||
// Load x86_64_sort library on supported hardware to enable avx512 sort and partition intrinsics
|
||||
if (UseAVX > 2 && VM_Version::supports_avx512dq()) {
|
||||
void *libsimdsort = nullptr;
|
||||
char ebuf_[1024];
|
||||
char dll_name_simd_sort[JVM_MAXPATHLEN];
|
||||
if (os::dll_locate_lib(dll_name_simd_sort, sizeof(dll_name_simd_sort), Arguments::get_dll_dir(), "simdsort")) {
|
||||
libsimdsort = os::dll_load(dll_name_simd_sort, ebuf_, sizeof ebuf_);
|
||||
}
|
||||
// Get addresses for avx512 sort and partition routines
|
||||
if (libsimdsort != nullptr) {
|
||||
log_info(library)("Loaded library %s, handle " INTPTR_FORMAT, JNI_LIB_PREFIX "simdsort" JNI_LIB_SUFFIX, p2i(libsimdsort));
|
||||
|
||||
snprintf(ebuf_, sizeof(ebuf_), "avx512_sort");
|
||||
StubRoutines::_array_sort = (address)os::dll_lookup(libsimdsort, ebuf_);
|
||||
|
||||
snprintf(ebuf_, sizeof(ebuf_), "avx512_partition");
|
||||
StubRoutines::_array_partition = (address)os::dll_lookup(libsimdsort, ebuf_);
|
||||
}
|
||||
}
|
||||
|
||||
// Get svml stub routine addresses
|
||||
void *libjsvml = nullptr;
|
||||
char ebuf[1024];
|
||||
|
@ -341,6 +341,14 @@ class methodHandle;
|
||||
do_name( copyOf_name, "copyOf") \
|
||||
do_signature(copyOf_signature, "([Ljava/lang/Object;ILjava/lang/Class;)[Ljava/lang/Object;") \
|
||||
\
|
||||
do_intrinsic(_arraySort, java_util_DualPivotQuicksort, arraySort_name, arraySort_signature, F_S) \
|
||||
do_name( arraySort_name, "sort") \
|
||||
do_signature(arraySort_signature, "(Ljava/lang/Class;Ljava/lang/Object;JIILjava/util/DualPivotQuicksort$SortOperation;)V") \
|
||||
\
|
||||
do_intrinsic(_arrayPartition, java_util_DualPivotQuicksort, arrayPartition_name, arrayPartition_signature, F_S) \
|
||||
do_name( arrayPartition_name, "partition") \
|
||||
do_signature(arrayPartition_signature, "(Ljava/lang/Class;Ljava/lang/Object;JIIIILjava/util/DualPivotQuicksort$PartitionOperation;)[I") \
|
||||
\
|
||||
do_intrinsic(_copyOfRange, java_util_Arrays, copyOfRange_name, copyOfRange_signature, F_S) \
|
||||
do_name( copyOfRange_name, "copyOfRange") \
|
||||
do_signature(copyOfRange_signature, "([Ljava/lang/Object;IILjava/lang/Class;)[Ljava/lang/Object;") \
|
||||
|
@ -145,6 +145,7 @@ class SerializeClosure;
|
||||
template(java_util_Vector, "java/util/Vector") \
|
||||
template(java_util_AbstractList, "java/util/AbstractList") \
|
||||
template(java_util_Hashtable, "java/util/Hashtable") \
|
||||
template(java_util_DualPivotQuicksort, "java/util/DualPivotQuicksort") \
|
||||
template(java_lang_Compiler, "java/lang/Compiler") \
|
||||
template(jdk_internal_misc_Signal, "jdk/internal/misc/Signal") \
|
||||
template(jdk_internal_util_Preconditions, "jdk/internal/util/Preconditions") \
|
||||
|
@ -387,6 +387,12 @@ void ShenandoahBarrierC2Support::verify(RootNode* root) {
|
||||
verify_type t;
|
||||
} args[6];
|
||||
} calls[] = {
|
||||
"array_partition_stub",
|
||||
{ { TypeFunc::Parms, ShenandoahStore }, { TypeFunc::Parms+4, ShenandoahStore }, { -1, ShenandoahNone },
|
||||
{ -1, ShenandoahNone }, { -1, ShenandoahNone }, { -1, ShenandoahNone } },
|
||||
"arraysort_stub",
|
||||
{ { TypeFunc::Parms, ShenandoahStore }, { -1, ShenandoahNone }, { -1, ShenandoahNone },
|
||||
{ -1, ShenandoahNone}, { -1, ShenandoahNone}, { -1, ShenandoahNone} },
|
||||
"aescrypt_encryptBlock",
|
||||
{ { TypeFunc::Parms, ShenandoahLoad }, { TypeFunc::Parms+1, ShenandoahStore }, { TypeFunc::Parms+2, ShenandoahLoad },
|
||||
{ -1, ShenandoahNone}, { -1, ShenandoahNone}, { -1, ShenandoahNone} },
|
||||
|
@ -331,6 +331,8 @@
|
||||
static_field(StubRoutines, _checkcast_arraycopy_uninit, address) \
|
||||
static_field(StubRoutines, _unsafe_arraycopy, address) \
|
||||
static_field(StubRoutines, _generic_arraycopy, address) \
|
||||
static_field(StubRoutines, _array_sort, address) \
|
||||
static_field(StubRoutines, _array_partition, address) \
|
||||
\
|
||||
static_field(StubRoutines, _aescrypt_encryptBlock, address) \
|
||||
static_field(StubRoutines, _aescrypt_decryptBlock, address) \
|
||||
|
@ -614,6 +614,8 @@ bool C2Compiler::is_intrinsic_supported(vmIntrinsics::ID id) {
|
||||
case vmIntrinsics::_min_strict:
|
||||
case vmIntrinsics::_max_strict:
|
||||
case vmIntrinsics::_arraycopy:
|
||||
case vmIntrinsics::_arraySort:
|
||||
case vmIntrinsics::_arrayPartition:
|
||||
case vmIntrinsics::_indexOfL:
|
||||
case vmIntrinsics::_indexOfU:
|
||||
case vmIntrinsics::_indexOfUL:
|
||||
|
@ -1575,6 +1575,8 @@ void ConnectionGraph::process_call_arguments(CallNode *call) {
|
||||
strcmp(call->as_CallLeaf()->_name, "bigIntegerRightShiftWorker") == 0 ||
|
||||
strcmp(call->as_CallLeaf()->_name, "bigIntegerLeftShiftWorker") == 0 ||
|
||||
strcmp(call->as_CallLeaf()->_name, "vectorizedMismatch") == 0 ||
|
||||
strcmp(call->as_CallLeaf()->_name, "arraysort_stub") == 0 ||
|
||||
strcmp(call->as_CallLeaf()->_name, "array_partition_stub") == 0 ||
|
||||
strcmp(call->as_CallLeaf()->_name, "get_class_id_intrinsic") == 0)
|
||||
))) {
|
||||
call->dump();
|
||||
|
@ -293,6 +293,9 @@ bool LibraryCallKit::try_to_inline(int predicate) {
|
||||
|
||||
case vmIntrinsics::_arraycopy: return inline_arraycopy();
|
||||
|
||||
case vmIntrinsics::_arraySort: return inline_array_sort();
|
||||
case vmIntrinsics::_arrayPartition: return inline_array_partition();
|
||||
|
||||
case vmIntrinsics::_compareToL: return inline_string_compareTo(StrIntrinsicNode::LL);
|
||||
case vmIntrinsics::_compareToU: return inline_string_compareTo(StrIntrinsicNode::UU);
|
||||
case vmIntrinsics::_compareToLU: return inline_string_compareTo(StrIntrinsicNode::LU);
|
||||
@ -5361,6 +5364,101 @@ void LibraryCallKit::create_new_uncommon_trap(CallStaticJavaNode* uncommon_trap_
|
||||
uncommon_trap_call->set_req(0, top()); // not used anymore, kill it
|
||||
}
|
||||
|
||||
//------------------------------inline_array_partition-----------------------
|
||||
bool LibraryCallKit::inline_array_partition() {
|
||||
|
||||
const char *stubName = "array_partition_stub";
|
||||
|
||||
Node* elementType = null_check(argument(0));
|
||||
Node* obj = argument(1);
|
||||
Node* offset = argument(2);
|
||||
Node* fromIndex = argument(4);
|
||||
Node* toIndex = argument(5);
|
||||
Node* indexPivot1 = argument(6);
|
||||
Node* indexPivot2 = argument(7);
|
||||
|
||||
const TypeInstPtr* elem_klass = gvn().type(elementType)->isa_instptr();
|
||||
ciType* elem_type = elem_klass->const_oop()->as_instance()->java_mirror_type();
|
||||
BasicType bt = elem_type->basic_type();
|
||||
address stubAddr = nullptr;
|
||||
stubAddr = StubRoutines::select_array_partition_function();
|
||||
// stub not loaded
|
||||
if (stubAddr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
// get the address of the array
|
||||
const TypeAryPtr* obj_t = _gvn.type(obj)->isa_aryptr();
|
||||
if (obj_t == nullptr || obj_t->elem() == Type::BOTTOM ) {
|
||||
return false; // failed input validation
|
||||
}
|
||||
Node* obj_adr = make_unsafe_address(obj, offset);
|
||||
|
||||
// create the pivotIndices array of type int and size = 2
|
||||
Node* size = intcon(2);
|
||||
Node* klass_node = makecon(TypeKlassPtr::make(ciTypeArrayKlass::make(T_INT)));
|
||||
Node* pivotIndices = new_array(klass_node, size, 0); // no arguments to push
|
||||
AllocateArrayNode* alloc = tightly_coupled_allocation(pivotIndices);
|
||||
guarantee(alloc != nullptr, "created above");
|
||||
Node* pivotIndices_adr = basic_plus_adr(pivotIndices, arrayOopDesc::base_offset_in_bytes(T_INT));
|
||||
|
||||
// pass the basic type enum to the stub
|
||||
Node* elemType = intcon(bt);
|
||||
|
||||
// Call the stub
|
||||
make_runtime_call(RC_LEAF|RC_NO_FP, OptoRuntime::array_partition_Type(),
|
||||
stubAddr, stubName, TypePtr::BOTTOM,
|
||||
obj_adr, elemType, fromIndex, toIndex, pivotIndices_adr,
|
||||
indexPivot1, indexPivot2);
|
||||
|
||||
if (!stopped()) {
|
||||
set_result(pivotIndices);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
//------------------------------inline_array_sort-----------------------
|
||||
bool LibraryCallKit::inline_array_sort() {
|
||||
|
||||
const char *stubName;
|
||||
stubName = "arraysort_stub";
|
||||
|
||||
Node* elementType = null_check(argument(0));
|
||||
Node* obj = argument(1);
|
||||
Node* offset = argument(2);
|
||||
Node* fromIndex = argument(4);
|
||||
Node* toIndex = argument(5);
|
||||
|
||||
const TypeInstPtr* elem_klass = gvn().type(elementType)->isa_instptr();
|
||||
ciType* elem_type = elem_klass->const_oop()->as_instance()->java_mirror_type();
|
||||
BasicType bt = elem_type->basic_type();
|
||||
address stubAddr = nullptr;
|
||||
stubAddr = StubRoutines::select_arraysort_function();
|
||||
//stub not loaded
|
||||
if (stubAddr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// get address of the array
|
||||
const TypeAryPtr* obj_t = _gvn.type(obj)->isa_aryptr();
|
||||
if (obj_t == nullptr || obj_t->elem() == Type::BOTTOM ) {
|
||||
return false; // failed input validation
|
||||
}
|
||||
Node* obj_adr = make_unsafe_address(obj, offset);
|
||||
|
||||
// pass the basic type enum to the stub
|
||||
Node* elemType = intcon(bt);
|
||||
|
||||
// Call the stub.
|
||||
make_runtime_call(RC_LEAF|RC_NO_FP, OptoRuntime::array_sort_Type(),
|
||||
stubAddr, stubName, TypePtr::BOTTOM,
|
||||
obj_adr, elemType, fromIndex, toIndex);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
//------------------------------inline_arraycopy-----------------------
|
||||
// public static native void java.lang.System.arraycopy(Object src, int srcPos,
|
||||
// Object dest, int destPos,
|
||||
|
@ -277,7 +277,8 @@ class LibraryCallKit : public GraphKit {
|
||||
JVMState* arraycopy_restore_alloc_state(AllocateArrayNode* alloc, int& saved_reexecute_sp);
|
||||
void arraycopy_move_allocation_here(AllocateArrayNode* alloc, Node* dest, JVMState* saved_jvms_before_guards, int saved_reexecute_sp,
|
||||
uint new_idx);
|
||||
|
||||
bool inline_array_sort();
|
||||
bool inline_array_partition();
|
||||
typedef enum { LS_get_add, LS_get_set, LS_cmp_swap, LS_cmp_swap_weak, LS_cmp_exchange } LoadStoreKind;
|
||||
bool inline_unsafe_load_store(BasicType type, LoadStoreKind kind, AccessKind access_kind);
|
||||
bool inline_unsafe_fence(vmIntrinsics::ID id);
|
||||
|
@ -857,6 +857,49 @@ const TypeFunc* OptoRuntime::array_fill_Type() {
|
||||
return TypeFunc::make(domain, range);
|
||||
}
|
||||
|
||||
const TypeFunc* OptoRuntime::array_partition_Type() {
|
||||
// create input type (domain)
|
||||
int num_args = 7;
|
||||
int argcnt = num_args;
|
||||
const Type** fields = TypeTuple::fields(argcnt);
|
||||
int argp = TypeFunc::Parms;
|
||||
fields[argp++] = TypePtr::NOTNULL; // array
|
||||
fields[argp++] = TypeInt::INT; // element type
|
||||
fields[argp++] = TypeInt::INT; // low
|
||||
fields[argp++] = TypeInt::INT; // end
|
||||
fields[argp++] = TypePtr::NOTNULL; // pivot_indices (int array)
|
||||
fields[argp++] = TypeInt::INT; // indexPivot1
|
||||
fields[argp++] = TypeInt::INT; // indexPivot2
|
||||
assert(argp == TypeFunc::Parms+argcnt, "correct decoding");
|
||||
const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms+argcnt, fields);
|
||||
|
||||
// no result type needed
|
||||
fields = TypeTuple::fields(1);
|
||||
fields[TypeFunc::Parms+0] = nullptr; // void
|
||||
const TypeTuple* range = TypeTuple::make(TypeFunc::Parms, fields);
|
||||
return TypeFunc::make(domain, range);
|
||||
}
|
||||
|
||||
const TypeFunc* OptoRuntime::array_sort_Type() {
|
||||
// create input type (domain)
|
||||
int num_args = 4;
|
||||
int argcnt = num_args;
|
||||
const Type** fields = TypeTuple::fields(argcnt);
|
||||
int argp = TypeFunc::Parms;
|
||||
fields[argp++] = TypePtr::NOTNULL; // array
|
||||
fields[argp++] = TypeInt::INT; // element type
|
||||
fields[argp++] = TypeInt::INT; // fromIndex
|
||||
fields[argp++] = TypeInt::INT; // toIndex
|
||||
assert(argp == TypeFunc::Parms+argcnt, "correct decoding");
|
||||
const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms+argcnt, fields);
|
||||
|
||||
// no result type needed
|
||||
fields = TypeTuple::fields(1);
|
||||
fields[TypeFunc::Parms+0] = nullptr; // void
|
||||
const TypeTuple* range = TypeTuple::make(TypeFunc::Parms, fields);
|
||||
return TypeFunc::make(domain, range);
|
||||
}
|
||||
|
||||
// for aescrypt encrypt/decrypt operations, just three pointers returning void (length is constant)
|
||||
const TypeFunc* OptoRuntime::aescrypt_block_Type() {
|
||||
// create input type (domain)
|
||||
|
@ -268,6 +268,8 @@ private:
|
||||
|
||||
static const TypeFunc* array_fill_Type();
|
||||
|
||||
static const TypeFunc* array_sort_Type();
|
||||
static const TypeFunc* array_partition_Type();
|
||||
static const TypeFunc* aescrypt_block_Type();
|
||||
static const TypeFunc* cipherBlockChaining_aescrypt_Type();
|
||||
static const TypeFunc* electronicCodeBook_aescrypt_Type();
|
||||
|
@ -176,6 +176,9 @@ address StubRoutines::_hf2f = nullptr;
|
||||
address StubRoutines::_vector_f_math[VectorSupport::NUM_VEC_SIZES][VectorSupport::NUM_SVML_OP] = {{nullptr}, {nullptr}};
|
||||
address StubRoutines::_vector_d_math[VectorSupport::NUM_VEC_SIZES][VectorSupport::NUM_SVML_OP] = {{nullptr}, {nullptr}};
|
||||
|
||||
address StubRoutines::_array_sort = nullptr;
|
||||
address StubRoutines::_array_partition = nullptr;
|
||||
|
||||
address StubRoutines::_cont_thaw = nullptr;
|
||||
address StubRoutines::_cont_returnBarrier = nullptr;
|
||||
address StubRoutines::_cont_returnBarrierExc = nullptr;
|
||||
|
@ -153,6 +153,8 @@ class StubRoutines: AllStatic {
|
||||
static BufferBlob* _compiler_stubs_code; // code buffer for C2 intrinsics
|
||||
static BufferBlob* _final_stubs_code; // code buffer for all other routines
|
||||
|
||||
static address _array_sort;
|
||||
static address _array_partition;
|
||||
// Leaf routines which implement arraycopy and their addresses
|
||||
// arraycopy operands aligned on element type boundary
|
||||
static address _jbyte_arraycopy;
|
||||
@ -375,6 +377,8 @@ class StubRoutines: AllStatic {
|
||||
static UnsafeArrayCopyStub UnsafeArrayCopy_stub() { return CAST_TO_FN_PTR(UnsafeArrayCopyStub, _unsafe_arraycopy); }
|
||||
|
||||
static address generic_arraycopy() { return _generic_arraycopy; }
|
||||
static address select_arraysort_function() { return _array_sort; }
|
||||
static address select_array_partition_function() { return _array_partition; }
|
||||
|
||||
static address jbyte_fill() { return _jbyte_fill; }
|
||||
static address jshort_fill() { return _jshort_fill; }
|
||||
|
441
src/java.base/linux/native/libsimdsort/avx512-32bit-qsort.hpp
Normal file
441
src/java.base/linux/native/libsimdsort/avx512-32bit-qsort.hpp
Normal file
@ -0,0 +1,441 @@
|
||||
/*
|
||||
* Copyright (c) 2021, 2023, Intel Corporation. All rights reserved.
|
||||
* Copyright (c) 2021 Serge Sans Paille. All rights reserved.
|
||||
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||||
*
|
||||
* This code is free software; you can redistribute it and/or modify it
|
||||
* under the terms of the GNU General Public License version 2 only, as
|
||||
* published by the Free Software Foundation.
|
||||
*
|
||||
* This code is distributed in the hope that it will be useful, but WITHOUT
|
||||
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
|
||||
* version 2 for more details (a copy is included in the LICENSE file that
|
||||
* accompanied this code).
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License version
|
||||
* 2 along with this work; if not, write to the Free Software Foundation,
|
||||
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||
*
|
||||
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
|
||||
* or visit www.oracle.com if you need additional information or have any
|
||||
* questions.
|
||||
*
|
||||
*/
|
||||
|
||||
// This implementation is based on x86-simd-sort(https://github.com/intel/x86-simd-sort)
|
||||
|
||||
#ifndef AVX512_QSORT_32BIT
|
||||
#define AVX512_QSORT_32BIT
|
||||
|
||||
#include "avx512-common-qsort.h"
|
||||
|
||||
/*
|
||||
* Constants used in sorting 16 elements in a ZMM registers. Based on Bitonic
|
||||
* sorting network (see
|
||||
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg)
|
||||
*/
|
||||
#define NETWORK_32BIT_1 14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1
|
||||
#define NETWORK_32BIT_2 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3
|
||||
#define NETWORK_32BIT_3 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7
|
||||
#define NETWORK_32BIT_4 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2
|
||||
#define NETWORK_32BIT_5 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
|
||||
#define NETWORK_32BIT_6 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4
|
||||
#define NETWORK_32BIT_7 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8
|
||||
|
||||
template <>
|
||||
struct zmm_vector<int32_t> {
|
||||
using type_t = int32_t;
|
||||
using zmm_t = __m512i;
|
||||
using ymm_t = __m256i;
|
||||
using opmask_t = __mmask16;
|
||||
static const uint8_t numlanes = 16;
|
||||
|
||||
static type_t type_max() { return X86_SIMD_SORT_MAX_INT32; }
|
||||
static type_t type_min() { return X86_SIMD_SORT_MIN_INT32; }
|
||||
static zmm_t zmm_max() { return _mm512_set1_epi32(type_max()); }
|
||||
|
||||
static opmask_t knot_opmask(opmask_t x) { return _mm512_knot(x); }
|
||||
static opmask_t ge(zmm_t x, zmm_t y) {
|
||||
return _mm512_cmp_epi32_mask(x, y, _MM_CMPINT_NLT);
|
||||
}
|
||||
static opmask_t gt(zmm_t x, zmm_t y) {
|
||||
return _mm512_cmp_epi32_mask(x, y, _MM_CMPINT_GT);
|
||||
}
|
||||
template <int scale>
|
||||
static ymm_t i64gather(__m512i index, void const *base) {
|
||||
return _mm512_i64gather_epi32(index, base, scale);
|
||||
}
|
||||
static zmm_t merge(ymm_t y1, ymm_t y2) {
|
||||
zmm_t z1 = _mm512_castsi256_si512(y1);
|
||||
return _mm512_inserti32x8(z1, y2, 1);
|
||||
}
|
||||
static zmm_t loadu(void const *mem) { return _mm512_loadu_si512(mem); }
|
||||
static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) {
|
||||
return _mm512_mask_compressstoreu_epi32(mem, mask, x);
|
||||
}
|
||||
static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) {
|
||||
return _mm512_mask_loadu_epi32(x, mask, mem);
|
||||
}
|
||||
static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) {
|
||||
return _mm512_mask_mov_epi32(x, mask, y);
|
||||
}
|
||||
static void mask_storeu(void *mem, opmask_t mask, zmm_t x) {
|
||||
return _mm512_mask_storeu_epi32(mem, mask, x);
|
||||
}
|
||||
static zmm_t min(zmm_t x, zmm_t y) { return _mm512_min_epi32(x, y); }
|
||||
static zmm_t max(zmm_t x, zmm_t y) { return _mm512_max_epi32(x, y); }
|
||||
static zmm_t permutexvar(__m512i idx, zmm_t zmm) {
|
||||
return _mm512_permutexvar_epi32(idx, zmm);
|
||||
}
|
||||
static type_t reducemax(zmm_t v) { return _mm512_reduce_max_epi32(v); }
|
||||
static type_t reducemin(zmm_t v) { return _mm512_reduce_min_epi32(v); }
|
||||
static zmm_t set1(type_t v) { return _mm512_set1_epi32(v); }
|
||||
template <uint8_t mask>
|
||||
static zmm_t shuffle(zmm_t zmm) {
|
||||
return _mm512_shuffle_epi32(zmm, (_MM_PERM_ENUM)mask);
|
||||
}
|
||||
static void storeu(void *mem, zmm_t x) {
|
||||
return _mm512_storeu_si512(mem, x);
|
||||
}
|
||||
|
||||
static ymm_t max(ymm_t x, ymm_t y) { return _mm256_max_epi32(x, y); }
|
||||
static ymm_t min(ymm_t x, ymm_t y) { return _mm256_min_epi32(x, y); }
|
||||
};
|
||||
template <>
|
||||
struct zmm_vector<float> {
|
||||
using type_t = float;
|
||||
using zmm_t = __m512;
|
||||
using ymm_t = __m256;
|
||||
using opmask_t = __mmask16;
|
||||
static const uint8_t numlanes = 16;
|
||||
|
||||
static type_t type_max() { return X86_SIMD_SORT_INFINITYF; }
|
||||
static type_t type_min() { return -X86_SIMD_SORT_INFINITYF; }
|
||||
static zmm_t zmm_max() { return _mm512_set1_ps(type_max()); }
|
||||
|
||||
static opmask_t knot_opmask(opmask_t x) { return _mm512_knot(x); }
|
||||
static opmask_t ge(zmm_t x, zmm_t y) {
|
||||
return _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ);
|
||||
}
|
||||
static opmask_t gt(zmm_t x, zmm_t y) {
|
||||
return _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ);
|
||||
}
|
||||
template <int scale>
|
||||
static ymm_t i64gather(__m512i index, void const *base) {
|
||||
return _mm512_i64gather_ps(index, base, scale);
|
||||
}
|
||||
static zmm_t merge(ymm_t y1, ymm_t y2) {
|
||||
zmm_t z1 = _mm512_castsi512_ps(
|
||||
_mm512_castsi256_si512(_mm256_castps_si256(y1)));
|
||||
return _mm512_insertf32x8(z1, y2, 1);
|
||||
}
|
||||
static zmm_t loadu(void const *mem) { return _mm512_loadu_ps(mem); }
|
||||
static zmm_t max(zmm_t x, zmm_t y) { return _mm512_max_ps(x, y); }
|
||||
static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) {
|
||||
return _mm512_mask_compressstoreu_ps(mem, mask, x);
|
||||
}
|
||||
static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) {
|
||||
return _mm512_mask_loadu_ps(x, mask, mem);
|
||||
}
|
||||
static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) {
|
||||
return _mm512_mask_mov_ps(x, mask, y);
|
||||
}
|
||||
static void mask_storeu(void *mem, opmask_t mask, zmm_t x) {
|
||||
return _mm512_mask_storeu_ps(mem, mask, x);
|
||||
}
|
||||
static zmm_t min(zmm_t x, zmm_t y) { return _mm512_min_ps(x, y); }
|
||||
static zmm_t permutexvar(__m512i idx, zmm_t zmm) {
|
||||
return _mm512_permutexvar_ps(idx, zmm);
|
||||
}
|
||||
static type_t reducemax(zmm_t v) { return _mm512_reduce_max_ps(v); }
|
||||
static type_t reducemin(zmm_t v) { return _mm512_reduce_min_ps(v); }
|
||||
static zmm_t set1(type_t v) { return _mm512_set1_ps(v); }
|
||||
template <uint8_t mask>
|
||||
static zmm_t shuffle(zmm_t zmm) {
|
||||
return _mm512_shuffle_ps(zmm, zmm, (_MM_PERM_ENUM)mask);
|
||||
}
|
||||
static void storeu(void *mem, zmm_t x) { return _mm512_storeu_ps(mem, x); }
|
||||
|
||||
static ymm_t max(ymm_t x, ymm_t y) { return _mm256_max_ps(x, y); }
|
||||
static ymm_t min(ymm_t x, ymm_t y) { return _mm256_min_ps(x, y); }
|
||||
};
|
||||
|
||||
/*
|
||||
* Assumes zmm is random and performs a full sorting network defined in
|
||||
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg
|
||||
*/
|
||||
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
|
||||
X86_SIMD_SORT_INLINE zmm_t sort_zmm_32bit(zmm_t zmm) {
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(zmm), 0xAAAA);
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::template shuffle<SHUFFLE_MASK(0, 1, 2, 3)>(zmm), 0xCCCC);
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(zmm), 0xAAAA);
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_3), zmm),
|
||||
0xF0F0);
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(zmm), 0xCCCC);
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(zmm), 0xAAAA);
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm),
|
||||
0xFF00);
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_6), zmm),
|
||||
0xF0F0);
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(zmm), 0xCCCC);
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(zmm), 0xAAAA);
|
||||
return zmm;
|
||||
}
|
||||
|
||||
// Assumes zmm is bitonic and performs a recursive half cleaner
|
||||
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
|
||||
X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_32bit(zmm_t zmm) {
|
||||
// 1) half_cleaner[16]: compare 1-9, 2-10, 3-11 etc ..
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_7), zmm),
|
||||
0xFF00);
|
||||
// 2) half_cleaner[8]: compare 1-5, 2-6, 3-7 etc ..
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_6), zmm),
|
||||
0xF0F0);
|
||||
// 3) half_cleaner[4]
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(zmm), 0xCCCC);
|
||||
// 3) half_cleaner[1]
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(zmm), 0xAAAA);
|
||||
return zmm;
|
||||
}
|
||||
|
||||
// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner
|
||||
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
|
||||
X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_32bit(zmm_t *zmm1,
|
||||
zmm_t *zmm2) {
|
||||
// 1) First step of a merging network: coex of zmm1 and zmm2 reversed
|
||||
*zmm2 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), *zmm2);
|
||||
zmm_t zmm3 = vtype::min(*zmm1, *zmm2);
|
||||
zmm_t zmm4 = vtype::max(*zmm1, *zmm2);
|
||||
// 2) Recursive half cleaner for each
|
||||
*zmm1 = bitonic_merge_zmm_32bit<vtype>(zmm3);
|
||||
*zmm2 = bitonic_merge_zmm_32bit<vtype>(zmm4);
|
||||
}
|
||||
|
||||
// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive
|
||||
// half cleaner
|
||||
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
|
||||
X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_32bit(zmm_t *zmm) {
|
||||
zmm_t zmm2r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[2]);
|
||||
zmm_t zmm3r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[3]);
|
||||
zmm_t zmm_t1 = vtype::min(zmm[0], zmm3r);
|
||||
zmm_t zmm_t2 = vtype::min(zmm[1], zmm2r);
|
||||
zmm_t zmm_t3 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5),
|
||||
vtype::max(zmm[1], zmm2r));
|
||||
zmm_t zmm_t4 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5),
|
||||
vtype::max(zmm[0], zmm3r));
|
||||
zmm_t zmm0 = vtype::min(zmm_t1, zmm_t2);
|
||||
zmm_t zmm1 = vtype::max(zmm_t1, zmm_t2);
|
||||
zmm_t zmm2 = vtype::min(zmm_t3, zmm_t4);
|
||||
zmm_t zmm3 = vtype::max(zmm_t3, zmm_t4);
|
||||
zmm[0] = bitonic_merge_zmm_32bit<vtype>(zmm0);
|
||||
zmm[1] = bitonic_merge_zmm_32bit<vtype>(zmm1);
|
||||
zmm[2] = bitonic_merge_zmm_32bit<vtype>(zmm2);
|
||||
zmm[3] = bitonic_merge_zmm_32bit<vtype>(zmm3);
|
||||
}
|
||||
|
||||
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
|
||||
X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_32bit(zmm_t *zmm) {
|
||||
zmm_t zmm4r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[4]);
|
||||
zmm_t zmm5r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[5]);
|
||||
zmm_t zmm6r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[6]);
|
||||
zmm_t zmm7r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[7]);
|
||||
zmm_t zmm_t1 = vtype::min(zmm[0], zmm7r);
|
||||
zmm_t zmm_t2 = vtype::min(zmm[1], zmm6r);
|
||||
zmm_t zmm_t3 = vtype::min(zmm[2], zmm5r);
|
||||
zmm_t zmm_t4 = vtype::min(zmm[3], zmm4r);
|
||||
zmm_t zmm_t5 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5),
|
||||
vtype::max(zmm[3], zmm4r));
|
||||
zmm_t zmm_t6 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5),
|
||||
vtype::max(zmm[2], zmm5r));
|
||||
zmm_t zmm_t7 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5),
|
||||
vtype::max(zmm[1], zmm6r));
|
||||
zmm_t zmm_t8 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5),
|
||||
vtype::max(zmm[0], zmm7r));
|
||||
COEX<vtype>(zmm_t1, zmm_t3);
|
||||
COEX<vtype>(zmm_t2, zmm_t4);
|
||||
COEX<vtype>(zmm_t5, zmm_t7);
|
||||
COEX<vtype>(zmm_t6, zmm_t8);
|
||||
COEX<vtype>(zmm_t1, zmm_t2);
|
||||
COEX<vtype>(zmm_t3, zmm_t4);
|
||||
COEX<vtype>(zmm_t5, zmm_t6);
|
||||
COEX<vtype>(zmm_t7, zmm_t8);
|
||||
zmm[0] = bitonic_merge_zmm_32bit<vtype>(zmm_t1);
|
||||
zmm[1] = bitonic_merge_zmm_32bit<vtype>(zmm_t2);
|
||||
zmm[2] = bitonic_merge_zmm_32bit<vtype>(zmm_t3);
|
||||
zmm[3] = bitonic_merge_zmm_32bit<vtype>(zmm_t4);
|
||||
zmm[4] = bitonic_merge_zmm_32bit<vtype>(zmm_t5);
|
||||
zmm[5] = bitonic_merge_zmm_32bit<vtype>(zmm_t6);
|
||||
zmm[6] = bitonic_merge_zmm_32bit<vtype>(zmm_t7);
|
||||
zmm[7] = bitonic_merge_zmm_32bit<vtype>(zmm_t8);
|
||||
}
|
||||
|
||||
template <typename vtype, typename type_t>
|
||||
X86_SIMD_SORT_INLINE void sort_16_32bit(type_t *arr, int32_t N) {
|
||||
typename vtype::opmask_t load_mask = (0x0001 << N) - 0x0001;
|
||||
typename vtype::zmm_t zmm =
|
||||
vtype::mask_loadu(vtype::zmm_max(), load_mask, arr);
|
||||
vtype::mask_storeu(arr, load_mask, sort_zmm_32bit<vtype>(zmm));
|
||||
}
|
||||
|
||||
template <typename vtype, typename type_t>
|
||||
X86_SIMD_SORT_INLINE void sort_32_32bit(type_t *arr, int32_t N) {
|
||||
if (N <= 16) {
|
||||
sort_16_32bit<vtype>(arr, N);
|
||||
return;
|
||||
}
|
||||
using zmm_t = typename vtype::zmm_t;
|
||||
zmm_t zmm1 = vtype::loadu(arr);
|
||||
typename vtype::opmask_t load_mask = (0x0001 << (N - 16)) - 0x0001;
|
||||
zmm_t zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr + 16);
|
||||
zmm1 = sort_zmm_32bit<vtype>(zmm1);
|
||||
zmm2 = sort_zmm_32bit<vtype>(zmm2);
|
||||
bitonic_merge_two_zmm_32bit<vtype>(&zmm1, &zmm2);
|
||||
vtype::storeu(arr, zmm1);
|
||||
vtype::mask_storeu(arr + 16, load_mask, zmm2);
|
||||
}
|
||||
|
||||
template <typename vtype, typename type_t>
|
||||
X86_SIMD_SORT_INLINE void sort_64_32bit(type_t *arr, int32_t N) {
|
||||
if (N <= 32) {
|
||||
sort_32_32bit<vtype>(arr, N);
|
||||
return;
|
||||
}
|
||||
using zmm_t = typename vtype::zmm_t;
|
||||
using opmask_t = typename vtype::opmask_t;
|
||||
zmm_t zmm[4];
|
||||
zmm[0] = vtype::loadu(arr);
|
||||
zmm[1] = vtype::loadu(arr + 16);
|
||||
opmask_t load_mask1 = 0xFFFF, load_mask2 = 0xFFFF;
|
||||
uint64_t combined_mask = (0x1ull << (N - 32)) - 0x1ull;
|
||||
load_mask1 &= combined_mask & 0xFFFF;
|
||||
load_mask2 &= (combined_mask >> 16) & 0xFFFF;
|
||||
zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 32);
|
||||
zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 48);
|
||||
zmm[0] = sort_zmm_32bit<vtype>(zmm[0]);
|
||||
zmm[1] = sort_zmm_32bit<vtype>(zmm[1]);
|
||||
zmm[2] = sort_zmm_32bit<vtype>(zmm[2]);
|
||||
zmm[3] = sort_zmm_32bit<vtype>(zmm[3]);
|
||||
bitonic_merge_two_zmm_32bit<vtype>(&zmm[0], &zmm[1]);
|
||||
bitonic_merge_two_zmm_32bit<vtype>(&zmm[2], &zmm[3]);
|
||||
bitonic_merge_four_zmm_32bit<vtype>(zmm);
|
||||
vtype::storeu(arr, zmm[0]);
|
||||
vtype::storeu(arr + 16, zmm[1]);
|
||||
vtype::mask_storeu(arr + 32, load_mask1, zmm[2]);
|
||||
vtype::mask_storeu(arr + 48, load_mask2, zmm[3]);
|
||||
}
|
||||
|
||||
template <typename vtype, typename type_t>
|
||||
X86_SIMD_SORT_INLINE void sort_128_32bit(type_t *arr, int32_t N) {
|
||||
if (N <= 64) {
|
||||
sort_64_32bit<vtype>(arr, N);
|
||||
return;
|
||||
}
|
||||
using zmm_t = typename vtype::zmm_t;
|
||||
using opmask_t = typename vtype::opmask_t;
|
||||
zmm_t zmm[8];
|
||||
zmm[0] = vtype::loadu(arr);
|
||||
zmm[1] = vtype::loadu(arr + 16);
|
||||
zmm[2] = vtype::loadu(arr + 32);
|
||||
zmm[3] = vtype::loadu(arr + 48);
|
||||
zmm[0] = sort_zmm_32bit<vtype>(zmm[0]);
|
||||
zmm[1] = sort_zmm_32bit<vtype>(zmm[1]);
|
||||
zmm[2] = sort_zmm_32bit<vtype>(zmm[2]);
|
||||
zmm[3] = sort_zmm_32bit<vtype>(zmm[3]);
|
||||
opmask_t load_mask1 = 0xFFFF, load_mask2 = 0xFFFF;
|
||||
opmask_t load_mask3 = 0xFFFF, load_mask4 = 0xFFFF;
|
||||
if (N != 128) {
|
||||
uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull;
|
||||
load_mask1 &= combined_mask & 0xFFFF;
|
||||
load_mask2 &= (combined_mask >> 16) & 0xFFFF;
|
||||
load_mask3 &= (combined_mask >> 32) & 0xFFFF;
|
||||
load_mask4 &= (combined_mask >> 48) & 0xFFFF;
|
||||
}
|
||||
zmm[4] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 64);
|
||||
zmm[5] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 80);
|
||||
zmm[6] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 96);
|
||||
zmm[7] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 112);
|
||||
zmm[4] = sort_zmm_32bit<vtype>(zmm[4]);
|
||||
zmm[5] = sort_zmm_32bit<vtype>(zmm[5]);
|
||||
zmm[6] = sort_zmm_32bit<vtype>(zmm[6]);
|
||||
zmm[7] = sort_zmm_32bit<vtype>(zmm[7]);
|
||||
bitonic_merge_two_zmm_32bit<vtype>(&zmm[0], &zmm[1]);
|
||||
bitonic_merge_two_zmm_32bit<vtype>(&zmm[2], &zmm[3]);
|
||||
bitonic_merge_two_zmm_32bit<vtype>(&zmm[4], &zmm[5]);
|
||||
bitonic_merge_two_zmm_32bit<vtype>(&zmm[6], &zmm[7]);
|
||||
bitonic_merge_four_zmm_32bit<vtype>(zmm);
|
||||
bitonic_merge_four_zmm_32bit<vtype>(zmm + 4);
|
||||
bitonic_merge_eight_zmm_32bit<vtype>(zmm);
|
||||
vtype::storeu(arr, zmm[0]);
|
||||
vtype::storeu(arr + 16, zmm[1]);
|
||||
vtype::storeu(arr + 32, zmm[2]);
|
||||
vtype::storeu(arr + 48, zmm[3]);
|
||||
vtype::mask_storeu(arr + 64, load_mask1, zmm[4]);
|
||||
vtype::mask_storeu(arr + 80, load_mask2, zmm[5]);
|
||||
vtype::mask_storeu(arr + 96, load_mask3, zmm[6]);
|
||||
vtype::mask_storeu(arr + 112, load_mask4, zmm[7]);
|
||||
}
|
||||
|
||||
|
||||
template <typename vtype, typename type_t>
|
||||
static void qsort_32bit_(type_t *arr, int64_t left, int64_t right,
|
||||
int64_t max_iters) {
|
||||
/*
|
||||
* Resort to std::sort if quicksort isnt making any progress
|
||||
*/
|
||||
if (max_iters <= 0) {
|
||||
std::sort(arr + left, arr + right + 1);
|
||||
return;
|
||||
}
|
||||
/*
|
||||
* Base case: use bitonic networks to sort arrays <= 128
|
||||
*/
|
||||
if (right + 1 - left <= 128) {
|
||||
sort_128_32bit<vtype>(arr + left, (int32_t)(right + 1 - left));
|
||||
return;
|
||||
}
|
||||
|
||||
type_t pivot = get_pivot_scalar<type_t>(arr, left, right);
|
||||
type_t smallest = vtype::type_max();
|
||||
type_t biggest = vtype::type_min();
|
||||
int64_t pivot_index = partition_avx512_unrolled<vtype, 2>(
|
||||
arr, left, right + 1, pivot, &smallest, &biggest, false);
|
||||
if (pivot != smallest)
|
||||
qsort_32bit_<vtype>(arr, left, pivot_index - 1, max_iters - 1);
|
||||
if (pivot != biggest)
|
||||
qsort_32bit_<vtype>(arr, pivot_index, right, max_iters - 1);
|
||||
}
|
||||
|
||||
template <>
|
||||
void inline avx512_qsort<int32_t>(int32_t *arr, int64_t fromIndex, int64_t toIndex) {
|
||||
int64_t arrsize = toIndex - fromIndex;
|
||||
if (arrsize > 1) {
|
||||
qsort_32bit_<zmm_vector<int32_t>, int32_t>(arr, fromIndex, toIndex - 1,
|
||||
2 * (int64_t)log2(arrsize));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void inline avx512_qsort<float>(float *arr, int64_t fromIndex, int64_t toIndex) {
|
||||
int64_t arrsize = toIndex - fromIndex;
|
||||
if (arrsize > 1) {
|
||||
qsort_32bit_<zmm_vector<float>, float>(arr, fromIndex, toIndex - 1,
|
||||
2 * (int64_t)log2(arrsize));
|
||||
}
|
||||
}
|
||||
|
||||
#endif // AVX512_QSORT_32BIT
|
212
src/java.base/linux/native/libsimdsort/avx512-64bit-common.h
Normal file
212
src/java.base/linux/native/libsimdsort/avx512-64bit-common.h
Normal file
@ -0,0 +1,212 @@
|
||||
/*
|
||||
* Copyright (c) 2021, 2023, Intel Corporation. All rights reserved.
|
||||
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||||
*
|
||||
* This code is free software; you can redistribute it and/or modify it
|
||||
* under the terms of the GNU General Public License version 2 only, as
|
||||
* published by the Free Software Foundation.
|
||||
*
|
||||
* This code is distributed in the hope that it will be useful, but WITHOUT
|
||||
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
|
||||
* version 2 for more details (a copy is included in the LICENSE file that
|
||||
* accompanied this code).
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License version
|
||||
* 2 along with this work; if not, write to the Free Software Foundation,
|
||||
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||
*
|
||||
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
|
||||
* or visit www.oracle.com if you need additional information or have any
|
||||
* questions.
|
||||
*
|
||||
*/
|
||||
|
||||
// This implementation is based on x86-simd-sort(https://github.com/intel/x86-simd-sort)
|
||||
|
||||
#ifndef AVX512_64BIT_COMMON
|
||||
#define AVX512_64BIT_COMMON
|
||||
#include "avx512-common-qsort.h"
|
||||
|
||||
/*
|
||||
* Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic
|
||||
* sorting network (see
|
||||
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg)
|
||||
*/
|
||||
// ZMM 7, 6, 5, 4, 3, 2, 1, 0
|
||||
#define NETWORK_64BIT_1 4, 5, 6, 7, 0, 1, 2, 3
|
||||
#define NETWORK_64BIT_2 0, 1, 2, 3, 4, 5, 6, 7
|
||||
#define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2
|
||||
#define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4
|
||||
|
||||
template <>
|
||||
struct zmm_vector<int64_t> {
|
||||
using type_t = int64_t;
|
||||
using zmm_t = __m512i;
|
||||
using zmmi_t = __m512i;
|
||||
using ymm_t = __m512i;
|
||||
using opmask_t = __mmask8;
|
||||
static const uint8_t numlanes = 8;
|
||||
|
||||
static type_t type_max() { return X86_SIMD_SORT_MAX_INT64; }
|
||||
static type_t type_min() { return X86_SIMD_SORT_MIN_INT64; }
|
||||
static zmm_t zmm_max() {
|
||||
return _mm512_set1_epi64(type_max());
|
||||
} // TODO: this should broadcast bits as is?
|
||||
|
||||
static zmmi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7,
|
||||
int v8) {
|
||||
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
|
||||
}
|
||||
static opmask_t kxor_opmask(opmask_t x, opmask_t y) {
|
||||
return _kxor_mask8(x, y);
|
||||
}
|
||||
static opmask_t knot_opmask(opmask_t x) { return _knot_mask8(x); }
|
||||
static opmask_t le(zmm_t x, zmm_t y) {
|
||||
return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_LE);
|
||||
}
|
||||
static opmask_t ge(zmm_t x, zmm_t y) {
|
||||
return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_NLT);
|
||||
}
|
||||
static opmask_t gt(zmm_t x, zmm_t y) {
|
||||
return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_GT);
|
||||
}
|
||||
static opmask_t eq(zmm_t x, zmm_t y) {
|
||||
return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_EQ);
|
||||
}
|
||||
template <int scale>
|
||||
static zmm_t mask_i64gather(zmm_t src, opmask_t mask, __m512i index,
|
||||
void const *base) {
|
||||
return _mm512_mask_i64gather_epi64(src, mask, index, base, scale);
|
||||
}
|
||||
template <int scale>
|
||||
static zmm_t i64gather(__m512i index, void const *base) {
|
||||
return _mm512_i64gather_epi64(index, base, scale);
|
||||
}
|
||||
static zmm_t loadu(void const *mem) { return _mm512_loadu_si512(mem); }
|
||||
static zmm_t max(zmm_t x, zmm_t y) { return _mm512_max_epi64(x, y); }
|
||||
static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) {
|
||||
return _mm512_mask_compressstoreu_epi64(mem, mask, x);
|
||||
}
|
||||
static zmm_t maskz_loadu(opmask_t mask, void const *mem) {
|
||||
return _mm512_maskz_loadu_epi64(mask, mem);
|
||||
}
|
||||
static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) {
|
||||
return _mm512_mask_loadu_epi64(x, mask, mem);
|
||||
}
|
||||
static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) {
|
||||
return _mm512_mask_mov_epi64(x, mask, y);
|
||||
}
|
||||
static void mask_storeu(void *mem, opmask_t mask, zmm_t x) {
|
||||
return _mm512_mask_storeu_epi64(mem, mask, x);
|
||||
}
|
||||
static zmm_t min(zmm_t x, zmm_t y) { return _mm512_min_epi64(x, y); }
|
||||
static zmm_t permutexvar(__m512i idx, zmm_t zmm) {
|
||||
return _mm512_permutexvar_epi64(idx, zmm);
|
||||
}
|
||||
static type_t reducemax(zmm_t v) { return _mm512_reduce_max_epi64(v); }
|
||||
static type_t reducemin(zmm_t v) { return _mm512_reduce_min_epi64(v); }
|
||||
static zmm_t set1(type_t v) { return _mm512_set1_epi64(v); }
|
||||
template <uint8_t mask>
|
||||
static zmm_t shuffle(zmm_t zmm) {
|
||||
__m512d temp = _mm512_castsi512_pd(zmm);
|
||||
return _mm512_castpd_si512(
|
||||
_mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask));
|
||||
}
|
||||
static void storeu(void *mem, zmm_t x) { _mm512_storeu_si512(mem, x); }
|
||||
};
|
||||
template <>
|
||||
struct zmm_vector<double> {
|
||||
using type_t = double;
|
||||
using zmm_t = __m512d;
|
||||
using zmmi_t = __m512i;
|
||||
using ymm_t = __m512d;
|
||||
using opmask_t = __mmask8;
|
||||
static const uint8_t numlanes = 8;
|
||||
|
||||
static type_t type_max() { return X86_SIMD_SORT_INFINITY; }
|
||||
static type_t type_min() { return -X86_SIMD_SORT_INFINITY; }
|
||||
static zmm_t zmm_max() { return _mm512_set1_pd(type_max()); }
|
||||
|
||||
static zmmi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7,
|
||||
int v8) {
|
||||
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
|
||||
}
|
||||
|
||||
static zmm_t maskz_loadu(opmask_t mask, void const *mem) {
|
||||
return _mm512_maskz_loadu_pd(mask, mem);
|
||||
}
|
||||
static opmask_t knot_opmask(opmask_t x) { return _knot_mask8(x); }
|
||||
static opmask_t ge(zmm_t x, zmm_t y) {
|
||||
return _mm512_cmp_pd_mask(x, y, _CMP_GE_OQ);
|
||||
}
|
||||
static opmask_t gt(zmm_t x, zmm_t y) {
|
||||
return _mm512_cmp_pd_mask(x, y, _CMP_GT_OQ);
|
||||
}
|
||||
static opmask_t eq(zmm_t x, zmm_t y) {
|
||||
return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OQ);
|
||||
}
|
||||
template <int type>
|
||||
static opmask_t fpclass(zmm_t x) {
|
||||
return _mm512_fpclass_pd_mask(x, type);
|
||||
}
|
||||
template <int scale>
|
||||
static zmm_t mask_i64gather(zmm_t src, opmask_t mask, __m512i index,
|
||||
void const *base) {
|
||||
return _mm512_mask_i64gather_pd(src, mask, index, base, scale);
|
||||
}
|
||||
template <int scale>
|
||||
static zmm_t i64gather(__m512i index, void const *base) {
|
||||
return _mm512_i64gather_pd(index, base, scale);
|
||||
}
|
||||
static zmm_t loadu(void const *mem) { return _mm512_loadu_pd(mem); }
|
||||
static zmm_t max(zmm_t x, zmm_t y) { return _mm512_max_pd(x, y); }
|
||||
static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) {
|
||||
return _mm512_mask_compressstoreu_pd(mem, mask, x);
|
||||
}
|
||||
static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) {
|
||||
return _mm512_mask_loadu_pd(x, mask, mem);
|
||||
}
|
||||
static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) {
|
||||
return _mm512_mask_mov_pd(x, mask, y);
|
||||
}
|
||||
static void mask_storeu(void *mem, opmask_t mask, zmm_t x) {
|
||||
return _mm512_mask_storeu_pd(mem, mask, x);
|
||||
}
|
||||
static zmm_t min(zmm_t x, zmm_t y) { return _mm512_min_pd(x, y); }
|
||||
static zmm_t permutexvar(__m512i idx, zmm_t zmm) {
|
||||
return _mm512_permutexvar_pd(idx, zmm);
|
||||
}
|
||||
static type_t reducemax(zmm_t v) { return _mm512_reduce_max_pd(v); }
|
||||
static type_t reducemin(zmm_t v) { return _mm512_reduce_min_pd(v); }
|
||||
static zmm_t set1(type_t v) { return _mm512_set1_pd(v); }
|
||||
template <uint8_t mask>
|
||||
static zmm_t shuffle(zmm_t zmm) {
|
||||
return _mm512_shuffle_pd(zmm, zmm, (_MM_PERM_ENUM)mask);
|
||||
}
|
||||
static void storeu(void *mem, zmm_t x) { _mm512_storeu_pd(mem, x); }
|
||||
};
|
||||
|
||||
/*
|
||||
* Assumes zmm is random and performs a full sorting network defined in
|
||||
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg
|
||||
*/
|
||||
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
|
||||
X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t zmm) {
|
||||
const typename vtype::zmmi_t rev_index = vtype::seti(NETWORK_64BIT_2);
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::template shuffle<SHUFFLE_MASK(1, 1, 1, 1)>(zmm), 0xAA);
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::permutexvar(vtype::seti(NETWORK_64BIT_1), zmm), 0xCC);
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::template shuffle<SHUFFLE_MASK(1, 1, 1, 1)>(zmm), 0xAA);
|
||||
zmm = cmp_merge<vtype>(zmm, vtype::permutexvar(rev_index, zmm), 0xF0);
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::permutexvar(vtype::seti(NETWORK_64BIT_3), zmm), 0xCC);
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::template shuffle<SHUFFLE_MASK(1, 1, 1, 1)>(zmm), 0xAA);
|
||||
return zmm;
|
||||
}
|
||||
|
||||
|
||||
#endif
|
772
src/java.base/linux/native/libsimdsort/avx512-64bit-qsort.hpp
Normal file
772
src/java.base/linux/native/libsimdsort/avx512-64bit-qsort.hpp
Normal file
@ -0,0 +1,772 @@
|
||||
/*
|
||||
* Copyright (c) 2021, 2023, Intel Corporation. All rights reserved.
|
||||
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||||
*
|
||||
* This code is free software; you can redistribute it and/or modify it
|
||||
* under the terms of the GNU General Public License version 2 only, as
|
||||
* published by the Free Software Foundation.
|
||||
*
|
||||
* This code is distributed in the hope that it will be useful, but WITHOUT
|
||||
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
|
||||
* version 2 for more details (a copy is included in the LICENSE file that
|
||||
* accompanied this code).
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License version
|
||||
* 2 along with this work; if not, write to the Free Software Foundation,
|
||||
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||
*
|
||||
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
|
||||
* or visit www.oracle.com if you need additional information or have any
|
||||
* questions.
|
||||
*
|
||||
*/
|
||||
|
||||
// This implementation is based on x86-simd-sort(https://github.com/intel/x86-simd-sort)
|
||||
|
||||
#ifndef AVX512_QSORT_64BIT
|
||||
#define AVX512_QSORT_64BIT
|
||||
|
||||
#include "avx512-64bit-common.h"
|
||||
|
||||
// Assumes zmm is bitonic and performs a recursive half cleaner
|
||||
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
|
||||
X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit(zmm_t zmm) {
|
||||
// 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), zmm), 0xF0);
|
||||
// 2) half_cleaner[4]
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), zmm), 0xCC);
|
||||
// 3) half_cleaner[1]
|
||||
zmm = cmp_merge<vtype>(
|
||||
zmm, vtype::template shuffle<SHUFFLE_MASK(1, 1, 1, 1)>(zmm), 0xAA);
|
||||
return zmm;
|
||||
}
|
||||
// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner
|
||||
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
|
||||
X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &zmm1,
|
||||
zmm_t &zmm2) {
|
||||
const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2);
|
||||
// 1) First step of a merging network: coex of zmm1 and zmm2 reversed
|
||||
zmm2 = vtype::permutexvar(rev_index, zmm2);
|
||||
zmm_t zmm3 = vtype::min(zmm1, zmm2);
|
||||
zmm_t zmm4 = vtype::max(zmm1, zmm2);
|
||||
// 2) Recursive half cleaner for each
|
||||
zmm1 = bitonic_merge_zmm_64bit<vtype>(zmm3);
|
||||
zmm2 = bitonic_merge_zmm_64bit<vtype>(zmm4);
|
||||
}
|
||||
// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive
|
||||
// half cleaner
|
||||
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
|
||||
X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *zmm) {
|
||||
const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2);
|
||||
// 1) First step of a merging network
|
||||
zmm_t zmm2r = vtype::permutexvar(rev_index, zmm[2]);
|
||||
zmm_t zmm3r = vtype::permutexvar(rev_index, zmm[3]);
|
||||
zmm_t zmm_t1 = vtype::min(zmm[0], zmm3r);
|
||||
zmm_t zmm_t2 = vtype::min(zmm[1], zmm2r);
|
||||
// 2) Recursive half clearer: 16
|
||||
zmm_t zmm_t3 = vtype::permutexvar(rev_index, vtype::max(zmm[1], zmm2r));
|
||||
zmm_t zmm_t4 = vtype::permutexvar(rev_index, vtype::max(zmm[0], zmm3r));
|
||||
zmm_t zmm0 = vtype::min(zmm_t1, zmm_t2);
|
||||
zmm_t zmm1 = vtype::max(zmm_t1, zmm_t2);
|
||||
zmm_t zmm2 = vtype::min(zmm_t3, zmm_t4);
|
||||
zmm_t zmm3 = vtype::max(zmm_t3, zmm_t4);
|
||||
zmm[0] = bitonic_merge_zmm_64bit<vtype>(zmm0);
|
||||
zmm[1] = bitonic_merge_zmm_64bit<vtype>(zmm1);
|
||||
zmm[2] = bitonic_merge_zmm_64bit<vtype>(zmm2);
|
||||
zmm[3] = bitonic_merge_zmm_64bit<vtype>(zmm3);
|
||||
}
|
||||
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
|
||||
X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *zmm) {
|
||||
const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2);
|
||||
zmm_t zmm4r = vtype::permutexvar(rev_index, zmm[4]);
|
||||
zmm_t zmm5r = vtype::permutexvar(rev_index, zmm[5]);
|
||||
zmm_t zmm6r = vtype::permutexvar(rev_index, zmm[6]);
|
||||
zmm_t zmm7r = vtype::permutexvar(rev_index, zmm[7]);
|
||||
zmm_t zmm_t1 = vtype::min(zmm[0], zmm7r);
|
||||
zmm_t zmm_t2 = vtype::min(zmm[1], zmm6r);
|
||||
zmm_t zmm_t3 = vtype::min(zmm[2], zmm5r);
|
||||
zmm_t zmm_t4 = vtype::min(zmm[3], zmm4r);
|
||||
zmm_t zmm_t5 = vtype::permutexvar(rev_index, vtype::max(zmm[3], zmm4r));
|
||||
zmm_t zmm_t6 = vtype::permutexvar(rev_index, vtype::max(zmm[2], zmm5r));
|
||||
zmm_t zmm_t7 = vtype::permutexvar(rev_index, vtype::max(zmm[1], zmm6r));
|
||||
zmm_t zmm_t8 = vtype::permutexvar(rev_index, vtype::max(zmm[0], zmm7r));
|
||||
COEX<vtype>(zmm_t1, zmm_t3);
|
||||
COEX<vtype>(zmm_t2, zmm_t4);
|
||||
COEX<vtype>(zmm_t5, zmm_t7);
|
||||
COEX<vtype>(zmm_t6, zmm_t8);
|
||||
COEX<vtype>(zmm_t1, zmm_t2);
|
||||
COEX<vtype>(zmm_t3, zmm_t4);
|
||||
COEX<vtype>(zmm_t5, zmm_t6);
|
||||
COEX<vtype>(zmm_t7, zmm_t8);
|
||||
zmm[0] = bitonic_merge_zmm_64bit<vtype>(zmm_t1);
|
||||
zmm[1] = bitonic_merge_zmm_64bit<vtype>(zmm_t2);
|
||||
zmm[2] = bitonic_merge_zmm_64bit<vtype>(zmm_t3);
|
||||
zmm[3] = bitonic_merge_zmm_64bit<vtype>(zmm_t4);
|
||||
zmm[4] = bitonic_merge_zmm_64bit<vtype>(zmm_t5);
|
||||
zmm[5] = bitonic_merge_zmm_64bit<vtype>(zmm_t6);
|
||||
zmm[6] = bitonic_merge_zmm_64bit<vtype>(zmm_t7);
|
||||
zmm[7] = bitonic_merge_zmm_64bit<vtype>(zmm_t8);
|
||||
}
|
||||
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
|
||||
X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *zmm) {
|
||||
const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2);
|
||||
zmm_t zmm8r = vtype::permutexvar(rev_index, zmm[8]);
|
||||
zmm_t zmm9r = vtype::permutexvar(rev_index, zmm[9]);
|
||||
zmm_t zmm10r = vtype::permutexvar(rev_index, zmm[10]);
|
||||
zmm_t zmm11r = vtype::permutexvar(rev_index, zmm[11]);
|
||||
zmm_t zmm12r = vtype::permutexvar(rev_index, zmm[12]);
|
||||
zmm_t zmm13r = vtype::permutexvar(rev_index, zmm[13]);
|
||||
zmm_t zmm14r = vtype::permutexvar(rev_index, zmm[14]);
|
||||
zmm_t zmm15r = vtype::permutexvar(rev_index, zmm[15]);
|
||||
zmm_t zmm_t1 = vtype::min(zmm[0], zmm15r);
|
||||
zmm_t zmm_t2 = vtype::min(zmm[1], zmm14r);
|
||||
zmm_t zmm_t3 = vtype::min(zmm[2], zmm13r);
|
||||
zmm_t zmm_t4 = vtype::min(zmm[3], zmm12r);
|
||||
zmm_t zmm_t5 = vtype::min(zmm[4], zmm11r);
|
||||
zmm_t zmm_t6 = vtype::min(zmm[5], zmm10r);
|
||||
zmm_t zmm_t7 = vtype::min(zmm[6], zmm9r);
|
||||
zmm_t zmm_t8 = vtype::min(zmm[7], zmm8r);
|
||||
zmm_t zmm_t9 = vtype::permutexvar(rev_index, vtype::max(zmm[7], zmm8r));
|
||||
zmm_t zmm_t10 = vtype::permutexvar(rev_index, vtype::max(zmm[6], zmm9r));
|
||||
zmm_t zmm_t11 = vtype::permutexvar(rev_index, vtype::max(zmm[5], zmm10r));
|
||||
zmm_t zmm_t12 = vtype::permutexvar(rev_index, vtype::max(zmm[4], zmm11r));
|
||||
zmm_t zmm_t13 = vtype::permutexvar(rev_index, vtype::max(zmm[3], zmm12r));
|
||||
zmm_t zmm_t14 = vtype::permutexvar(rev_index, vtype::max(zmm[2], zmm13r));
|
||||
zmm_t zmm_t15 = vtype::permutexvar(rev_index, vtype::max(zmm[1], zmm14r));
|
||||
zmm_t zmm_t16 = vtype::permutexvar(rev_index, vtype::max(zmm[0], zmm15r));
|
||||
// Recusive half clear 16 zmm regs
|
||||
COEX<vtype>(zmm_t1, zmm_t5);
|
||||
COEX<vtype>(zmm_t2, zmm_t6);
|
||||
COEX<vtype>(zmm_t3, zmm_t7);
|
||||
COEX<vtype>(zmm_t4, zmm_t8);
|
||||
COEX<vtype>(zmm_t9, zmm_t13);
|
||||
COEX<vtype>(zmm_t10, zmm_t14);
|
||||
COEX<vtype>(zmm_t11, zmm_t15);
|
||||
COEX<vtype>(zmm_t12, zmm_t16);
|
||||
//
|
||||
COEX<vtype>(zmm_t1, zmm_t3);
|
||||
COEX<vtype>(zmm_t2, zmm_t4);
|
||||
COEX<vtype>(zmm_t5, zmm_t7);
|
||||
COEX<vtype>(zmm_t6, zmm_t8);
|
||||
COEX<vtype>(zmm_t9, zmm_t11);
|
||||
COEX<vtype>(zmm_t10, zmm_t12);
|
||||
COEX<vtype>(zmm_t13, zmm_t15);
|
||||
COEX<vtype>(zmm_t14, zmm_t16);
|
||||
//
|
||||
COEX<vtype>(zmm_t1, zmm_t2);
|
||||
COEX<vtype>(zmm_t3, zmm_t4);
|
||||
COEX<vtype>(zmm_t5, zmm_t6);
|
||||
COEX<vtype>(zmm_t7, zmm_t8);
|
||||
COEX<vtype>(zmm_t9, zmm_t10);
|
||||
COEX<vtype>(zmm_t11, zmm_t12);
|
||||
COEX<vtype>(zmm_t13, zmm_t14);
|
||||
COEX<vtype>(zmm_t15, zmm_t16);
|
||||
//
|
||||
zmm[0] = bitonic_merge_zmm_64bit<vtype>(zmm_t1);
|
||||
zmm[1] = bitonic_merge_zmm_64bit<vtype>(zmm_t2);
|
||||
zmm[2] = bitonic_merge_zmm_64bit<vtype>(zmm_t3);
|
||||
zmm[3] = bitonic_merge_zmm_64bit<vtype>(zmm_t4);
|
||||
zmm[4] = bitonic_merge_zmm_64bit<vtype>(zmm_t5);
|
||||
zmm[5] = bitonic_merge_zmm_64bit<vtype>(zmm_t6);
|
||||
zmm[6] = bitonic_merge_zmm_64bit<vtype>(zmm_t7);
|
||||
zmm[7] = bitonic_merge_zmm_64bit<vtype>(zmm_t8);
|
||||
zmm[8] = bitonic_merge_zmm_64bit<vtype>(zmm_t9);
|
||||
zmm[9] = bitonic_merge_zmm_64bit<vtype>(zmm_t10);
|
||||
zmm[10] = bitonic_merge_zmm_64bit<vtype>(zmm_t11);
|
||||
zmm[11] = bitonic_merge_zmm_64bit<vtype>(zmm_t12);
|
||||
zmm[12] = bitonic_merge_zmm_64bit<vtype>(zmm_t13);
|
||||
zmm[13] = bitonic_merge_zmm_64bit<vtype>(zmm_t14);
|
||||
zmm[14] = bitonic_merge_zmm_64bit<vtype>(zmm_t15);
|
||||
zmm[15] = bitonic_merge_zmm_64bit<vtype>(zmm_t16);
|
||||
}
|
||||
|
||||
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
|
||||
X86_SIMD_SORT_INLINE void bitonic_merge_32_zmm_64bit(zmm_t *zmm) {
|
||||
const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2);
|
||||
zmm_t zmm16r = vtype::permutexvar(rev_index, zmm[16]);
|
||||
zmm_t zmm17r = vtype::permutexvar(rev_index, zmm[17]);
|
||||
zmm_t zmm18r = vtype::permutexvar(rev_index, zmm[18]);
|
||||
zmm_t zmm19r = vtype::permutexvar(rev_index, zmm[19]);
|
||||
zmm_t zmm20r = vtype::permutexvar(rev_index, zmm[20]);
|
||||
zmm_t zmm21r = vtype::permutexvar(rev_index, zmm[21]);
|
||||
zmm_t zmm22r = vtype::permutexvar(rev_index, zmm[22]);
|
||||
zmm_t zmm23r = vtype::permutexvar(rev_index, zmm[23]);
|
||||
zmm_t zmm24r = vtype::permutexvar(rev_index, zmm[24]);
|
||||
zmm_t zmm25r = vtype::permutexvar(rev_index, zmm[25]);
|
||||
zmm_t zmm26r = vtype::permutexvar(rev_index, zmm[26]);
|
||||
zmm_t zmm27r = vtype::permutexvar(rev_index, zmm[27]);
|
||||
zmm_t zmm28r = vtype::permutexvar(rev_index, zmm[28]);
|
||||
zmm_t zmm29r = vtype::permutexvar(rev_index, zmm[29]);
|
||||
zmm_t zmm30r = vtype::permutexvar(rev_index, zmm[30]);
|
||||
zmm_t zmm31r = vtype::permutexvar(rev_index, zmm[31]);
|
||||
zmm_t zmm_t1 = vtype::min(zmm[0], zmm31r);
|
||||
zmm_t zmm_t2 = vtype::min(zmm[1], zmm30r);
|
||||
zmm_t zmm_t3 = vtype::min(zmm[2], zmm29r);
|
||||
zmm_t zmm_t4 = vtype::min(zmm[3], zmm28r);
|
||||
zmm_t zmm_t5 = vtype::min(zmm[4], zmm27r);
|
||||
zmm_t zmm_t6 = vtype::min(zmm[5], zmm26r);
|
||||
zmm_t zmm_t7 = vtype::min(zmm[6], zmm25r);
|
||||
zmm_t zmm_t8 = vtype::min(zmm[7], zmm24r);
|
||||
zmm_t zmm_t9 = vtype::min(zmm[8], zmm23r);
|
||||
zmm_t zmm_t10 = vtype::min(zmm[9], zmm22r);
|
||||
zmm_t zmm_t11 = vtype::min(zmm[10], zmm21r);
|
||||
zmm_t zmm_t12 = vtype::min(zmm[11], zmm20r);
|
||||
zmm_t zmm_t13 = vtype::min(zmm[12], zmm19r);
|
||||
zmm_t zmm_t14 = vtype::min(zmm[13], zmm18r);
|
||||
zmm_t zmm_t15 = vtype::min(zmm[14], zmm17r);
|
||||
zmm_t zmm_t16 = vtype::min(zmm[15], zmm16r);
|
||||
zmm_t zmm_t17 = vtype::permutexvar(rev_index, vtype::max(zmm[15], zmm16r));
|
||||
zmm_t zmm_t18 = vtype::permutexvar(rev_index, vtype::max(zmm[14], zmm17r));
|
||||
zmm_t zmm_t19 = vtype::permutexvar(rev_index, vtype::max(zmm[13], zmm18r));
|
||||
zmm_t zmm_t20 = vtype::permutexvar(rev_index, vtype::max(zmm[12], zmm19r));
|
||||
zmm_t zmm_t21 = vtype::permutexvar(rev_index, vtype::max(zmm[11], zmm20r));
|
||||
zmm_t zmm_t22 = vtype::permutexvar(rev_index, vtype::max(zmm[10], zmm21r));
|
||||
zmm_t zmm_t23 = vtype::permutexvar(rev_index, vtype::max(zmm[9], zmm22r));
|
||||
zmm_t zmm_t24 = vtype::permutexvar(rev_index, vtype::max(zmm[8], zmm23r));
|
||||
zmm_t zmm_t25 = vtype::permutexvar(rev_index, vtype::max(zmm[7], zmm24r));
|
||||
zmm_t zmm_t26 = vtype::permutexvar(rev_index, vtype::max(zmm[6], zmm25r));
|
||||
zmm_t zmm_t27 = vtype::permutexvar(rev_index, vtype::max(zmm[5], zmm26r));
|
||||
zmm_t zmm_t28 = vtype::permutexvar(rev_index, vtype::max(zmm[4], zmm27r));
|
||||
zmm_t zmm_t29 = vtype::permutexvar(rev_index, vtype::max(zmm[3], zmm28r));
|
||||
zmm_t zmm_t30 = vtype::permutexvar(rev_index, vtype::max(zmm[2], zmm29r));
|
||||
zmm_t zmm_t31 = vtype::permutexvar(rev_index, vtype::max(zmm[1], zmm30r));
|
||||
zmm_t zmm_t32 = vtype::permutexvar(rev_index, vtype::max(zmm[0], zmm31r));
|
||||
// Recusive half clear 16 zmm regs
|
||||
COEX<vtype>(zmm_t1, zmm_t9);
|
||||
COEX<vtype>(zmm_t2, zmm_t10);
|
||||
COEX<vtype>(zmm_t3, zmm_t11);
|
||||
COEX<vtype>(zmm_t4, zmm_t12);
|
||||
COEX<vtype>(zmm_t5, zmm_t13);
|
||||
COEX<vtype>(zmm_t6, zmm_t14);
|
||||
COEX<vtype>(zmm_t7, zmm_t15);
|
||||
COEX<vtype>(zmm_t8, zmm_t16);
|
||||
COEX<vtype>(zmm_t17, zmm_t25);
|
||||
COEX<vtype>(zmm_t18, zmm_t26);
|
||||
COEX<vtype>(zmm_t19, zmm_t27);
|
||||
COEX<vtype>(zmm_t20, zmm_t28);
|
||||
COEX<vtype>(zmm_t21, zmm_t29);
|
||||
COEX<vtype>(zmm_t22, zmm_t30);
|
||||
COEX<vtype>(zmm_t23, zmm_t31);
|
||||
COEX<vtype>(zmm_t24, zmm_t32);
|
||||
//
|
||||
COEX<vtype>(zmm_t1, zmm_t5);
|
||||
COEX<vtype>(zmm_t2, zmm_t6);
|
||||
COEX<vtype>(zmm_t3, zmm_t7);
|
||||
COEX<vtype>(zmm_t4, zmm_t8);
|
||||
COEX<vtype>(zmm_t9, zmm_t13);
|
||||
COEX<vtype>(zmm_t10, zmm_t14);
|
||||
COEX<vtype>(zmm_t11, zmm_t15);
|
||||
COEX<vtype>(zmm_t12, zmm_t16);
|
||||
COEX<vtype>(zmm_t17, zmm_t21);
|
||||
COEX<vtype>(zmm_t18, zmm_t22);
|
||||
COEX<vtype>(zmm_t19, zmm_t23);
|
||||
COEX<vtype>(zmm_t20, zmm_t24);
|
||||
COEX<vtype>(zmm_t25, zmm_t29);
|
||||
COEX<vtype>(zmm_t26, zmm_t30);
|
||||
COEX<vtype>(zmm_t27, zmm_t31);
|
||||
COEX<vtype>(zmm_t28, zmm_t32);
|
||||
//
|
||||
COEX<vtype>(zmm_t1, zmm_t3);
|
||||
COEX<vtype>(zmm_t2, zmm_t4);
|
||||
COEX<vtype>(zmm_t5, zmm_t7);
|
||||
COEX<vtype>(zmm_t6, zmm_t8);
|
||||
COEX<vtype>(zmm_t9, zmm_t11);
|
||||
COEX<vtype>(zmm_t10, zmm_t12);
|
||||
COEX<vtype>(zmm_t13, zmm_t15);
|
||||
COEX<vtype>(zmm_t14, zmm_t16);
|
||||
COEX<vtype>(zmm_t17, zmm_t19);
|
||||
COEX<vtype>(zmm_t18, zmm_t20);
|
||||
COEX<vtype>(zmm_t21, zmm_t23);
|
||||
COEX<vtype>(zmm_t22, zmm_t24);
|
||||
COEX<vtype>(zmm_t25, zmm_t27);
|
||||
COEX<vtype>(zmm_t26, zmm_t28);
|
||||
COEX<vtype>(zmm_t29, zmm_t31);
|
||||
COEX<vtype>(zmm_t30, zmm_t32);
|
||||
//
|
||||
COEX<vtype>(zmm_t1, zmm_t2);
|
||||
COEX<vtype>(zmm_t3, zmm_t4);
|
||||
COEX<vtype>(zmm_t5, zmm_t6);
|
||||
COEX<vtype>(zmm_t7, zmm_t8);
|
||||
COEX<vtype>(zmm_t9, zmm_t10);
|
||||
COEX<vtype>(zmm_t11, zmm_t12);
|
||||
COEX<vtype>(zmm_t13, zmm_t14);
|
||||
COEX<vtype>(zmm_t15, zmm_t16);
|
||||
COEX<vtype>(zmm_t17, zmm_t18);
|
||||
COEX<vtype>(zmm_t19, zmm_t20);
|
||||
COEX<vtype>(zmm_t21, zmm_t22);
|
||||
COEX<vtype>(zmm_t23, zmm_t24);
|
||||
COEX<vtype>(zmm_t25, zmm_t26);
|
||||
COEX<vtype>(zmm_t27, zmm_t28);
|
||||
COEX<vtype>(zmm_t29, zmm_t30);
|
||||
COEX<vtype>(zmm_t31, zmm_t32);
|
||||
//
|
||||
zmm[0] = bitonic_merge_zmm_64bit<vtype>(zmm_t1);
|
||||
zmm[1] = bitonic_merge_zmm_64bit<vtype>(zmm_t2);
|
||||
zmm[2] = bitonic_merge_zmm_64bit<vtype>(zmm_t3);
|
||||
zmm[3] = bitonic_merge_zmm_64bit<vtype>(zmm_t4);
|
||||
zmm[4] = bitonic_merge_zmm_64bit<vtype>(zmm_t5);
|
||||
zmm[5] = bitonic_merge_zmm_64bit<vtype>(zmm_t6);
|
||||
zmm[6] = bitonic_merge_zmm_64bit<vtype>(zmm_t7);
|
||||
zmm[7] = bitonic_merge_zmm_64bit<vtype>(zmm_t8);
|
||||
zmm[8] = bitonic_merge_zmm_64bit<vtype>(zmm_t9);
|
||||
zmm[9] = bitonic_merge_zmm_64bit<vtype>(zmm_t10);
|
||||
zmm[10] = bitonic_merge_zmm_64bit<vtype>(zmm_t11);
|
||||
zmm[11] = bitonic_merge_zmm_64bit<vtype>(zmm_t12);
|
||||
zmm[12] = bitonic_merge_zmm_64bit<vtype>(zmm_t13);
|
||||
zmm[13] = bitonic_merge_zmm_64bit<vtype>(zmm_t14);
|
||||
zmm[14] = bitonic_merge_zmm_64bit<vtype>(zmm_t15);
|
||||
zmm[15] = bitonic_merge_zmm_64bit<vtype>(zmm_t16);
|
||||
zmm[16] = bitonic_merge_zmm_64bit<vtype>(zmm_t17);
|
||||
zmm[17] = bitonic_merge_zmm_64bit<vtype>(zmm_t18);
|
||||
zmm[18] = bitonic_merge_zmm_64bit<vtype>(zmm_t19);
|
||||
zmm[19] = bitonic_merge_zmm_64bit<vtype>(zmm_t20);
|
||||
zmm[20] = bitonic_merge_zmm_64bit<vtype>(zmm_t21);
|
||||
zmm[21] = bitonic_merge_zmm_64bit<vtype>(zmm_t22);
|
||||
zmm[22] = bitonic_merge_zmm_64bit<vtype>(zmm_t23);
|
||||
zmm[23] = bitonic_merge_zmm_64bit<vtype>(zmm_t24);
|
||||
zmm[24] = bitonic_merge_zmm_64bit<vtype>(zmm_t25);
|
||||
zmm[25] = bitonic_merge_zmm_64bit<vtype>(zmm_t26);
|
||||
zmm[26] = bitonic_merge_zmm_64bit<vtype>(zmm_t27);
|
||||
zmm[27] = bitonic_merge_zmm_64bit<vtype>(zmm_t28);
|
||||
zmm[28] = bitonic_merge_zmm_64bit<vtype>(zmm_t29);
|
||||
zmm[29] = bitonic_merge_zmm_64bit<vtype>(zmm_t30);
|
||||
zmm[30] = bitonic_merge_zmm_64bit<vtype>(zmm_t31);
|
||||
zmm[31] = bitonic_merge_zmm_64bit<vtype>(zmm_t32);
|
||||
}
|
||||
|
||||
template <typename vtype, typename type_t>
|
||||
X86_SIMD_SORT_INLINE void sort_8_64bit(type_t *arr, int32_t N) {
|
||||
typename vtype::opmask_t load_mask = (0x01 << N) - 0x01;
|
||||
typename vtype::zmm_t zmm =
|
||||
vtype::mask_loadu(vtype::zmm_max(), load_mask, arr);
|
||||
vtype::mask_storeu(arr, load_mask, sort_zmm_64bit<vtype>(zmm));
|
||||
}
|
||||
|
||||
template <typename vtype, typename type_t>
|
||||
X86_SIMD_SORT_INLINE void sort_16_64bit(type_t *arr, int32_t N) {
|
||||
if (N <= 8) {
|
||||
sort_8_64bit<vtype>(arr, N);
|
||||
return;
|
||||
}
|
||||
using zmm_t = typename vtype::zmm_t;
|
||||
zmm_t zmm1 = vtype::loadu(arr);
|
||||
typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01;
|
||||
zmm_t zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr + 8);
|
||||
zmm1 = sort_zmm_64bit<vtype>(zmm1);
|
||||
zmm2 = sort_zmm_64bit<vtype>(zmm2);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm1, zmm2);
|
||||
vtype::storeu(arr, zmm1);
|
||||
vtype::mask_storeu(arr + 8, load_mask, zmm2);
|
||||
}
|
||||
|
||||
template <typename vtype, typename type_t>
|
||||
X86_SIMD_SORT_INLINE void sort_32_64bit(type_t *arr, int32_t N) {
|
||||
if (N <= 16) {
|
||||
sort_16_64bit<vtype>(arr, N);
|
||||
return;
|
||||
}
|
||||
using zmm_t = typename vtype::zmm_t;
|
||||
using opmask_t = typename vtype::opmask_t;
|
||||
zmm_t zmm[4];
|
||||
zmm[0] = vtype::loadu(arr);
|
||||
zmm[1] = vtype::loadu(arr + 8);
|
||||
opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF;
|
||||
uint64_t combined_mask = (0x1ull << (N - 16)) - 0x1ull;
|
||||
load_mask1 = (combined_mask)&0xFF;
|
||||
load_mask2 = (combined_mask >> 8) & 0xFF;
|
||||
zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 16);
|
||||
zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 24);
|
||||
zmm[0] = sort_zmm_64bit<vtype>(zmm[0]);
|
||||
zmm[1] = sort_zmm_64bit<vtype>(zmm[1]);
|
||||
zmm[2] = sort_zmm_64bit<vtype>(zmm[2]);
|
||||
zmm[3] = sort_zmm_64bit<vtype>(zmm[3]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[0], zmm[1]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[2], zmm[3]);
|
||||
bitonic_merge_four_zmm_64bit<vtype>(zmm);
|
||||
vtype::storeu(arr, zmm[0]);
|
||||
vtype::storeu(arr + 8, zmm[1]);
|
||||
vtype::mask_storeu(arr + 16, load_mask1, zmm[2]);
|
||||
vtype::mask_storeu(arr + 24, load_mask2, zmm[3]);
|
||||
}
|
||||
|
||||
template <typename vtype, typename type_t>
|
||||
X86_SIMD_SORT_INLINE void sort_64_64bit(type_t *arr, int32_t N) {
|
||||
if (N <= 32) {
|
||||
sort_32_64bit<vtype>(arr, N);
|
||||
return;
|
||||
}
|
||||
using zmm_t = typename vtype::zmm_t;
|
||||
using opmask_t = typename vtype::opmask_t;
|
||||
zmm_t zmm[8];
|
||||
zmm[0] = vtype::loadu(arr);
|
||||
zmm[1] = vtype::loadu(arr + 8);
|
||||
zmm[2] = vtype::loadu(arr + 16);
|
||||
zmm[3] = vtype::loadu(arr + 24);
|
||||
zmm[0] = sort_zmm_64bit<vtype>(zmm[0]);
|
||||
zmm[1] = sort_zmm_64bit<vtype>(zmm[1]);
|
||||
zmm[2] = sort_zmm_64bit<vtype>(zmm[2]);
|
||||
zmm[3] = sort_zmm_64bit<vtype>(zmm[3]);
|
||||
opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF;
|
||||
opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF;
|
||||
// N-32 >= 1
|
||||
uint64_t combined_mask = (0x1ull << (N - 32)) - 0x1ull;
|
||||
load_mask1 = (combined_mask)&0xFF;
|
||||
load_mask2 = (combined_mask >> 8) & 0xFF;
|
||||
load_mask3 = (combined_mask >> 16) & 0xFF;
|
||||
load_mask4 = (combined_mask >> 24) & 0xFF;
|
||||
zmm[4] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 32);
|
||||
zmm[5] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 40);
|
||||
zmm[6] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 48);
|
||||
zmm[7] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 56);
|
||||
zmm[4] = sort_zmm_64bit<vtype>(zmm[4]);
|
||||
zmm[5] = sort_zmm_64bit<vtype>(zmm[5]);
|
||||
zmm[6] = sort_zmm_64bit<vtype>(zmm[6]);
|
||||
zmm[7] = sort_zmm_64bit<vtype>(zmm[7]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[0], zmm[1]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[2], zmm[3]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[4], zmm[5]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[6], zmm[7]);
|
||||
bitonic_merge_four_zmm_64bit<vtype>(zmm);
|
||||
bitonic_merge_four_zmm_64bit<vtype>(zmm + 4);
|
||||
bitonic_merge_eight_zmm_64bit<vtype>(zmm);
|
||||
vtype::storeu(arr, zmm[0]);
|
||||
vtype::storeu(arr + 8, zmm[1]);
|
||||
vtype::storeu(arr + 16, zmm[2]);
|
||||
vtype::storeu(arr + 24, zmm[3]);
|
||||
vtype::mask_storeu(arr + 32, load_mask1, zmm[4]);
|
||||
vtype::mask_storeu(arr + 40, load_mask2, zmm[5]);
|
||||
vtype::mask_storeu(arr + 48, load_mask3, zmm[6]);
|
||||
vtype::mask_storeu(arr + 56, load_mask4, zmm[7]);
|
||||
}
|
||||
|
||||
template <typename vtype, typename type_t>
|
||||
X86_SIMD_SORT_INLINE void sort_128_64bit(type_t *arr, int32_t N) {
|
||||
if (N <= 64) {
|
||||
sort_64_64bit<vtype>(arr, N);
|
||||
return;
|
||||
}
|
||||
using zmm_t = typename vtype::zmm_t;
|
||||
using opmask_t = typename vtype::opmask_t;
|
||||
zmm_t zmm[16];
|
||||
zmm[0] = vtype::loadu(arr);
|
||||
zmm[1] = vtype::loadu(arr + 8);
|
||||
zmm[2] = vtype::loadu(arr + 16);
|
||||
zmm[3] = vtype::loadu(arr + 24);
|
||||
zmm[4] = vtype::loadu(arr + 32);
|
||||
zmm[5] = vtype::loadu(arr + 40);
|
||||
zmm[6] = vtype::loadu(arr + 48);
|
||||
zmm[7] = vtype::loadu(arr + 56);
|
||||
zmm[0] = sort_zmm_64bit<vtype>(zmm[0]);
|
||||
zmm[1] = sort_zmm_64bit<vtype>(zmm[1]);
|
||||
zmm[2] = sort_zmm_64bit<vtype>(zmm[2]);
|
||||
zmm[3] = sort_zmm_64bit<vtype>(zmm[3]);
|
||||
zmm[4] = sort_zmm_64bit<vtype>(zmm[4]);
|
||||
zmm[5] = sort_zmm_64bit<vtype>(zmm[5]);
|
||||
zmm[6] = sort_zmm_64bit<vtype>(zmm[6]);
|
||||
zmm[7] = sort_zmm_64bit<vtype>(zmm[7]);
|
||||
opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF;
|
||||
opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF;
|
||||
opmask_t load_mask5 = 0xFF, load_mask6 = 0xFF;
|
||||
opmask_t load_mask7 = 0xFF, load_mask8 = 0xFF;
|
||||
if (N != 128) {
|
||||
uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull;
|
||||
load_mask1 = (combined_mask)&0xFF;
|
||||
load_mask2 = (combined_mask >> 8) & 0xFF;
|
||||
load_mask3 = (combined_mask >> 16) & 0xFF;
|
||||
load_mask4 = (combined_mask >> 24) & 0xFF;
|
||||
load_mask5 = (combined_mask >> 32) & 0xFF;
|
||||
load_mask6 = (combined_mask >> 40) & 0xFF;
|
||||
load_mask7 = (combined_mask >> 48) & 0xFF;
|
||||
load_mask8 = (combined_mask >> 56) & 0xFF;
|
||||
}
|
||||
zmm[8] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 64);
|
||||
zmm[9] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 72);
|
||||
zmm[10] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 80);
|
||||
zmm[11] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 88);
|
||||
zmm[12] = vtype::mask_loadu(vtype::zmm_max(), load_mask5, arr + 96);
|
||||
zmm[13] = vtype::mask_loadu(vtype::zmm_max(), load_mask6, arr + 104);
|
||||
zmm[14] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, arr + 112);
|
||||
zmm[15] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, arr + 120);
|
||||
zmm[8] = sort_zmm_64bit<vtype>(zmm[8]);
|
||||
zmm[9] = sort_zmm_64bit<vtype>(zmm[9]);
|
||||
zmm[10] = sort_zmm_64bit<vtype>(zmm[10]);
|
||||
zmm[11] = sort_zmm_64bit<vtype>(zmm[11]);
|
||||
zmm[12] = sort_zmm_64bit<vtype>(zmm[12]);
|
||||
zmm[13] = sort_zmm_64bit<vtype>(zmm[13]);
|
||||
zmm[14] = sort_zmm_64bit<vtype>(zmm[14]);
|
||||
zmm[15] = sort_zmm_64bit<vtype>(zmm[15]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[0], zmm[1]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[2], zmm[3]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[4], zmm[5]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[6], zmm[7]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[8], zmm[9]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[10], zmm[11]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[12], zmm[13]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[14], zmm[15]);
|
||||
bitonic_merge_four_zmm_64bit<vtype>(zmm);
|
||||
bitonic_merge_four_zmm_64bit<vtype>(zmm + 4);
|
||||
bitonic_merge_four_zmm_64bit<vtype>(zmm + 8);
|
||||
bitonic_merge_four_zmm_64bit<vtype>(zmm + 12);
|
||||
bitonic_merge_eight_zmm_64bit<vtype>(zmm);
|
||||
bitonic_merge_eight_zmm_64bit<vtype>(zmm + 8);
|
||||
bitonic_merge_sixteen_zmm_64bit<vtype>(zmm);
|
||||
vtype::storeu(arr, zmm[0]);
|
||||
vtype::storeu(arr + 8, zmm[1]);
|
||||
vtype::storeu(arr + 16, zmm[2]);
|
||||
vtype::storeu(arr + 24, zmm[3]);
|
||||
vtype::storeu(arr + 32, zmm[4]);
|
||||
vtype::storeu(arr + 40, zmm[5]);
|
||||
vtype::storeu(arr + 48, zmm[6]);
|
||||
vtype::storeu(arr + 56, zmm[7]);
|
||||
vtype::mask_storeu(arr + 64, load_mask1, zmm[8]);
|
||||
vtype::mask_storeu(arr + 72, load_mask2, zmm[9]);
|
||||
vtype::mask_storeu(arr + 80, load_mask3, zmm[10]);
|
||||
vtype::mask_storeu(arr + 88, load_mask4, zmm[11]);
|
||||
vtype::mask_storeu(arr + 96, load_mask5, zmm[12]);
|
||||
vtype::mask_storeu(arr + 104, load_mask6, zmm[13]);
|
||||
vtype::mask_storeu(arr + 112, load_mask7, zmm[14]);
|
||||
vtype::mask_storeu(arr + 120, load_mask8, zmm[15]);
|
||||
}
|
||||
|
||||
template <typename vtype, typename type_t>
|
||||
X86_SIMD_SORT_INLINE void sort_256_64bit(type_t *arr, int32_t N) {
|
||||
if (N <= 128) {
|
||||
sort_128_64bit<vtype>(arr, N);
|
||||
return;
|
||||
}
|
||||
using zmm_t = typename vtype::zmm_t;
|
||||
using opmask_t = typename vtype::opmask_t;
|
||||
zmm_t zmm[32];
|
||||
zmm[0] = vtype::loadu(arr);
|
||||
zmm[1] = vtype::loadu(arr + 8);
|
||||
zmm[2] = vtype::loadu(arr + 16);
|
||||
zmm[3] = vtype::loadu(arr + 24);
|
||||
zmm[4] = vtype::loadu(arr + 32);
|
||||
zmm[5] = vtype::loadu(arr + 40);
|
||||
zmm[6] = vtype::loadu(arr + 48);
|
||||
zmm[7] = vtype::loadu(arr + 56);
|
||||
zmm[8] = vtype::loadu(arr + 64);
|
||||
zmm[9] = vtype::loadu(arr + 72);
|
||||
zmm[10] = vtype::loadu(arr + 80);
|
||||
zmm[11] = vtype::loadu(arr + 88);
|
||||
zmm[12] = vtype::loadu(arr + 96);
|
||||
zmm[13] = vtype::loadu(arr + 104);
|
||||
zmm[14] = vtype::loadu(arr + 112);
|
||||
zmm[15] = vtype::loadu(arr + 120);
|
||||
zmm[0] = sort_zmm_64bit<vtype>(zmm[0]);
|
||||
zmm[1] = sort_zmm_64bit<vtype>(zmm[1]);
|
||||
zmm[2] = sort_zmm_64bit<vtype>(zmm[2]);
|
||||
zmm[3] = sort_zmm_64bit<vtype>(zmm[3]);
|
||||
zmm[4] = sort_zmm_64bit<vtype>(zmm[4]);
|
||||
zmm[5] = sort_zmm_64bit<vtype>(zmm[5]);
|
||||
zmm[6] = sort_zmm_64bit<vtype>(zmm[6]);
|
||||
zmm[7] = sort_zmm_64bit<vtype>(zmm[7]);
|
||||
zmm[8] = sort_zmm_64bit<vtype>(zmm[8]);
|
||||
zmm[9] = sort_zmm_64bit<vtype>(zmm[9]);
|
||||
zmm[10] = sort_zmm_64bit<vtype>(zmm[10]);
|
||||
zmm[11] = sort_zmm_64bit<vtype>(zmm[11]);
|
||||
zmm[12] = sort_zmm_64bit<vtype>(zmm[12]);
|
||||
zmm[13] = sort_zmm_64bit<vtype>(zmm[13]);
|
||||
zmm[14] = sort_zmm_64bit<vtype>(zmm[14]);
|
||||
zmm[15] = sort_zmm_64bit<vtype>(zmm[15]);
|
||||
opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF;
|
||||
opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF;
|
||||
opmask_t load_mask5 = 0xFF, load_mask6 = 0xFF;
|
||||
opmask_t load_mask7 = 0xFF, load_mask8 = 0xFF;
|
||||
opmask_t load_mask9 = 0xFF, load_mask10 = 0xFF;
|
||||
opmask_t load_mask11 = 0xFF, load_mask12 = 0xFF;
|
||||
opmask_t load_mask13 = 0xFF, load_mask14 = 0xFF;
|
||||
opmask_t load_mask15 = 0xFF, load_mask16 = 0xFF;
|
||||
if (N != 256) {
|
||||
uint64_t combined_mask;
|
||||
if (N < 192) {
|
||||
combined_mask = (0x1ull << (N - 128)) - 0x1ull;
|
||||
load_mask1 = (combined_mask)&0xFF;
|
||||
load_mask2 = (combined_mask >> 8) & 0xFF;
|
||||
load_mask3 = (combined_mask >> 16) & 0xFF;
|
||||
load_mask4 = (combined_mask >> 24) & 0xFF;
|
||||
load_mask5 = (combined_mask >> 32) & 0xFF;
|
||||
load_mask6 = (combined_mask >> 40) & 0xFF;
|
||||
load_mask7 = (combined_mask >> 48) & 0xFF;
|
||||
load_mask8 = (combined_mask >> 56) & 0xFF;
|
||||
load_mask9 = 0x00;
|
||||
load_mask10 = 0x0;
|
||||
load_mask11 = 0x00;
|
||||
load_mask12 = 0x00;
|
||||
load_mask13 = 0x00;
|
||||
load_mask14 = 0x00;
|
||||
load_mask15 = 0x00;
|
||||
load_mask16 = 0x00;
|
||||
} else {
|
||||
combined_mask = (0x1ull << (N - 192)) - 0x1ull;
|
||||
load_mask9 = (combined_mask)&0xFF;
|
||||
load_mask10 = (combined_mask >> 8) & 0xFF;
|
||||
load_mask11 = (combined_mask >> 16) & 0xFF;
|
||||
load_mask12 = (combined_mask >> 24) & 0xFF;
|
||||
load_mask13 = (combined_mask >> 32) & 0xFF;
|
||||
load_mask14 = (combined_mask >> 40) & 0xFF;
|
||||
load_mask15 = (combined_mask >> 48) & 0xFF;
|
||||
load_mask16 = (combined_mask >> 56) & 0xFF;
|
||||
}
|
||||
}
|
||||
zmm[16] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 128);
|
||||
zmm[17] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 136);
|
||||
zmm[18] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 144);
|
||||
zmm[19] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 152);
|
||||
zmm[20] = vtype::mask_loadu(vtype::zmm_max(), load_mask5, arr + 160);
|
||||
zmm[21] = vtype::mask_loadu(vtype::zmm_max(), load_mask6, arr + 168);
|
||||
zmm[22] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, arr + 176);
|
||||
zmm[23] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, arr + 184);
|
||||
if (N < 192) {
|
||||
zmm[24] = vtype::zmm_max();
|
||||
zmm[25] = vtype::zmm_max();
|
||||
zmm[26] = vtype::zmm_max();
|
||||
zmm[27] = vtype::zmm_max();
|
||||
zmm[28] = vtype::zmm_max();
|
||||
zmm[29] = vtype::zmm_max();
|
||||
zmm[30] = vtype::zmm_max();
|
||||
zmm[31] = vtype::zmm_max();
|
||||
} else {
|
||||
zmm[24] = vtype::mask_loadu(vtype::zmm_max(), load_mask9, arr + 192);
|
||||
zmm[25] = vtype::mask_loadu(vtype::zmm_max(), load_mask10, arr + 200);
|
||||
zmm[26] = vtype::mask_loadu(vtype::zmm_max(), load_mask11, arr + 208);
|
||||
zmm[27] = vtype::mask_loadu(vtype::zmm_max(), load_mask12, arr + 216);
|
||||
zmm[28] = vtype::mask_loadu(vtype::zmm_max(), load_mask13, arr + 224);
|
||||
zmm[29] = vtype::mask_loadu(vtype::zmm_max(), load_mask14, arr + 232);
|
||||
zmm[30] = vtype::mask_loadu(vtype::zmm_max(), load_mask15, arr + 240);
|
||||
zmm[31] = vtype::mask_loadu(vtype::zmm_max(), load_mask16, arr + 248);
|
||||
}
|
||||
zmm[16] = sort_zmm_64bit<vtype>(zmm[16]);
|
||||
zmm[17] = sort_zmm_64bit<vtype>(zmm[17]);
|
||||
zmm[18] = sort_zmm_64bit<vtype>(zmm[18]);
|
||||
zmm[19] = sort_zmm_64bit<vtype>(zmm[19]);
|
||||
zmm[20] = sort_zmm_64bit<vtype>(zmm[20]);
|
||||
zmm[21] = sort_zmm_64bit<vtype>(zmm[21]);
|
||||
zmm[22] = sort_zmm_64bit<vtype>(zmm[22]);
|
||||
zmm[23] = sort_zmm_64bit<vtype>(zmm[23]);
|
||||
zmm[24] = sort_zmm_64bit<vtype>(zmm[24]);
|
||||
zmm[25] = sort_zmm_64bit<vtype>(zmm[25]);
|
||||
zmm[26] = sort_zmm_64bit<vtype>(zmm[26]);
|
||||
zmm[27] = sort_zmm_64bit<vtype>(zmm[27]);
|
||||
zmm[28] = sort_zmm_64bit<vtype>(zmm[28]);
|
||||
zmm[29] = sort_zmm_64bit<vtype>(zmm[29]);
|
||||
zmm[30] = sort_zmm_64bit<vtype>(zmm[30]);
|
||||
zmm[31] = sort_zmm_64bit<vtype>(zmm[31]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[0], zmm[1]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[2], zmm[3]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[4], zmm[5]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[6], zmm[7]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[8], zmm[9]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[10], zmm[11]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[12], zmm[13]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[14], zmm[15]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[16], zmm[17]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[18], zmm[19]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[20], zmm[21]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[22], zmm[23]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[24], zmm[25]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[26], zmm[27]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[28], zmm[29]);
|
||||
bitonic_merge_two_zmm_64bit<vtype>(zmm[30], zmm[31]);
|
||||
bitonic_merge_four_zmm_64bit<vtype>(zmm);
|
||||
bitonic_merge_four_zmm_64bit<vtype>(zmm + 4);
|
||||
bitonic_merge_four_zmm_64bit<vtype>(zmm + 8);
|
||||
bitonic_merge_four_zmm_64bit<vtype>(zmm + 12);
|
||||
bitonic_merge_four_zmm_64bit<vtype>(zmm + 16);
|
||||
bitonic_merge_four_zmm_64bit<vtype>(zmm + 20);
|
||||
bitonic_merge_four_zmm_64bit<vtype>(zmm + 24);
|
||||
bitonic_merge_four_zmm_64bit<vtype>(zmm + 28);
|
||||
bitonic_merge_eight_zmm_64bit<vtype>(zmm);
|
||||
bitonic_merge_eight_zmm_64bit<vtype>(zmm + 8);
|
||||
bitonic_merge_eight_zmm_64bit<vtype>(zmm + 16);
|
||||
bitonic_merge_eight_zmm_64bit<vtype>(zmm + 24);
|
||||
bitonic_merge_sixteen_zmm_64bit<vtype>(zmm);
|
||||
bitonic_merge_sixteen_zmm_64bit<vtype>(zmm + 16);
|
||||
bitonic_merge_32_zmm_64bit<vtype>(zmm);
|
||||
vtype::storeu(arr, zmm[0]);
|
||||
vtype::storeu(arr + 8, zmm[1]);
|
||||
vtype::storeu(arr + 16, zmm[2]);
|
||||
vtype::storeu(arr + 24, zmm[3]);
|
||||
vtype::storeu(arr + 32, zmm[4]);
|
||||
vtype::storeu(arr + 40, zmm[5]);
|
||||
vtype::storeu(arr + 48, zmm[6]);
|
||||
vtype::storeu(arr + 56, zmm[7]);
|
||||
vtype::storeu(arr + 64, zmm[8]);
|
||||
vtype::storeu(arr + 72, zmm[9]);
|
||||
vtype::storeu(arr + 80, zmm[10]);
|
||||
vtype::storeu(arr + 88, zmm[11]);
|
||||
vtype::storeu(arr + 96, zmm[12]);
|
||||
vtype::storeu(arr + 104, zmm[13]);
|
||||
vtype::storeu(arr + 112, zmm[14]);
|
||||
vtype::storeu(arr + 120, zmm[15]);
|
||||
vtype::mask_storeu(arr + 128, load_mask1, zmm[16]);
|
||||
vtype::mask_storeu(arr + 136, load_mask2, zmm[17]);
|
||||
vtype::mask_storeu(arr + 144, load_mask3, zmm[18]);
|
||||
vtype::mask_storeu(arr + 152, load_mask4, zmm[19]);
|
||||
vtype::mask_storeu(arr + 160, load_mask5, zmm[20]);
|
||||
vtype::mask_storeu(arr + 168, load_mask6, zmm[21]);
|
||||
vtype::mask_storeu(arr + 176, load_mask7, zmm[22]);
|
||||
vtype::mask_storeu(arr + 184, load_mask8, zmm[23]);
|
||||
if (N > 192) {
|
||||
vtype::mask_storeu(arr + 192, load_mask9, zmm[24]);
|
||||
vtype::mask_storeu(arr + 200, load_mask10, zmm[25]);
|
||||
vtype::mask_storeu(arr + 208, load_mask11, zmm[26]);
|
||||
vtype::mask_storeu(arr + 216, load_mask12, zmm[27]);
|
||||
vtype::mask_storeu(arr + 224, load_mask13, zmm[28]);
|
||||
vtype::mask_storeu(arr + 232, load_mask14, zmm[29]);
|
||||
vtype::mask_storeu(arr + 240, load_mask15, zmm[30]);
|
||||
vtype::mask_storeu(arr + 248, load_mask16, zmm[31]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename vtype, typename type_t>
|
||||
static void qsort_64bit_(type_t *arr, int64_t left, int64_t right,
|
||||
int64_t max_iters) {
|
||||
/*
|
||||
* Resort to std::sort if quicksort isnt making any progress
|
||||
*/
|
||||
if (max_iters <= 0) {
|
||||
std::sort(arr + left, arr + right + 1);
|
||||
return;
|
||||
}
|
||||
/*
|
||||
* Base case: use bitonic networks to sort arrays <= 128
|
||||
*/
|
||||
if (right + 1 - left <= 256) {
|
||||
sort_256_64bit<vtype>(arr + left, (int32_t)(right + 1 - left));
|
||||
return;
|
||||
}
|
||||
|
||||
type_t pivot = get_pivot_scalar<type_t>(arr, left, right);
|
||||
type_t smallest = vtype::type_max();
|
||||
type_t biggest = vtype::type_min();
|
||||
int64_t pivot_index = partition_avx512_unrolled<vtype, 8>(
|
||||
arr, left, right + 1, pivot, &smallest, &biggest, false);
|
||||
if (pivot != smallest)
|
||||
qsort_64bit_<vtype>(arr, left, pivot_index - 1, max_iters - 1);
|
||||
if (pivot != biggest)
|
||||
qsort_64bit_<vtype>(arr, pivot_index, right, max_iters - 1);
|
||||
}
|
||||
|
||||
template <>
|
||||
void inline avx512_qsort<int64_t>(int64_t *arr, int64_t fromIndex, int64_t toIndex) {
|
||||
int64_t arrsize = toIndex - fromIndex;
|
||||
if (arrsize > 1) {
|
||||
qsort_64bit_<zmm_vector<int64_t>, int64_t>(arr, fromIndex, toIndex - 1,
|
||||
2 * (int64_t)log2(arrsize));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void inline avx512_qsort<double>(double *arr, int64_t fromIndex, int64_t toIndex) {
|
||||
int64_t arrsize = toIndex - fromIndex;
|
||||
if (arrsize > 1) {
|
||||
qsort_64bit_<zmm_vector<double>, double>(arr, fromIndex, toIndex - 1,
|
||||
2 * (int64_t)log2(arrsize));
|
||||
}
|
||||
}
|
||||
|
||||
#endif // AVX512_QSORT_64BIT
|
474
src/java.base/linux/native/libsimdsort/avx512-common-qsort.h
Normal file
474
src/java.base/linux/native/libsimdsort/avx512-common-qsort.h
Normal file
@ -0,0 +1,474 @@
|
||||
/*
|
||||
* Copyright (c) 2021, 2023, Intel Corporation. All rights reserved.
|
||||
* Copyright (c) 2021 Serge Sans Paille. All rights reserved.
|
||||
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||||
*
|
||||
* This code is free software; you can redistribute it and/or modify it
|
||||
* under the terms of the GNU General Public License version 2 only, as
|
||||
* published by the Free Software Foundation.
|
||||
*
|
||||
* This code is distributed in the hope that it will be useful, but WITHOUT
|
||||
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
|
||||
* version 2 for more details (a copy is included in the LICENSE file that
|
||||
* accompanied this code).
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License version
|
||||
* 2 along with this work; if not, write to the Free Software Foundation,
|
||||
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||
*
|
||||
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
|
||||
* or visit www.oracle.com if you need additional information or have any
|
||||
* questions.
|
||||
*
|
||||
*/
|
||||
|
||||
// This implementation is based on x86-simd-sort(https://github.com/intel/x86-simd-sort)
|
||||
#ifndef AVX512_QSORT_COMMON
|
||||
#define AVX512_QSORT_COMMON
|
||||
|
||||
/*
|
||||
* Quicksort using AVX-512. The ideas and code are based on these two research
|
||||
* papers [1] and [2]. On a high level, the idea is to vectorize quicksort
|
||||
* partitioning using AVX-512 compressstore instructions. If the array size is
|
||||
* < 128, then use Bitonic sorting network implemented on 512-bit registers.
|
||||
* The precise network definitions depend on the dtype and are defined in
|
||||
* separate files: avx512-16bit-qsort.hpp, avx512-32bit-qsort.hpp and
|
||||
* avx512-64bit-qsort.hpp. Article [4] is a good resource for bitonic sorting
|
||||
* network. The core implementations of the vectorized qsort functions
|
||||
* avx512_qsort<T>(T*, int64_t) are modified versions of avx2 quicksort
|
||||
* presented in the paper [2] and source code associated with that paper [3].
|
||||
*
|
||||
* [1] Fast and Robust Vectorized In-Place Sorting of Primitive Types
|
||||
* https://drops.dagstuhl.de/opus/volltexte/2021/13775/
|
||||
*
|
||||
* [2] A Novel Hybrid Quicksort Algorithm Vectorized using AVX-512 on Intel
|
||||
* Skylake https://arxiv.org/pdf/1704.08579.pdf
|
||||
*
|
||||
* [3] https://github.com/simd-sorting/fast-and-robust: SPDX-License-Identifier:
|
||||
* MIT
|
||||
*
|
||||
* [4]
|
||||
* http://mitp-content-server.mit.edu:18180/books/content/sectbyfn?collid=books_pres_0&fn=Chapter%2027.pdf&id=8030
|
||||
*
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <immintrin.h>
|
||||
#include <limits>
|
||||
|
||||
#define X86_SIMD_SORT_INFINITY std::numeric_limits<double>::infinity()
|
||||
#define X86_SIMD_SORT_INFINITYF std::numeric_limits<float>::infinity()
|
||||
#define X86_SIMD_SORT_INFINITYH 0x7c00
|
||||
#define X86_SIMD_SORT_NEGINFINITYH 0xfc00
|
||||
#define X86_SIMD_SORT_MAX_UINT16 std::numeric_limits<uint16_t>::max()
|
||||
#define X86_SIMD_SORT_MAX_INT16 std::numeric_limits<int16_t>::max()
|
||||
#define X86_SIMD_SORT_MIN_INT16 std::numeric_limits<int16_t>::min()
|
||||
#define X86_SIMD_SORT_MAX_UINT32 std::numeric_limits<uint32_t>::max()
|
||||
#define X86_SIMD_SORT_MAX_INT32 std::numeric_limits<int32_t>::max()
|
||||
#define X86_SIMD_SORT_MIN_INT32 std::numeric_limits<int32_t>::min()
|
||||
#define X86_SIMD_SORT_MAX_UINT64 std::numeric_limits<uint64_t>::max()
|
||||
#define X86_SIMD_SORT_MAX_INT64 std::numeric_limits<int64_t>::max()
|
||||
#define X86_SIMD_SORT_MIN_INT64 std::numeric_limits<int64_t>::min()
|
||||
#define ZMM_MAX_DOUBLE _mm512_set1_pd(X86_SIMD_SORT_INFINITY)
|
||||
#define ZMM_MAX_UINT64 _mm512_set1_epi64(X86_SIMD_SORT_MAX_UINT64)
|
||||
#define ZMM_MAX_INT64 _mm512_set1_epi64(X86_SIMD_SORT_MAX_INT64)
|
||||
#define ZMM_MAX_FLOAT _mm512_set1_ps(X86_SIMD_SORT_INFINITYF)
|
||||
#define ZMM_MAX_UINT _mm512_set1_epi32(X86_SIMD_SORT_MAX_UINT32)
|
||||
#define ZMM_MAX_INT _mm512_set1_epi32(X86_SIMD_SORT_MAX_INT32)
|
||||
#define ZMM_MAX_HALF _mm512_set1_epi16(X86_SIMD_SORT_INFINITYH)
|
||||
#define YMM_MAX_HALF _mm256_set1_epi16(X86_SIMD_SORT_INFINITYH)
|
||||
#define ZMM_MAX_UINT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_UINT16)
|
||||
#define ZMM_MAX_INT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_INT16)
|
||||
#define SHUFFLE_MASK(a, b, c, d) (a << 6) | (b << 4) | (c << 2) | d
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define X86_SIMD_SORT_INLINE static inline
|
||||
#define X86_SIMD_SORT_FINLINE static __forceinline
|
||||
#elif defined(__CYGWIN__)
|
||||
/*
|
||||
* Force inline in cygwin to work around a compiler bug. See
|
||||
* https://github.com/numpy/numpy/pull/22315#issuecomment-1267757584
|
||||
*/
|
||||
#define X86_SIMD_SORT_INLINE static __attribute__((always_inline))
|
||||
#define X86_SIMD_SORT_FINLINE static __attribute__((always_inline))
|
||||
#elif defined(__GNUC__)
|
||||
#define X86_SIMD_SORT_INLINE static inline
|
||||
#define X86_SIMD_SORT_FINLINE static __attribute__((always_inline))
|
||||
#else
|
||||
#define X86_SIMD_SORT_INLINE static
|
||||
#define X86_SIMD_SORT_FINLINE static
|
||||
#endif
|
||||
|
||||
#define LIKELY(x) __builtin_expect((x), 1)
|
||||
#define UNLIKELY(x) __builtin_expect((x), 0)
|
||||
|
||||
template <typename type>
|
||||
struct zmm_vector;
|
||||
|
||||
template <typename type>
|
||||
struct ymm_vector;
|
||||
|
||||
// Regular quicksort routines:
|
||||
template <typename T>
|
||||
void avx512_qsort(T *arr, int64_t arrsize);
|
||||
|
||||
template <typename T>
|
||||
void inline avx512_qsort(T *arr, int64_t from_index, int64_t to_index);
|
||||
|
||||
template <typename T>
|
||||
bool is_a_nan(T elem) {
|
||||
return std::isnan(elem);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
X86_SIMD_SORT_INLINE T get_pivot_scalar(T *arr, const int64_t left, const int64_t right) {
|
||||
// median of 8 equally spaced elements
|
||||
int64_t NUM_ELEMENTS = 8;
|
||||
int64_t MID = NUM_ELEMENTS / 2;
|
||||
int64_t size = (right - left) / NUM_ELEMENTS;
|
||||
T temp[NUM_ELEMENTS];
|
||||
for (int64_t i = 0; i < NUM_ELEMENTS; i++) temp[i] = arr[left + (i * size)];
|
||||
std::sort(temp, temp + NUM_ELEMENTS);
|
||||
return temp[MID];
|
||||
}
|
||||
|
||||
template <typename vtype, typename T = typename vtype::type_t>
|
||||
bool comparison_func_ge(const T &a, const T &b) {
|
||||
return a < b;
|
||||
}
|
||||
|
||||
template <typename vtype, typename T = typename vtype::type_t>
|
||||
bool comparison_func_gt(const T &a, const T &b) {
|
||||
return a <= b;
|
||||
}
|
||||
|
||||
/*
|
||||
* COEX == Compare and Exchange two registers by swapping min and max values
|
||||
*/
|
||||
template <typename vtype, typename mm_t>
|
||||
static void COEX(mm_t &a, mm_t &b) {
|
||||
mm_t temp = a;
|
||||
a = vtype::min(a, b);
|
||||
b = vtype::max(temp, b);
|
||||
}
|
||||
template <typename vtype, typename zmm_t = typename vtype::zmm_t,
|
||||
typename opmask_t = typename vtype::opmask_t>
|
||||
static inline zmm_t cmp_merge(zmm_t in1, zmm_t in2, opmask_t mask) {
|
||||
zmm_t min = vtype::min(in2, in1);
|
||||
zmm_t max = vtype::max(in2, in1);
|
||||
return vtype::mask_mov(min, mask, max); // 0 -> min, 1 -> max
|
||||
}
|
||||
/*
|
||||
* Parition one ZMM register based on the pivot and returns the
|
||||
* number of elements that are greater than or equal to the pivot.
|
||||
*/
|
||||
template <typename vtype, typename type_t, typename zmm_t>
|
||||
static inline int32_t partition_vec(type_t *arr, int64_t left, int64_t right,
|
||||
const zmm_t curr_vec, const zmm_t pivot_vec,
|
||||
zmm_t *smallest_vec, zmm_t *biggest_vec, bool use_gt) {
|
||||
/* which elements are larger than or equal to the pivot */
|
||||
typename vtype::opmask_t mask;
|
||||
if (use_gt) mask = vtype::gt(curr_vec, pivot_vec);
|
||||
else mask = vtype::ge(curr_vec, pivot_vec);
|
||||
//mask = vtype::ge(curr_vec, pivot_vec);
|
||||
int32_t amount_ge_pivot = _mm_popcnt_u32((int32_t)mask);
|
||||
vtype::mask_compressstoreu(arr + left, vtype::knot_opmask(mask),
|
||||
curr_vec);
|
||||
vtype::mask_compressstoreu(arr + right - amount_ge_pivot, mask,
|
||||
curr_vec);
|
||||
*smallest_vec = vtype::min(curr_vec, *smallest_vec);
|
||||
*biggest_vec = vtype::max(curr_vec, *biggest_vec);
|
||||
return amount_ge_pivot;
|
||||
}
|
||||
/*
|
||||
* Parition an array based on the pivot and returns the index of the
|
||||
* first element that is greater than or equal to the pivot.
|
||||
*/
|
||||
template <typename vtype, typename type_t>
|
||||
static inline int64_t partition_avx512(type_t *arr, int64_t left, int64_t right,
|
||||
type_t pivot, type_t *smallest,
|
||||
type_t *biggest, bool use_gt) {
|
||||
auto comparison_func = use_gt ? comparison_func_gt<vtype> : comparison_func_ge<vtype>;
|
||||
/* make array length divisible by vtype::numlanes , shortening the array */
|
||||
for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) {
|
||||
*smallest = std::min(*smallest, arr[left], comparison_func);
|
||||
*biggest = std::max(*biggest, arr[left], comparison_func);
|
||||
if (!comparison_func(arr[left], pivot)) {
|
||||
std::swap(arr[left], arr[--right]);
|
||||
} else {
|
||||
++left;
|
||||
}
|
||||
}
|
||||
|
||||
if (left == right)
|
||||
return left; /* less than vtype::numlanes elements in the array */
|
||||
|
||||
using zmm_t = typename vtype::zmm_t;
|
||||
zmm_t pivot_vec = vtype::set1(pivot);
|
||||
zmm_t min_vec = vtype::set1(*smallest);
|
||||
zmm_t max_vec = vtype::set1(*biggest);
|
||||
|
||||
if (right - left == vtype::numlanes) {
|
||||
zmm_t vec = vtype::loadu(arr + left);
|
||||
int32_t amount_ge_pivot =
|
||||
partition_vec<vtype>(arr, left, left + vtype::numlanes, vec,
|
||||
pivot_vec, &min_vec, &max_vec, use_gt);
|
||||
*smallest = vtype::reducemin(min_vec);
|
||||
*biggest = vtype::reducemax(max_vec);
|
||||
return left + (vtype::numlanes - amount_ge_pivot);
|
||||
}
|
||||
|
||||
// first and last vtype::numlanes values are partitioned at the end
|
||||
zmm_t vec_left = vtype::loadu(arr + left);
|
||||
zmm_t vec_right = vtype::loadu(arr + (right - vtype::numlanes));
|
||||
// store points of the vectors
|
||||
int64_t r_store = right - vtype::numlanes;
|
||||
int64_t l_store = left;
|
||||
// indices for loading the elements
|
||||
left += vtype::numlanes;
|
||||
right -= vtype::numlanes;
|
||||
while (right - left != 0) {
|
||||
zmm_t curr_vec;
|
||||
/*
|
||||
* if fewer elements are stored on the right side of the array,
|
||||
* then next elements are loaded from the right side,
|
||||
* otherwise from the left side
|
||||
*/
|
||||
if ((r_store + vtype::numlanes) - right < left - l_store) {
|
||||
right -= vtype::numlanes;
|
||||
curr_vec = vtype::loadu(arr + right);
|
||||
} else {
|
||||
curr_vec = vtype::loadu(arr + left);
|
||||
left += vtype::numlanes;
|
||||
}
|
||||
// partition the current vector and save it on both sides of the array
|
||||
int32_t amount_ge_pivot =
|
||||
partition_vec<vtype>(arr, l_store, r_store + vtype::numlanes,
|
||||
curr_vec, pivot_vec, &min_vec, &max_vec, use_gt);
|
||||
;
|
||||
r_store -= amount_ge_pivot;
|
||||
l_store += (vtype::numlanes - amount_ge_pivot);
|
||||
}
|
||||
|
||||
/* partition and save vec_left and vec_right */
|
||||
int32_t amount_ge_pivot =
|
||||
partition_vec<vtype>(arr, l_store, r_store + vtype::numlanes, vec_left,
|
||||
pivot_vec, &min_vec, &max_vec, use_gt);
|
||||
l_store += (vtype::numlanes - amount_ge_pivot);
|
||||
amount_ge_pivot =
|
||||
partition_vec<vtype>(arr, l_store, l_store + vtype::numlanes, vec_right,
|
||||
pivot_vec, &min_vec, &max_vec, use_gt);
|
||||
l_store += (vtype::numlanes - amount_ge_pivot);
|
||||
*smallest = vtype::reducemin(min_vec);
|
||||
*biggest = vtype::reducemax(max_vec);
|
||||
return l_store;
|
||||
}
|
||||
|
||||
template <typename vtype, int num_unroll,
|
||||
typename type_t = typename vtype::type_t>
|
||||
static inline int64_t partition_avx512_unrolled(type_t *arr, int64_t left,
|
||||
int64_t right, type_t pivot,
|
||||
type_t *smallest,
|
||||
type_t *biggest, bool use_gt) {
|
||||
if (right - left <= 2 * num_unroll * vtype::numlanes) {
|
||||
return partition_avx512<vtype>(arr, left, right, pivot, smallest,
|
||||
biggest, use_gt);
|
||||
}
|
||||
|
||||
auto comparison_func = use_gt ? comparison_func_gt<vtype> : comparison_func_ge<vtype>;
|
||||
/* make array length divisible by 8*vtype::numlanes , shortening the array
|
||||
*/
|
||||
for (int32_t i = ((right - left) % (num_unroll * vtype::numlanes)); i > 0;
|
||||
--i) {
|
||||
*smallest = std::min(*smallest, arr[left], comparison_func);
|
||||
*biggest = std::max(*biggest, arr[left], comparison_func);
|
||||
if (!comparison_func(arr[left], pivot)) {
|
||||
std::swap(arr[left], arr[--right]);
|
||||
} else {
|
||||
++left;
|
||||
}
|
||||
}
|
||||
|
||||
if (left == right)
|
||||
return left; /* less than vtype::numlanes elements in the array */
|
||||
|
||||
using zmm_t = typename vtype::zmm_t;
|
||||
zmm_t pivot_vec = vtype::set1(pivot);
|
||||
zmm_t min_vec = vtype::set1(*smallest);
|
||||
zmm_t max_vec = vtype::set1(*biggest);
|
||||
|
||||
// We will now have atleast 16 registers worth of data to process:
|
||||
// left and right vtype::numlanes values are partitioned at the end
|
||||
zmm_t vec_left[num_unroll], vec_right[num_unroll];
|
||||
#pragma GCC unroll 8
|
||||
for (int ii = 0; ii < num_unroll; ++ii) {
|
||||
vec_left[ii] = vtype::loadu(arr + left + vtype::numlanes * ii);
|
||||
vec_right[ii] =
|
||||
vtype::loadu(arr + (right - vtype::numlanes * (num_unroll - ii)));
|
||||
}
|
||||
// store points of the vectors
|
||||
int64_t r_store = right - vtype::numlanes;
|
||||
int64_t l_store = left;
|
||||
// indices for loading the elements
|
||||
left += num_unroll * vtype::numlanes;
|
||||
right -= num_unroll * vtype::numlanes;
|
||||
while (right - left != 0) {
|
||||
zmm_t curr_vec[num_unroll];
|
||||
/*
|
||||
* if fewer elements are stored on the right side of the array,
|
||||
* then next elements are loaded from the right side,
|
||||
* otherwise from the left side
|
||||
*/
|
||||
if ((r_store + vtype::numlanes) - right < left - l_store) {
|
||||
right -= num_unroll * vtype::numlanes;
|
||||
#pragma GCC unroll 8
|
||||
for (int ii = 0; ii < num_unroll; ++ii) {
|
||||
curr_vec[ii] = vtype::loadu(arr + right + ii * vtype::numlanes);
|
||||
}
|
||||
} else {
|
||||
#pragma GCC unroll 8
|
||||
for (int ii = 0; ii < num_unroll; ++ii) {
|
||||
curr_vec[ii] = vtype::loadu(arr + left + ii * vtype::numlanes);
|
||||
}
|
||||
left += num_unroll * vtype::numlanes;
|
||||
}
|
||||
// partition the current vector and save it on both sides of the array
|
||||
#pragma GCC unroll 8
|
||||
for (int ii = 0; ii < num_unroll; ++ii) {
|
||||
int32_t amount_ge_pivot = partition_vec<vtype>(
|
||||
arr, l_store, r_store + vtype::numlanes, curr_vec[ii],
|
||||
pivot_vec, &min_vec, &max_vec, use_gt);
|
||||
l_store += (vtype::numlanes - amount_ge_pivot);
|
||||
r_store -= amount_ge_pivot;
|
||||
}
|
||||
}
|
||||
|
||||
/* partition and save vec_left[8] and vec_right[8] */
|
||||
#pragma GCC unroll 8
|
||||
for (int ii = 0; ii < num_unroll; ++ii) {
|
||||
int32_t amount_ge_pivot =
|
||||
partition_vec<vtype>(arr, l_store, r_store + vtype::numlanes,
|
||||
vec_left[ii], pivot_vec, &min_vec, &max_vec, use_gt);
|
||||
l_store += (vtype::numlanes - amount_ge_pivot);
|
||||
r_store -= amount_ge_pivot;
|
||||
}
|
||||
#pragma GCC unroll 8
|
||||
for (int ii = 0; ii < num_unroll; ++ii) {
|
||||
int32_t amount_ge_pivot =
|
||||
partition_vec<vtype>(arr, l_store, r_store + vtype::numlanes,
|
||||
vec_right[ii], pivot_vec, &min_vec, &max_vec, use_gt);
|
||||
l_store += (vtype::numlanes - amount_ge_pivot);
|
||||
r_store -= amount_ge_pivot;
|
||||
}
|
||||
*smallest = vtype::reducemin(min_vec);
|
||||
*biggest = vtype::reducemax(max_vec);
|
||||
return l_store;
|
||||
}
|
||||
|
||||
// to_index (exclusive)
|
||||
template <typename vtype, typename type_t>
|
||||
static int64_t vectorized_partition(type_t *arr, int64_t from_index, int64_t to_index, type_t pivot, bool use_gt) {
|
||||
type_t smallest = vtype::type_max();
|
||||
type_t biggest = vtype::type_min();
|
||||
int64_t pivot_index = partition_avx512_unrolled<vtype, 2>(
|
||||
arr, from_index, to_index, pivot, &smallest, &biggest, use_gt);
|
||||
return pivot_index;
|
||||
}
|
||||
|
||||
// partitioning functions
|
||||
template <typename T>
|
||||
void avx512_dual_pivot_partition(T *arr, int64_t from_index, int64_t to_index, int32_t *pivot_indices, int64_t index_pivot1, int64_t index_pivot2){
|
||||
const T pivot1 = arr[index_pivot1];
|
||||
const T pivot2 = arr[index_pivot2];
|
||||
|
||||
const int64_t low = from_index;
|
||||
const int64_t high = to_index;
|
||||
const int64_t start = low + 1;
|
||||
const int64_t end = high - 1;
|
||||
|
||||
|
||||
std::swap(arr[index_pivot1], arr[low]);
|
||||
std::swap(arr[index_pivot2], arr[end]);
|
||||
|
||||
|
||||
const int64_t pivot_index2 = vectorized_partition<zmm_vector<T>, T>(arr, start, end, pivot2, true); // use_gt = true
|
||||
std::swap(arr[end], arr[pivot_index2]);
|
||||
int64_t upper = pivot_index2;
|
||||
|
||||
// if all other elements are greater than pivot2 (and pivot1), no need to do further partitioning
|
||||
if (upper == start) {
|
||||
pivot_indices[0] = low;
|
||||
pivot_indices[1] = upper;
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t pivot_index1 = vectorized_partition<zmm_vector<T>, T>(arr, start, upper, pivot1, false); // use_ge (use_gt = false)
|
||||
int64_t lower = pivot_index1 - 1;
|
||||
std::swap(arr[low], arr[lower]);
|
||||
|
||||
pivot_indices[0] = lower;
|
||||
pivot_indices[1] = upper;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void avx512_single_pivot_partition(T *arr, int64_t from_index, int64_t to_index, int32_t *pivot_indices, int64_t index_pivot){
|
||||
const T pivot = arr[index_pivot];
|
||||
|
||||
const int64_t low = from_index;
|
||||
const int64_t high = to_index;
|
||||
const int64_t end = high - 1;
|
||||
|
||||
|
||||
const int64_t pivot_index1 = vectorized_partition<zmm_vector<T>, T>(arr, low, high, pivot, false); // use_gt = false (use_ge)
|
||||
int64_t lower = pivot_index1;
|
||||
|
||||
const int64_t pivot_index2 = vectorized_partition<zmm_vector<T>, T>(arr, pivot_index1, high, pivot, true); // use_gt = true
|
||||
int64_t upper = pivot_index2;
|
||||
|
||||
pivot_indices[0] = lower;
|
||||
pivot_indices[1] = upper;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void inline avx512_fast_partition(T *arr, int64_t from_index, int64_t to_index, int32_t *pivot_indices, int64_t index_pivot1, int64_t index_pivot2) {
|
||||
if (index_pivot1 != index_pivot2) {
|
||||
avx512_dual_pivot_partition<T>(arr, from_index, to_index, pivot_indices, index_pivot1, index_pivot2);
|
||||
}
|
||||
else {
|
||||
avx512_single_pivot_partition<T>(arr, from_index, to_index, pivot_indices, index_pivot1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void inline insertion_sort(T *arr, int32_t from_index, int32_t to_index) {
|
||||
for (int i, k = from_index; ++k < to_index; ) {
|
||||
T ai = arr[i = k];
|
||||
|
||||
if (ai < arr[i - 1]) {
|
||||
while (--i >= from_index && ai < arr[i]) {
|
||||
arr[i + 1] = arr[i];
|
||||
}
|
||||
arr[i + 1] = ai;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void inline avx512_fast_sort(T *arr, int64_t from_index, int64_t to_index, const int32_t INS_SORT_THRESHOLD) {
|
||||
int32_t size = to_index - from_index;
|
||||
|
||||
if (size <= INS_SORT_THRESHOLD) {
|
||||
insertion_sort<T>(arr, from_index, to_index);
|
||||
}
|
||||
else {
|
||||
avx512_qsort<T>(arr, from_index, to_index);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#endif // AVX512_QSORT_COMMON
|
@ -0,0 +1,70 @@
|
||||
/*
|
||||
* Copyright (c) 2023 Intel Corporation. All rights reserved.
|
||||
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||||
*
|
||||
* This code is free software; you can redistribute it and/or modify it
|
||||
* under the terms of the GNU General Public License version 2 only, as
|
||||
* published by the Free Software Foundation.
|
||||
*
|
||||
* This code is distributed in the hope that it will be useful, but WITHOUT
|
||||
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
|
||||
* version 2 for more details (a copy is included in the LICENSE file that
|
||||
* accompanied this code).
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License version
|
||||
* 2 along with this work; if not, write to the Free Software Foundation,
|
||||
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||
*
|
||||
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
|
||||
* or visit www.oracle.com if you need additional information or have any
|
||||
* questions.
|
||||
*
|
||||
*/
|
||||
|
||||
#pragma GCC target("avx512dq", "avx512f")
|
||||
#include "avx512-32bit-qsort.hpp"
|
||||
#include "avx512-64bit-qsort.hpp"
|
||||
#include "classfile_constants.h"
|
||||
|
||||
#define DLL_PUBLIC __attribute__((visibility("default")))
|
||||
#define INSERTION_SORT_THRESHOLD_32BIT 16
|
||||
#define INSERTION_SORT_THRESHOLD_64BIT 20
|
||||
|
||||
extern "C" {
|
||||
|
||||
DLL_PUBLIC void avx512_sort(void *array, int elem_type, int32_t from_index, int32_t to_index) {
|
||||
switch(elem_type) {
|
||||
case JVM_T_INT:
|
||||
avx512_fast_sort<int32_t>((int32_t*)array, from_index, to_index, INSERTION_SORT_THRESHOLD_32BIT);
|
||||
break;
|
||||
case JVM_T_LONG:
|
||||
avx512_fast_sort<int64_t>((int64_t*)array, from_index, to_index, INSERTION_SORT_THRESHOLD_64BIT);
|
||||
break;
|
||||
case JVM_T_FLOAT:
|
||||
avx512_fast_sort<float>((float*)array, from_index, to_index, INSERTION_SORT_THRESHOLD_32BIT);
|
||||
break;
|
||||
case JVM_T_DOUBLE:
|
||||
avx512_fast_sort<double>((double*)array, from_index, to_index, INSERTION_SORT_THRESHOLD_64BIT);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
DLL_PUBLIC void avx512_partition(void *array, int elem_type, int32_t from_index, int32_t to_index, int32_t *pivot_indices, int32_t index_pivot1, int32_t index_pivot2) {
|
||||
switch(elem_type) {
|
||||
case JVM_T_INT:
|
||||
avx512_fast_partition<int32_t>((int32_t*)array, from_index, to_index, pivot_indices, index_pivot1, index_pivot2);
|
||||
break;
|
||||
case JVM_T_LONG:
|
||||
avx512_fast_partition<int64_t>((int64_t*)array, from_index, to_index, pivot_indices, index_pivot1, index_pivot2);
|
||||
break;
|
||||
case JVM_T_FLOAT:
|
||||
avx512_fast_partition<float>((float*)array, from_index, to_index, pivot_indices, index_pivot1, index_pivot2);
|
||||
break;
|
||||
case JVM_T_DOUBLE:
|
||||
avx512_fast_partition<double>((double*)array, from_index, to_index, pivot_indices, index_pivot1, index_pivot2);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2009, 2019, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2009, 2023, Oracle and/or its affiliates. All rights reserved.
|
||||
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||||
*
|
||||
* This code is free software; you can redistribute it and/or modify it
|
||||
@ -26,7 +26,8 @@
|
||||
* @compile/module=java.base java/util/SortingHelper.java
|
||||
* @bug 6880672 6896573 6899694 6976036 7013585 7018258 8003981 8226297
|
||||
* @build Sorting
|
||||
* @run main Sorting -shortrun
|
||||
* @run main/othervm -XX:+UnlockDiagnosticVMOptions -XX:DisableIntrinsic=_arraySort,_arrayPartition Sorting -shortrun
|
||||
* @run main/othervm -XX:-TieredCompilation -XX:CompileCommand=CompileThresholdScaling,java.util.DualPivotQuicksort::sort,0.0001 Sorting -shortrun
|
||||
* @summary Exercise Arrays.sort, Arrays.parallelSort
|
||||
*
|
||||
* @author Vladimir Yaroslavskiy
|
||||
|
163
test/micro/org/openjdk/bench/java/util/ArraysSort.java
Normal file
163
test/micro/org/openjdk/bench/java/util/ArraysSort.java
Normal file
@ -0,0 +1,163 @@
|
||||
/*
|
||||
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
|
||||
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||||
*
|
||||
* This code is free software; you can redistribute it and/or modify it
|
||||
* under the terms of the GNU General Public License version 2 only, as
|
||||
* published by the Free Software Foundation.
|
||||
*
|
||||
* This code is distributed in the hope that it will be useful, but WITHOUT
|
||||
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
|
||||
* version 2 for more details (a copy is included in the LICENSE file that
|
||||
* accompanied this code).
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License version
|
||||
* 2 along with this work; if not, write to the Free Software Foundation,
|
||||
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||
*
|
||||
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
|
||||
* or visit www.oracle.com if you need additional information or have any
|
||||
* questions.
|
||||
*/
|
||||
package org.openjdk.bench.java.lang;
|
||||
|
||||
import org.openjdk.jmh.annotations.Benchmark;
|
||||
import org.openjdk.jmh.annotations.BenchmarkMode;
|
||||
import org.openjdk.jmh.annotations.Fork;
|
||||
import org.openjdk.jmh.annotations.Measurement;
|
||||
import org.openjdk.jmh.annotations.Mode;
|
||||
import org.openjdk.jmh.annotations.OperationsPerInvocation;
|
||||
import org.openjdk.jmh.annotations.OutputTimeUnit;
|
||||
import org.openjdk.jmh.annotations.Param;
|
||||
import org.openjdk.jmh.annotations.Scope;
|
||||
import org.openjdk.jmh.annotations.Setup;
|
||||
import org.openjdk.jmh.annotations.State;
|
||||
import org.openjdk.jmh.annotations.Level;
|
||||
import org.openjdk.jmh.annotations.Warmup;
|
||||
import org.openjdk.jmh.infra.Blackhole;
|
||||
import java.util.Arrays;
|
||||
import java.util.Random;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.io.UnsupportedEncodingException;
|
||||
import java.lang.invoke.MethodHandle;
|
||||
import java.lang.invoke.MethodHandles;
|
||||
import java.lang.reflect.Method;
|
||||
|
||||
/**
|
||||
* Performance test of Arrays.sort() methods
|
||||
*/
|
||||
@Fork(value=1, jvmArgsAppend={"-XX:CompileThreshold=1", "-XX:-TieredCompilation"})
|
||||
@BenchmarkMode(Mode.AverageTime)
|
||||
@OutputTimeUnit(TimeUnit.MICROSECONDS)
|
||||
@State(Scope.Thread)
|
||||
@Warmup(iterations = 3, time=5)
|
||||
@Measurement(iterations = 3, time=3)
|
||||
public class ArraysSort {
|
||||
|
||||
@Param({"10","25","50","75","100", "1000", "10000", "100000", "1000000"})
|
||||
private int size;
|
||||
|
||||
private int[] ints_unsorted;
|
||||
private long[] longs_unsorted;
|
||||
private float[] floats_unsorted;
|
||||
private double[] doubles_unsorted;
|
||||
|
||||
private int[] ints_sorted;
|
||||
private long[] longs_sorted;
|
||||
private float[] floats_sorted;
|
||||
private double[] doubles_sorted;
|
||||
|
||||
|
||||
public void initialize() {
|
||||
Random rnd = new Random(42);
|
||||
|
||||
ints_unsorted = new int[size];
|
||||
longs_unsorted = new long[size];
|
||||
floats_unsorted = new float[size];
|
||||
doubles_unsorted = new double[size];
|
||||
|
||||
int[] intSpecialCases = {Integer.MIN_VALUE, Integer.MAX_VALUE};
|
||||
long[] longSpecialCases = {Long.MIN_VALUE, Long.MAX_VALUE};
|
||||
float[] floatSpecialCases = {+0.0f, -0.0f, Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY, Float.NaN};
|
||||
double[] doubleSpecialCases = {+0.0, -0.0, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, Double.NaN};
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
ints_unsorted[i] = rnd.nextInt();
|
||||
longs_unsorted[i] = rnd.nextLong();
|
||||
if (i % 10 != 0) {
|
||||
ints_unsorted[i] = rnd.nextInt();
|
||||
longs_unsorted[i] = rnd.nextLong();
|
||||
floats_unsorted[i] = rnd.nextFloat();
|
||||
doubles_unsorted[i] = rnd.nextDouble();
|
||||
} else {
|
||||
ints_unsorted[i] = intSpecialCases[rnd.nextInt(intSpecialCases.length)];
|
||||
longs_unsorted[i] = longSpecialCases[rnd.nextInt(longSpecialCases.length)];
|
||||
floats_unsorted[i] = floatSpecialCases[rnd.nextInt(floatSpecialCases.length)];
|
||||
doubles_unsorted[i] = doubleSpecialCases[rnd.nextInt(doubleSpecialCases.length)];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Setup
|
||||
public void setup() throws UnsupportedEncodingException, ClassNotFoundException, NoSuchMethodException, Throwable {
|
||||
initialize();
|
||||
}
|
||||
|
||||
@Setup(Level.Invocation)
|
||||
public void clear() {
|
||||
ints_sorted = ints_unsorted.clone();
|
||||
longs_sorted = longs_unsorted.clone();
|
||||
floats_sorted = floats_unsorted.clone();
|
||||
doubles_sorted = doubles_unsorted.clone();
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public int[] intSort() throws Throwable {
|
||||
Arrays.sort(ints_sorted);
|
||||
return ints_sorted;
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public int[] intParallelSort() throws Throwable {
|
||||
Arrays.parallelSort(ints_sorted);
|
||||
return ints_sorted;
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public long[] longSort() throws Throwable {
|
||||
Arrays.sort(longs_sorted);
|
||||
return longs_sorted;
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public long[] longParallelSort() throws Throwable {
|
||||
Arrays.parallelSort(longs_sorted);
|
||||
return longs_sorted;
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public float[] floatSort() throws Throwable {
|
||||
Arrays.sort(floats_sorted);
|
||||
return floats_sorted;
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public float[] floatParallelSort() throws Throwable {
|
||||
Arrays.parallelSort(floats_sorted);
|
||||
return floats_sorted;
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public double[] doubleSort() throws Throwable {
|
||||
Arrays.sort(doubles_sorted);
|
||||
return doubles_sorted;
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public double[] doubleParallelSort() throws Throwable {
|
||||
Arrays.parallelSort(doubles_sorted);
|
||||
return doubles_sorted;
|
||||
}
|
||||
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user