8199868: Support JNI critical functions in object pinning API

Pin/unpin incoming array arguments of critical native JNI call

Reviewed-by: shade, adinn
This commit is contained in:
Zhengyu Gu 2018-07-31 13:12:06 -04:00
parent 9d25c65fda
commit b71f3e7104
6 changed files with 493 additions and 2 deletions

@ -1,5 +1,5 @@
/*
* Copyright (c) 2003, 2017, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2003, 2018, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -1434,6 +1434,64 @@ static void save_or_restore_arguments(MacroAssembler* masm,
}
}
// Pin object, return pinned object or null in rax
static void gen_pin_object(MacroAssembler* masm,
VMRegPair reg) {
__ block_comment("gen_pin_object {");
// rax always contains oop, either incoming or
// pinned.
Register tmp_reg = rax;
Label is_null;
VMRegPair tmp;
VMRegPair in_reg = reg;
tmp.set_ptr(tmp_reg->as_VMReg());
if (reg.first()->is_stack()) {
// Load the arg up from the stack
move_ptr(masm, reg, tmp);
reg = tmp;
} else {
__ movptr(rax, reg.first()->as_Register());
}
__ testptr(reg.first()->as_Register(), reg.first()->as_Register());
__ jccb(Assembler::equal, is_null);
if (reg.first()->as_Register() != c_rarg1) {
__ movptr(c_rarg1, reg.first()->as_Register());
}
__ call_VM_leaf(
CAST_FROM_FN_PTR(address, SharedRuntime::pin_object),
r15_thread, c_rarg1);
__ bind(is_null);
__ block_comment("} gen_pin_object");
}
// Unpin object
static void gen_unpin_object(MacroAssembler* masm,
VMRegPair reg) {
__ block_comment("gen_unpin_object {");
Label is_null;
if (reg.first()->is_stack()) {
__ movptr(c_rarg1, Address(rbp, reg2offset_in(reg.first())));
} else if (reg.first()->as_Register() != c_rarg1) {
__ movptr(c_rarg1, reg.first()->as_Register());
}
__ testptr(c_rarg1, c_rarg1);
__ jccb(Assembler::equal, is_null);
__ call_VM_leaf(
CAST_FROM_FN_PTR(address, SharedRuntime::unpin_object),
r15_thread, c_rarg1);
__ bind(is_null);
__ block_comment("} gen_unpin_object");
}
// Check GCLocker::needs_gc and enter the runtime if it's true. This
// keeps a new JNI critical region from starting until a GC has been
@ -2129,7 +2187,7 @@ nmethod* SharedRuntime::generate_native_wrapper(MacroAssembler* masm,
const Register oop_handle_reg = r14;
if (is_critical_native) {
if (is_critical_native && !Universe::heap()->supports_object_pinning()) {
check_needs_gc_for_critical_native(masm, stack_slots, total_c_args, total_in_args,
oop_handle_offset, oop_maps, in_regs, in_sig_bt);
}
@ -2186,6 +2244,11 @@ nmethod* SharedRuntime::generate_native_wrapper(MacroAssembler* masm,
// the incoming and outgoing registers are offset upwards and for
// critical natives they are offset down.
GrowableArray<int> arg_order(2 * total_in_args);
// Inbound arguments that need to be pinned for critical natives
GrowableArray<int> pinned_args(total_in_args);
// Current stack slot for storing register based array argument
int pinned_slot = oop_handle_offset;
VMRegPair tmp_vmreg;
tmp_vmreg.set2(rbx->as_VMReg());
@ -2233,6 +2296,23 @@ nmethod* SharedRuntime::generate_native_wrapper(MacroAssembler* masm,
switch (in_sig_bt[i]) {
case T_ARRAY:
if (is_critical_native) {
// pin before unpack
if (Universe::heap()->supports_object_pinning()) {
save_args(masm, total_c_args, 0, out_regs);
gen_pin_object(masm, in_regs[i]);
pinned_args.append(i);
restore_args(masm, total_c_args, 0, out_regs);
// rax has pinned array
VMRegPair result_reg;
result_reg.set_ptr(rax->as_VMReg());
move_ptr(masm, result_reg, in_regs[i]);
if (!in_regs[i].first()->is_stack()) {
assert(pinned_slot <= stack_slots, "overflow");
move_ptr(masm, result_reg, VMRegImpl::stack2reg(pinned_slot));
pinned_slot += VMRegImpl::slots_per_word;
}
}
unpack_array_argument(masm, in_regs[i], in_elem_bt[i], out_regs[c_arg + 1], out_regs[c_arg]);
c_arg++;
#ifdef ASSERT
@ -2449,6 +2529,24 @@ nmethod* SharedRuntime::generate_native_wrapper(MacroAssembler* masm,
default : ShouldNotReachHere();
}
// unpin pinned arguments
pinned_slot = oop_handle_offset;
if (pinned_args.length() > 0) {
// save return value that may be overwritten otherwise.
save_native_result(masm, ret_type, stack_slots);
for (int index = 0; index < pinned_args.length(); index ++) {
int i = pinned_args.at(index);
assert(pinned_slot <= stack_slots, "overflow");
if (!in_regs[i].first()->is_stack()) {
int offset = pinned_slot * VMRegImpl::stack_slot_size;
__ movq(in_regs[i].first()->as_Register(), Address(rsp, offset));
pinned_slot += VMRegImpl::slots_per_word;
}
gen_unpin_object(masm, in_regs[i]);
}
restore_native_result(masm, ret_type, stack_slots);
}
// Switch thread to "native transition" state before reading the synchronization state.
// This additional state is necessary because reading and testing the synchronization
// state is not atomic w.r.t. GC, as this scenario demonstrates:

@ -2863,6 +2863,22 @@ JRT_ENTRY_NO_ASYNC(void, SharedRuntime::block_for_jni_critical(JavaThread* threa
GCLocker::unlock_critical(thread);
JRT_END
JRT_LEAF(oopDesc*, SharedRuntime::pin_object(JavaThread* thread, oopDesc* obj))
assert(Universe::heap()->supports_object_pinning(), "Why we are here?");
assert(obj != NULL, "Should not be null");
oop o(obj);
o = Universe::heap()->pin_object(thread, o);
assert(o != NULL, "Should not be null");
return o;
JRT_END
JRT_LEAF(void, SharedRuntime::unpin_object(JavaThread* thread, oopDesc* obj))
assert(Universe::heap()->supports_object_pinning(), "Why we are here?");
assert(obj != NULL, "Should not be null");
oop o(obj);
Universe::heap()->unpin_object(thread, o);
JRT_END
// -------------------------------------------------------------------------
// Java-Java calling convention
// (what you use when Java calls Java)

@ -486,6 +486,10 @@ class SharedRuntime: AllStatic {
// Block before entering a JNI critical method
static void block_for_jni_critical(JavaThread* thread);
// Pin/Unpin object
static oopDesc* pin_object(JavaThread* thread, oopDesc* obj);
static void unpin_object(JavaThread* thread, oopDesc* obj);
// A compiled caller has just called the interpreter, but compiled code
// exists. Patch the caller so he no longer calls into the interpreter.
static void fixup_callers_callsite(Method* moop, address ret_pc);

@ -0,0 +1,52 @@
/*
* Copyright (c) 2018, Red Hat, Inc. and/or its affiliates.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*
*/
/*
* @test CriticalNativeStress
* @key gc
* @bug 8199868
* @requires (os.arch =="x86_64" | os.arch == "amd64") & (vm.bits == "64") & vm.gc.Epsilon & !vm.graal.enabled
* @summary test argument unpacking nmethod wrapper of critical native method
* @run main/othervm/native -XX:+UnlockExperimentalVMOptions -XX:+UseEpsilonGC -Xcomp -Xmx256M -XX:+CriticalJNINatives CriticalNativeArgs
*/
public class CriticalNativeArgs {
static {
System.loadLibrary("CriticalNative");
}
static native boolean isNull(int[] a);
public static void main(String[] args) {
int[] arr = new int[2];
if (isNull(arr)) {
throw new RuntimeException("Should not be null");
}
if (!isNull(null)) {
throw new RuntimeException("Should be null");
}
}
}

