223 lines
7.6 KiB
C++
Raw Normal View History

/*
* Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
#include <jni.h>
#include <jvmti.h>
#include <jvmti_common.hpp>
#include <atomic>
static jvmtiEnv *jvmti = nullptr;
static jint error_count = 0;
extern "C" JNIEXPORT jint JNICALL
Agent_OnLoad(JavaVM *vm, char *options, void *reserved) {
if (vm->GetEnv((void **)&jvmti, JVMTI_VERSION) != JNI_OK) {
LOG("Could not initialize JVMTI\n");
return JNI_ERR;
}
jvmtiCapabilities caps;
memset(&caps, 0, sizeof(caps));
caps.can_support_virtual_threads = 1;
caps.can_suspend = 1;
caps.can_signal_thread = 1;
jvmtiError err = jvmti->AddCapabilities(&caps);
if (err != JVMTI_ERROR_NONE) {
LOG("JVMTI AddCapabilities error: %d\n", err);
return JNI_ERR;
}
return JNI_OK;
}
extern "C" JNIEXPORT jboolean JNICALL
Java_GetThreadStateMountedTest_trySuspendInWaitingState(JNIEnv* jni, jclass clazz, jthread vthread) {
const int max_retries = 10;
for (int i = 0; i < max_retries; i++) {
// wait a bit
sleep_ms(100);
// suspend the thread
LOG("suspend vthread (%d)\n", i);
suspend_thread(jvmti, jni, vthread);
jint state = get_thread_state(jvmti, jni, vthread);
if ((state & JVMTI_THREAD_STATE_WAITING) != 0) {
LOG("suspended in WAITING state\n");
return JNI_TRUE;
}
LOG("suspended vthread is not waiting: state = %x (%s)\n", state, TranslateState(state));
LOG("resume vthread\n");
resume_thread(jvmti, jni, vthread);
}
LOG("ERROR: failed to suspend in WAITING state in %d tries\n", max_retries);
return JNI_FALSE;
}
static void verify_thread_state(const char *name, JNIEnv* jni,
jthread thread, jint expected_strong, jint expected_weak)
{
jint state = get_thread_state(jvmti, jni, thread);
LOG("%s state(%x): %s\n", name, state, TranslateState(state));
bool failed = false;
// check 1: all expected_strong bits are set
jint actual_strong = state & expected_strong;
if (actual_strong != expected_strong) {
failed = true;
jint missed = expected_strong - actual_strong;
LOG(" ERROR: some mandatory bits are not set (%x): %s\n",
missed, TranslateState(missed));
}
// check 2: no bits other than (expected_strong | expected_weak) are set
jint actual_full = state & (expected_strong | expected_weak);
if (actual_full != state) {
failed = true;
jint unexpected = state - actual_full;
LOG(" ERROR: some unexpected bits are set (%x): %s\n",
unexpected, TranslateState(unexpected));
}
// check 3: expected_weak checks
if (expected_weak != 0) {
// check 3a: at least 1 bit from expected_weak is set
if ((state & expected_weak) == 0) {
failed = true;
LOG(" ERROR: no expected 'weak' bits are set\n");
}
// check 3b: not all expected_weak bits are set
if ((state & expected_weak) == expected_weak) {
failed = true;
LOG(" ERROR: all expected 'weak' bits are set\n");
}
}
if (failed) {
LOG(" expected 'strong' state (%x): %s\n", expected_strong, TranslateState(expected_strong));
LOG(" expected 'weak' state (%x): %s\n", expected_weak, TranslateState(expected_weak));
error_count++;
}
}
extern "C" JNIEXPORT void JNICALL
Java_GetThreadStateMountedTest_testThread(
JNIEnv* jni, jclass clazz, jthread vthread, jboolean is_vthread_suspended,
jboolean test_interrupt,
jint expected_strong, jint expected_weak)
{
jint exp_ct_state = JVMTI_THREAD_STATE_ALIVE
| JVMTI_THREAD_STATE_WAITING
| JVMTI_THREAD_STATE_WAITING_INDEFINITELY;
jint exp_vt_state = expected_strong
| JVMTI_THREAD_STATE_ALIVE;
jthread cthread = get_carrier_thread(jvmti, jni, vthread);
verify_thread_state("cthread", jni, cthread,
exp_ct_state, 0);
verify_thread_state("vthread", jni, vthread,
exp_vt_state | (is_vthread_suspended ? JVMTI_THREAD_STATE_SUSPENDED : 0),
expected_weak);
// suspend ctread and verify
LOG("suspend cthread\n");
suspend_thread(jvmti, jni, cthread);
verify_thread_state("cthread", jni, cthread,
exp_ct_state | JVMTI_THREAD_STATE_SUSPENDED, 0);
verify_thread_state("vthread", jni, vthread,
exp_vt_state | (is_vthread_suspended ? JVMTI_THREAD_STATE_SUSPENDED : 0),
expected_weak);
// suspend vthread and verify
if (!is_vthread_suspended) {
LOG("suspend vthread\n");
suspend_thread(jvmti, jni, vthread);
verify_thread_state("cthread", jni, cthread,
exp_ct_state | JVMTI_THREAD_STATE_SUSPENDED, 0);
verify_thread_state("vthread", jni, vthread,
exp_vt_state | JVMTI_THREAD_STATE_SUSPENDED, expected_weak);
}
// resume cthread and verify
LOG("resume cthread\n");
resume_thread(jvmti, jni, cthread);
verify_thread_state("cthread", jni, cthread,
exp_ct_state, 0);
verify_thread_state("vthread", jni, vthread,
exp_vt_state | JVMTI_THREAD_STATE_SUSPENDED, expected_weak);
if (test_interrupt) {
// interrupt vthread (while it's suspended)
LOG("interrupt vthread\n");
check_jvmti_status(jni, jvmti->InterruptThread(vthread), "error in JVMTI InterruptThread");
verify_thread_state("cthread", jni, cthread,
exp_ct_state, 0);
verify_thread_state("vthread", jni, vthread,
exp_vt_state | JVMTI_THREAD_STATE_SUSPENDED | JVMTI_THREAD_STATE_INTERRUPTED,
expected_weak);
}
// resume vthread
LOG("resume vthread\n");
resume_thread(jvmti, jni, vthread);
// don't verify thread state after InterruptThread and ResumeThread
}
extern "C" JNIEXPORT int JNICALL
Java_GetThreadStateMountedTest_getErrorCount(JNIEnv* jni, jclass clazz) {
return error_count;
}
static std::atomic<bool> time_to_exit(false);
extern "C" JNIEXPORT void JNICALL
Java_GetThreadStateMountedTest_runFromNative(JNIEnv* jni, jclass clazz, jobject runnable) {
jmethodID mid = jni->GetStaticMethodID(clazz, "runUpcall", "(Ljava/lang/Runnable;)V");
if (mid == nullptr) {
jni->FatalError("failed to get runUpcall method");
return;
}
jni->CallStaticVoidMethod(clazz, mid, runnable);
}
extern "C" JNIEXPORT void JNICALL
Java_GetThreadStateMountedTest_waitInNative(JNIEnv* jni, jclass clazz) {
// Notify main thread that we are ready
jfieldID fid = jni->GetStaticFieldID(clazz, "waitInNativeReady", "Z");
if (fid == nullptr) {
jni->FatalError("cannot get waitInNativeReady field");
return;
}
jni->SetStaticBooleanField(clazz, fid, JNI_TRUE);
while (!time_to_exit) {
sleep_ms(100);
}
}
extern "C" JNIEXPORT void JNICALL
Java_GetThreadStateMountedTest_endWait(JNIEnv* jni, jclass clazz) {
time_to_exit = true;
}