Skip to content

Commit

Permalink
Use thread factory for virtual threads
Browse files Browse the repository at this point in the history
This change makes ThreadPool use the supplied ThreadFactory also for virtual threads, which means they will have correct names etc
  • Loading branch information
cfredri4 authored and belaban committed Nov 20, 2024
1 parent d19f241 commit f23836f
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 61 deletions.
50 changes: 19 additions & 31 deletions src/org/jgroups/util/ThreadCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,48 +31,37 @@ public static boolean hasVirtualThreads() {
}

public static Thread createThread(Runnable r, String name, boolean daemon, boolean virtual) {
if(!virtual || CREATE_VTHREAD == null) {
Thread t=new Thread(r, name);
Thread t=null;
if(virtual)
t=newVirtualThread(r);
if(t == null) {
t=new Thread(r);
t.setDaemon(daemon);
return t;
}

// Thread.ofVirtual().unstarted()
try {
Object of=OF_VIRTUAL.invoke();
Thread t=(Thread)CREATE_VTHREAD.invokeWithArguments(of, r);
t.setName(name);
return t;
}
catch(Throwable t) {
}
t.setName(name);
return t;
}

// Thread.newThread(String name, int characteristics, Runnable task) in JDKs 15 & 16
try {
return (Thread)CREATE_VTHREAD.invokeExact(name, 1, r);
protected static Thread newVirtualThread(Runnable r) {
if(CREATE_VTHREAD != null) {
// Thread.ofVirtual().unstarted()
try {
Object of=OF_VIRTUAL.invoke();
return (Thread)CREATE_VTHREAD.invokeWithArguments(of, r);
}
catch(Throwable t) {
}
}
catch(Throwable ex) {
}
return new Thread(r, name);
return null;
}


protected static MethodHandle getCreateVThreadHandle() {
MethodType type=MethodType.methodType(Thread.class, Runnable.class);
try {
return LOOKUP.findVirtual(OF_VIRTUAL_CLASS, "unstarted", type);
}
catch(Exception ex) {
LOG.debug("%s.unstarted() not found, trying Thread.newThread() (jdk 15/16)", OF_VIRTUAL_NAME);
}

// try Thread.newThread(String name, int characteristics, Runnable task) in JDKs 15 & 16
type=MethodType.methodType(Thread.class, String.class, int.class, Runnable.class);
try {
return LOOKUP.findStatic(Thread.class, "newThread", type);
}
catch(Exception ex) {
LOG.debug("Thread.newThread() not found, falling back to regular threads");
LOG.debug("%s.unstarted() not found, falling back to regular threads", OF_VIRTUAL_NAME);
}
return null;
}
Expand All @@ -96,5 +85,4 @@ protected static MethodHandle getOfVirtualHandle() {
return null;
}
}

}
55 changes: 25 additions & 30 deletions src/org/jgroups/util/ThreadPool.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
* @since 5.2
*/
public class ThreadPool implements Lifecycle {
private static final MethodHandle EXECUTORS_NEW_VIRTUAL_THREAD_FACTORY=getNewVirtualThreadFactoryHandle();
private static final MethodHandle EXECUTORS_NEW_THREAD_PER_TASK_EXECUTOR=getNewThreadPerTaskExecutorHandle();
protected Executor thread_pool;
protected Log log;
protected ThreadFactory thread_factory;
Expand Down Expand Up @@ -192,7 +192,6 @@ public void resetStats() {
num_rejected_msgs.reset();
}


@Override
public void init() throws Exception {
if(log == null)
Expand Down Expand Up @@ -258,44 +257,40 @@ public String toString() {


protected static ExecutorService createThreadPool(int min_threads, int max_threads, long keep_alive_time,
String rejection_policy,
BlockingQueue<Runnable> queue, final ThreadFactory factory,
Log log) {
if(!factory.useVirtualThreads() || EXECUTORS_NEW_VIRTUAL_THREAD_FACTORY == null) {
ThreadPoolExecutor pool=new ThreadPoolExecutor(min_threads, max_threads, keep_alive_time,
TimeUnit.MILLISECONDS, queue, factory);
String rejection_policy, BlockingQueue<Runnable> queue,
final ThreadFactory factory, Log log) {
ExecutorService pool=null;
if(factory.useVirtualThreads())
pool=newVirtualThreadPool(factory);
if(pool == null) {
RejectedExecutionHandler handler=Util.parseRejectionPolicy(rejection_policy);
pool.setRejectedExecutionHandler(new ShutdownRejectedExecutionHandler(handler));
pool=new ThreadPoolExecutor(min_threads, max_threads, keep_alive_time,
TimeUnit.MILLISECONDS, queue, factory, handler);
if(log != null)
log.debug("thread pool min/max/keep-alive (ms): %d/%d/%d", min_threads, max_threads, keep_alive_time);
return pool;
}

try {
return (ExecutorService)EXECUTORS_NEW_VIRTUAL_THREAD_FACTORY.invokeExact();
}
catch(Throwable t) {
throw new IllegalStateException(String.format("failed to create virtual thread pool: %s", t));
}
return pool;
}

protected static MethodHandle getNewVirtualThreadFactoryHandle() {
MethodType type=MethodType.methodType(ExecutorService.class);
String[] names={
"newVirtualThreadPerTaskExecutor", // jdk 18-21
"newVirtualThreadExecutor", // jdk 17
"newUnboundedVirtualThreadExecutor" // jdk 15 & 16
};

MethodHandles.Lookup LOOKUP=MethodHandles.publicLookup();
for(int i=0; i < names.length; i++) {

protected static ExecutorService newVirtualThreadPool(final ThreadFactory factory) {
if(EXECUTORS_NEW_THREAD_PER_TASK_EXECUTOR != null) {
try {
return LOOKUP.findStatic(Executors.class, names[i], type);
return (ExecutorService)EXECUTORS_NEW_THREAD_PER_TASK_EXECUTOR.invokeExact((java.util.concurrent.ThreadFactory)factory);
}
catch(Exception e) {
catch(Throwable t) {
}
}
return null;
}

protected static MethodHandle getNewThreadPerTaskExecutorHandle() {
MethodHandles.Lookup LOOKUP=MethodHandles.publicLookup();
MethodType type=MethodType.methodType(ExecutorService.class, java.util.concurrent.ThreadFactory.class);
try {
return LOOKUP.findStatic(Executors.class, "newThreadPerTaskExecutor", type);
}
catch(Exception t) {
}
return null;
}
}

0 comments on commit f23836f

Please sign in to comment.