@ -0,0 +1,197 @@
/*
* Copyright (c) 2018, Red Hat, Inc. and/or its affiliates.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*
*/
import java.util.Random;
/*
* @test CriticalNativeStress
* @key gc
* @bug 8199868
* @requires (os.arch =="x86_64" | os.arch == "amd64") & (vm.bits == "64") & vm.gc.Epsilon & !vm.graal.enabled
* @summary test argument pinning by nmethod wrapper of critical native method
* @run main/othervm/native -XX:+UnlockExperimentalVMOptions -XX:+UseEpsilonGC -Xcomp -Xmx1G -XX:+CriticalJNINatives CriticalNativeStress
*/
public class CriticalNativeStress {
private static Random rand = new Random();
static {
System.loadLibrary("CriticalNative");
}
// CYCLES and THREAD_PER_CASE are used to tune the tests for different GC settings,
// so that they can execrise enough GC cycles and not OOM
private static int CYCLES = Integer.getInteger("cycles", 3);
private static int THREAD_PER_CASE = Integer.getInteger("threadPerCase", 1);
static native long sum1(long[] a);
// More than 6 parameters
static native long sum2(long a1, int[] a2, int[] a3, long[] a4, int[] a5);
static long sum(long[] a) {
long sum = 0;
for (int index = 0; index < a.length; index ++) {
sum += a[index];
}
return sum;
}
static long sum(int[] a) {
long sum = 0;
for (int index = 0; index < a.length; index ++) {
sum += a[index];
}
return sum;
}
private static volatile String garbage_array[];
// GC potentially moves arrays passed to critical native methods
// if they are not pinned correctly.
// Create enough garbages to exercise GC cycles, verify
// the arrays are pinned correctly.
static void create_garbage(int len) {
len = Math.max(len, 1024);
String array[] = new String[len];
for (int index = 0; index < len; index ++) {
array[index] = "String " + index;
}
garbage_array = array;
}
// Two test cases with different method signatures:
// Tests generate arbitrary length of arrays with
// arbitrary values, then calcuate sum of the array
// elements with critical native JNI methods and java
// methods, and compare the results for correctness.
static void run_test_case1() {
// Create testing arary with arbitrary length and
// values
int length = rand.nextInt(50) + 1;
long[] arr = new long[length];
for (int index = 0; index < length; index ++) {
arr[index] = rand.nextLong() % 1002;
}
// Generate garbages to trigger GCs
for (int index = 0; index < length; index ++) {
create_garbage(index);
}
// Compare results for correctness.
long native_sum = sum1(arr);
long java_sum = sum(arr);
if (native_sum != java_sum) {
StringBuffer sb = new StringBuffer("Sums do not match: native = ")
.append(native_sum).append(" java = ").append(java_sum);
throw new RuntimeException(sb.toString());
}
}
static void run_test_case2() {
// Create testing arary with arbitrary length and
// values
int index;
long a1 = rand.nextLong() % 1025;
int a2_length = rand.nextInt(50) + 1;
int[] a2 = new int[a2_length];
for (index = 0; index < a2_length; index ++) {
a2[index] = rand.nextInt(106);
}
int a3_length = rand.nextInt(150) + 1;
int[] a3 = new int[a3_length];
for (index = 0; index < a3_length; index ++) {
a3[index] = rand.nextInt(3333);
}
int a4_length = rand.nextInt(200) + 1;
long[] a4 = new long[a4_length];
for (index = 0; index < a4_length; index ++) {
a4[index] = rand.nextLong() % 122;
}
int a5_length = rand.nextInt(350) + 1;
int[] a5 = new int[a5_length];
for (index = 0; index < a5_length; index ++) {
a5[index] = rand.nextInt(333);
}
// Generate garbages to trigger GCs
for (index = 0; index < a1; index ++) {
create_garbage(index);
}
// Compare results for correctness.
long native_sum = sum2(a1, a2, a3, a4, a5);
long java_sum = a1 + sum(a2) + sum(a3) + sum(a4) + sum(a5);
if (native_sum != java_sum) {
StringBuffer sb = new StringBuffer("Sums do not match: native = ")
.append(native_sum).append(" java = ").append(java_sum);
throw new RuntimeException(sb.toString());
}
}
static class Case1Runner extends Thread {
public Case1Runner() {
start();
}
public void run() {
for (int index = 0; index < CYCLES; index ++) {
run_test_case1();
}
}
}
static class Case2Runner extends Thread {
public Case2Runner() {
start();
}
public void run() {
for (int index = 0; index < CYCLES; index ++) {
run_test_case2();
}
}
}
public static void main(String[] args) {
Thread[] thrs = new Thread[THREAD_PER_CASE * 2];
for (int index = 0; index < thrs.length; index = index + 2) {
thrs[index] = new Case1Runner();
thrs[index + 1] = new Case2Runner();
}
for (int index = 0; index < thrs.length; index ++) {
try {
thrs[index].join();
} catch (Exception e) {
e.printStackTrace();
}
}
}
}

