Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.autonomouslogic.commons.concurrent;

/**
* Exception thrown when an operation requires a virtual thread but is executed on a platform thread.
*/
public class NotVirtualThreadException extends RuntimeException {
public NotVirtualThreadException() {
super("Current thread is not a virtual thread");
}

public NotVirtualThreadException(String message) {
super(message);
}

public NotVirtualThreadException(String message, Throwable cause) {
super(message, cause);
}

public NotVirtualThreadException(Throwable cause) {
super(cause);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Stream;
Expand Down Expand Up @@ -262,6 +263,73 @@ public static <T> void runAll(@NonNull Stream<T> inputs, @NonNull Consumer<T> ac
runAll(inputs.map(input -> (Runnable) () -> action.accept(input)).iterator(), maxConcurrency);
}

/**
* Checks if the current thread is a virtual thread.
*
* @return true if the current thread is a virtual thread, false otherwise
*/
public static boolean isVirtual() {
return Thread.currentThread().isVirtual();
}

/**
* Asserts that the current thread is a virtual thread.
*
* @throws NotVirtualThreadException if the current thread is not a virtual thread
*/
public static void checkIsVirtual() {
if (!isVirtual()) {
throw new NotVirtualThreadException();
}
}

/**
* Executes a task on a virtual thread if the current thread is not virtual.
* If the current thread is already virtual, the task is executed immediately.
*
* @param task the task to execute
* @throws InterruptedException if the current thread is interrupted while waiting for the task to complete
*/
public static void onVirtualThread(@NonNull Runnable task) throws InterruptedException {
if (isVirtual()) {
task.run();
} else {
var thread = Thread.ofVirtual().start(task);
thread.join();
}
}

/**
* Executes a callable task on a virtual thread if the current thread is not virtual.
* If the current thread is already virtual, the task is executed immediately and the result is returned.
*
* @param <T> the type of the result
* @param task the callable task to execute
* @return the result of the task
* @throws InterruptedException if the current thread is interrupted while waiting for the task to complete
* @throws Exception if the task throws an exception
*/
public static <T> T onVirtualThread(@NonNull Callable<T> task) throws InterruptedException, Exception {
if (isVirtual()) {
return task.call();
} else {
var result = new AtomicReference<T>();
var exception = new AtomicReference<Exception>();
var thread = Thread.ofVirtual().start(() -> {
try {
result.set(task.call());
} catch (Exception e) {
exception.set(e);
}
});
thread.join();
if (exception.get() != null) {
throw exception.get();
}
return result.get();
}
}

private static final class Result<T> {
private final int index;
private final T value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,4 +450,115 @@ void shouldProcessInputsUsingIterable() throws Exception {
assertEquals(5, processed.get());
}
}

@Nested
class ThreadTypeTests {
@Test
void isVirtualShouldReturnFalseForPlatformThread() {
assertTrue(!VirtualThreads.isVirtual());
}

@Test
void isVirtualShouldReturnTrueForVirtualThread() throws Exception {
var result = new AtomicInteger();
var virtualThread = Thread.ofVirtual().start(() -> {
result.set(VirtualThreads.isVirtual() ? 1 : 0);
});

virtualThread.join();

assertEquals(1, result.get());
}

@Test
void checkIsVirtualShouldThrowForPlatformThread() {
assertThrows(NotVirtualThreadException.class, VirtualThreads::checkIsVirtual);
}

@Test
void checkIsVirtualShouldNotThrowForVirtualThread() throws Exception {
var exceptionThrown = new AtomicInteger();
var virtualThread = Thread.ofVirtual().start(() -> {
try {
VirtualThreads.checkIsVirtual();
} catch (NotVirtualThreadException e) {
exceptionThrown.set(1);
}
});

virtualThread.join();

assertEquals(0, exceptionThrown.get());
}
}

@Nested
class OnVirtualThreadTests {
@Test
void runnableShouldExecuteImmediatelyOnVirtualThread() throws Exception {
var executed = new AtomicInteger();
var virtualThread = Thread.ofVirtual().start(() -> {
try {
VirtualThreads.onVirtualThread(() -> executed.set(1));
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
});

virtualThread.join();

assertEquals(1, executed.get());
}

@Test
void runnableShouldCreateNewVirtualThreadFromPlatformThread() throws Exception {
var executed = new AtomicInteger();
var threadId = new AtomicInteger();

VirtualThreads.onVirtualThread(() -> {
executed.set(1);
threadId.set((int) Thread.currentThread().threadId());
});

assertEquals(1, executed.get());
assertTrue(Thread.currentThread().threadId() != threadId.get());
}

@Test
void callableShouldReturnResultOnVirtualThread() throws Exception {
var result = new AtomicInteger();
var virtualThread = Thread.ofVirtual().start(() -> {
try {
var value = VirtualThreads.onVirtualThread(() -> 42);
result.set(value);
} catch (Exception e) {
Thread.currentThread().interrupt();
}
});

virtualThread.join();

assertEquals(42, result.get());
}

@Test
void callableShouldReturnResultFromCreatedVirtualThread() throws Exception {
var result = VirtualThreads.onVirtualThread(() -> 123);

assertEquals(123, result);
}

@Test
void callableShouldPropagateException() throws Exception {
var failureMessage = "Test failure";

var exception = assertThrows(
RuntimeException.class,
() -> VirtualThreads.onVirtualThread(() -> {
throw new RuntimeException(failureMessage);
}));

assertEquals(failureMessage, exception.getMessage());
}
}
}
Loading