diff --git a/src/main/java/com/autonomouslogic/commons/concurrent/NotVirtualThreadException.java b/src/main/java/com/autonomouslogic/commons/concurrent/NotVirtualThreadException.java new file mode 100644 index 0000000..dad1b7a --- /dev/null +++ b/src/main/java/com/autonomouslogic/commons/concurrent/NotVirtualThreadException.java @@ -0,0 +1,22 @@ +package com.autonomouslogic.commons.concurrent; + +/** + * Exception thrown when an operation requires a virtual thread but is executed on a platform thread. + */ +public class NotVirtualThreadException extends RuntimeException { + public NotVirtualThreadException() { + super("Current thread is not a virtual thread"); + } + + public NotVirtualThreadException(String message) { + super(message); + } + + public NotVirtualThreadException(String message, Throwable cause) { + super(message, cause); + } + + public NotVirtualThreadException(Throwable cause) { + super(cause); + } +} diff --git a/src/main/java/com/autonomouslogic/commons/concurrent/VirtualThreads.java b/src/main/java/com/autonomouslogic/commons/concurrent/VirtualThreads.java index 2ad60ac..6f2fa16 100644 --- a/src/main/java/com/autonomouslogic/commons/concurrent/VirtualThreads.java +++ b/src/main/java/com/autonomouslogic/commons/concurrent/VirtualThreads.java @@ -9,6 +9,7 @@ import java.util.concurrent.ExecutorCompletionService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Stream; @@ -262,6 +263,73 @@ public static void runAll(@NonNull Stream inputs, @NonNull Consumer ac runAll(inputs.map(input -> (Runnable) () -> action.accept(input)).iterator(), maxConcurrency); } + /** + * Checks if the current thread is a virtual thread. + * + * @return true if the current thread is a virtual thread, false otherwise + */ + public static boolean isVirtual() { + return Thread.currentThread().isVirtual(); + } + + /** + * Asserts that the current thread is a virtual thread. + * + * @throws NotVirtualThreadException if the current thread is not a virtual thread + */ + public static void checkIsVirtual() { + if (!isVirtual()) { + throw new NotVirtualThreadException(); + } + } + + /** + * Executes a task on a virtual thread if the current thread is not virtual. + * If the current thread is already virtual, the task is executed immediately. + * + * @param task the task to execute + * @throws InterruptedException if the current thread is interrupted while waiting for the task to complete + */ + public static void onVirtualThread(@NonNull Runnable task) throws InterruptedException { + if (isVirtual()) { + task.run(); + } else { + var thread = Thread.ofVirtual().start(task); + thread.join(); + } + } + + /** + * Executes a callable task on a virtual thread if the current thread is not virtual. + * If the current thread is already virtual, the task is executed immediately and the result is returned. + * + * @param the type of the result + * @param task the callable task to execute + * @return the result of the task + * @throws InterruptedException if the current thread is interrupted while waiting for the task to complete + * @throws Exception if the task throws an exception + */ + public static T onVirtualThread(@NonNull Callable task) throws InterruptedException, Exception { + if (isVirtual()) { + return task.call(); + } else { + var result = new AtomicReference(); + var exception = new AtomicReference(); + var thread = Thread.ofVirtual().start(() -> { + try { + result.set(task.call()); + } catch (Exception e) { + exception.set(e); + } + }); + thread.join(); + if (exception.get() != null) { + throw exception.get(); + } + return result.get(); + } + } + private static final class Result { private final int index; private final T value; diff --git a/src/test/java/com/autonomouslogic/commons/concurrent/VirtualThreadsTest.java b/src/test/java/com/autonomouslogic/commons/concurrent/VirtualThreadsTest.java index 3da6ac9..c418e49 100644 --- a/src/test/java/com/autonomouslogic/commons/concurrent/VirtualThreadsTest.java +++ b/src/test/java/com/autonomouslogic/commons/concurrent/VirtualThreadsTest.java @@ -450,4 +450,115 @@ void shouldProcessInputsUsingIterable() throws Exception { assertEquals(5, processed.get()); } } + + @Nested + class ThreadTypeTests { + @Test + void isVirtualShouldReturnFalseForPlatformThread() { + assertTrue(!VirtualThreads.isVirtual()); + } + + @Test + void isVirtualShouldReturnTrueForVirtualThread() throws Exception { + var result = new AtomicInteger(); + var virtualThread = Thread.ofVirtual().start(() -> { + result.set(VirtualThreads.isVirtual() ? 1 : 0); + }); + + virtualThread.join(); + + assertEquals(1, result.get()); + } + + @Test + void checkIsVirtualShouldThrowForPlatformThread() { + assertThrows(NotVirtualThreadException.class, VirtualThreads::checkIsVirtual); + } + + @Test + void checkIsVirtualShouldNotThrowForVirtualThread() throws Exception { + var exceptionThrown = new AtomicInteger(); + var virtualThread = Thread.ofVirtual().start(() -> { + try { + VirtualThreads.checkIsVirtual(); + } catch (NotVirtualThreadException e) { + exceptionThrown.set(1); + } + }); + + virtualThread.join(); + + assertEquals(0, exceptionThrown.get()); + } + } + + @Nested + class OnVirtualThreadTests { + @Test + void runnableShouldExecuteImmediatelyOnVirtualThread() throws Exception { + var executed = new AtomicInteger(); + var virtualThread = Thread.ofVirtual().start(() -> { + try { + VirtualThreads.onVirtualThread(() -> executed.set(1)); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + + virtualThread.join(); + + assertEquals(1, executed.get()); + } + + @Test + void runnableShouldCreateNewVirtualThreadFromPlatformThread() throws Exception { + var executed = new AtomicInteger(); + var threadId = new AtomicInteger(); + + VirtualThreads.onVirtualThread(() -> { + executed.set(1); + threadId.set((int) Thread.currentThread().threadId()); + }); + + assertEquals(1, executed.get()); + assertTrue(Thread.currentThread().threadId() != threadId.get()); + } + + @Test + void callableShouldReturnResultOnVirtualThread() throws Exception { + var result = new AtomicInteger(); + var virtualThread = Thread.ofVirtual().start(() -> { + try { + var value = VirtualThreads.onVirtualThread(() -> 42); + result.set(value); + } catch (Exception e) { + Thread.currentThread().interrupt(); + } + }); + + virtualThread.join(); + + assertEquals(42, result.get()); + } + + @Test + void callableShouldReturnResultFromCreatedVirtualThread() throws Exception { + var result = VirtualThreads.onVirtualThread(() -> 123); + + assertEquals(123, result); + } + + @Test + void callableShouldPropagateException() throws Exception { + var failureMessage = "Test failure"; + + var exception = assertThrows( + RuntimeException.class, + () -> VirtualThreads.onVirtualThread(() -> { + throw new RuntimeException(failureMessage); + })); + + assertEquals(failureMessage, exception.getMessage()); + } + } }