@ -0,0 +1,124 @@
/*
* Copyright (c) 2018, Red Hat, Inc. and/or its affiliates.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*
*/
#include "jni.h"
JNIEXPORT jlong JNICALL JavaCritical_CriticalNativeStress_sum1
(jint length, jlong* a) {
jlong sum = 0;
jint index;
for (index = 0; index < length; index ++) {
sum += a[index];
}
return sum;
}
JNIEXPORT jlong JNICALL JavaCritical_CriticalNativeStress_sum2
(jlong a1, jint a2_length, jint* a2, jint a4_length, jint* a4, jint a6_length, jlong* a6, jint a8_length, jint* a8) {
jlong sum = a1;
jint index;
for (index = 0; index < a2_length; index ++) {
sum += a2[index];
}
for (index = 0; index < a4_length; index ++) {
sum += a4[index];
}
for (index = 0; index < a6_length; index ++) {
sum += a6[index];
}
for (index = 0; index < a8_length; index ++) {
sum += a8[index];
}
return sum;
}
JNIEXPORT jlong JNICALL Java_CriticalNativeStress_sum1
(JNIEnv *env, jclass jclazz, jlongArray a) {
jlong sum = 0;
jsize len = (*env)->GetArrayLength(env, a);
jsize index;
jlong* arr = (jlong*)(*env)->GetPrimitiveArrayCritical(env, a, 0);
for (index = 0; index < len; index ++) {
sum += arr[index];
}
(*env)->ReleasePrimitiveArrayCritical(env, a, arr, 0);
return sum;
}
JNIEXPORT jlong JNICALL Java_CriticalNativeStress_sum2
(JNIEnv *env, jclass jclazz, jlong a1, jintArray a2, jintArray a3, jlongArray a4, jintArray a5) {
jlong sum = a1;
jsize index;
jsize len = (*env)->GetArrayLength(env, a2);
jint* a2_arr = (jint*)(*env)->GetPrimitiveArrayCritical(env, a2, 0);
for (index = 0; index < len; index ++) {
sum += a2_arr[index];
}
(*env)->ReleasePrimitiveArrayCritical(env, a2, a2_arr, 0);
len = (*env)->GetArrayLength(env, a3);
jint* a3_arr = (jint*)(*env)->GetPrimitiveArrayCritical(env, a3, 0);
for (index = 0; index < len; index ++) {
sum += a3_arr[index];
}
(*env)->ReleasePrimitiveArrayCritical(env, a3, a3_arr, 0);
len = (*env)->GetArrayLength(env, a4);
jlong* a4_arr = (jlong*)(*env)->GetPrimitiveArrayCritical(env, a4, 0);
for (index = 0; index < len; index ++) {
sum += a4_arr[index];
}
(*env)->ReleasePrimitiveArrayCritical(env, a4, a4_arr, 0);
len = (*env)->GetArrayLength(env, a5);
jint* a5_arr = (jint*)(*env)->GetPrimitiveArrayCritical(env, a5, 0);
for (index = 0; index < len; index ++) {
sum += a5_arr[index];
}
(*env)->ReleasePrimitiveArrayCritical(env, a5, a5_arr, 0);
return sum;
}
JNIEXPORT jboolean JNICALL JavaCritical_CriticalNativeArgs_isNull
(jint length, jint* a) {
return (a == NULL) && (length == 0);
}
JNIEXPORT jboolean JNICALL Java_CriticalNativeArgs_isNull
(JNIEnv *env, jclass jclazz, jintArray a) {
jboolean is_null;
jsize len = (*env)->GetArrayLength(env, a);
jint* arr = (jint*)(*env)->GetPrimitiveArrayCritical(env, a, 0);
is_null = (arr == NULL) && (len == 0);
(*env)->ReleasePrimitiveArrayCritical(env, a, arr, 0);
return is_null;
}