Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions .claude/skills/add-sdk-conformance-test/SKILL.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
---
name: add-sdk-conformance-test
description: Add a new test to the SDK conformance test suite. Use when the user wants to add a new sdk test, conformance test, or test a new SDK feature across all SDK implementations.
user-invocable: true
---

# Adding a New SDK Conformance Test

The `sdk-tests` module is a **conformance tool** — it defines contracts that SDK implementations must satisfy and test runners that verify them. It contains NO implementation code.

## Architecture

- **Contracts** (`sdk-tests/src/main/kotlin/dev/restate/sdktesting/contracts/`) — Kotlin interfaces (`@Service`, `@VirtualObject`, `@Workflow`) that each SDK implements. These define the wire API (service name, handler names, JSON field names).
- **Tests** (`sdk-tests/src/main/kotlin/dev/restate/sdktesting/tests/`) — JUnit 5 test classes that drive contracts through the Restate ingress client.

**Never add implementation code to sdk-tests.** Only interfaces in contracts, only test logic in tests.

## Step 1: Expand the contracts (if needed)

Edit the relevant contract interface only if strictly needed. The main ones:

- `VirtualObjectCommandInterpreter` — interpreter for combinator/signal/awakeable tests; the workhorse for most feature tests
- `TestUtilsService` — utility handlers (cancel, signal resolve/reject, etc.)

**Contract rules:**
- Data classes → `@Serializable`; sealed hierarchies → `@SerialName("camelCase")` discriminator
- Handler inputs must be a single type — wrap multiple fields in a `@Serializable` data class
- `@Handler` for exclusive handlers, `@Shared` for shared handlers

### VirtualObjectCommandInterpreter — key types

**AwaitableCommand** (sub-operations that can be composed):
- `CreateAwakeable(awakeableKey)`, `CreateSignal(signalName)`, `Sleep(timeoutMillis)`, `RunReturns(value)`, `RunThrowTerminalException(reason)`

**Command** (top-level interpreter steps):
- `AwaitOne(command)` — await a single sub-operation
- `AwaitAny(commands)` — first to complete (race); throws if winner failed
- `AwaitAnySuccessful(commands)` — first successful or all failed (legacy)
- `AwaitFirstCompleted(commands)` — first to complete (race)
- `AwaitFirstSucceededOrAllFailed(commands)` — first success, or throws if all fail
- `AwaitAllSucceededOrFirstFailed(commands)` — all succeed → pipe-joined `"v0|v1"`; throws on first fail
- `AwaitAllCompleted(commands)` — all settle → pipe-joined `"ok:v0|err:reason|ok:v2"`

**TestUtilsService:**
- `resolveSignal(ResolveSignalRequest(invocationId, signalName, value))`
- `rejectSignal(RejectSignalRequest(invocationId, signalName, reason))`

## Step 2: Write the test

### Test class boilerplate

```kotlin
class MyFeature {
companion object {
@RegisterExtension
val deployerExt: RestateDeployerExtension = RestateDeployerExtension {
withServiceSpec(
ServiceSpec.defaultBuilder()
.withServices(VirtualObjectCommandInterpreter::class, TestUtilsService::class))
}
}

@Test
@DisplayName("Human-readable description")
@Execution(ExecutionMode.CONCURRENT)
fun myTest(@InjectClient ingressClient: Client) = runTest {
// ...
}
}
```

### Client patterns

```kotlin
// Build clients
val voClient = ingressClient.toVirtualObject<VirtualObjectCommandInterpreter>(UUID.randomUUID().toString())
val svcClient = ingressClient.toService<TestUtilsService>()

// Call and get result immediately
val result = voClient.request { myHandler(req) }.options(idempotentCallOptions).call().response

// Send (non-blocking) + attach later — REQUIRED for signals
val sendResponse = voClient.request { myHandler(req) }.options(idempotentCallOptions).send()
val invocationId = sendResponse.invocationId()
// ... resolve/reject signals ...
val result = sendResponse.attachSuspend().response

// Expect a terminal error
assertThat(runCatching { sendResponse.attachSuspend().response }.exceptionOrNull())
.message().contains("expected substring")

// Poll until condition (awakeables only — not needed for signals)
await withAlias "description" untilAsserted {
assertThat(voClient.request { hasAwakeable("key") }.call().response).isTrue()
}
```

### Awakeable vs Signal patterns

**Awakeables** — identified by a unique runtime ID stored in VirtualObject state:
- Send the `interpretCommands` call and poll `hasAwakeable(key)` before resolving
- Resolve/reject via `interpreterClient.request { resolveAwakeable(ResolveAwakeable(key, value)) }`

**Signals** — identified by invocation ID + name; no pre-registration needed:
- `.send()` → get `invocationId()` → resolve/reject via `TestUtilsService.resolveSignal/rejectSignal` → `attachSuspend()`
- No polling required — signals can be sent before or after the handler starts waiting

## Step 3: Verify it compiles

```bash
./gradlew :sdk-tests:compileKotlin
```

## Step 4: Run against a local SDK image

Build the SDK Docker image (example for TypeScript SDK):

```bash
# From the sdk-typescript repo root
podman build -t e2e-ts:local -f packages/tests/restate-e2e-services/Dockerfile .
```

Run just the new test class:

```bash
./gradlew :sdk-tests:run --args='run --sequential --image-pull-policy=CACHED --test-suite=default --test-name=MyFeature --service-container-image=localhost/e2e-ts:local'
```

## Step 5: Update SDK implementations

After adding a new contract or command type, you must update each SDK's test service implementation. The TypeScript SDK (the reference implementation) lives in `sdk-typescript`:
- Main: `packages/tests/restate-e2e-services/src/virtual_object_command_interpreter.ts` and `test_utils.ts`
- Gen: `packages/libs/restate-sdk-gen/test-services/src/vo-command-interpreter.ts` and `test-utils.ts`

Use the `update-sdk-test-contracts` skill in the sdk-typescript repo for guidance on the implementation patterns.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ services/node-services/restatedev-restate-sdk-*
test_report
.kotlin

.claude/
.claude/settings.local.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ abstract class BaseRestateDeployerExtension : ParameterResolver {
}

private fun resolveIngressClient(extensionContext: ExtensionContext): Client {
return Client.connect(resolveIngressURI(extensionContext).toString())
return LoggingClient(Client.connect(resolveIngressURI(extensionContext).toString()))
}

private fun resolveContainerAddress(
Expand Down
110 changes: 110 additions & 0 deletions infra/src/main/kotlin/dev/restate/sdktesting/infra/LoggingClient.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH
//
// This file is part of the Restate SDK Test suite tool,
// which is released under the MIT license.
//
// You can find a copy of the license in file LICENSE in the root
// directory of this repository or package, or at
// https://github.com/restatedev/sdk-test-suite/blob/main/LICENSE
package dev.restate.sdktesting.infra

import dev.restate.client.Client
import dev.restate.client.Response
import dev.restate.client.SendResponse
import dev.restate.common.Request
import dev.restate.serde.TypeTag
import java.time.Duration
import java.util.concurrent.CompletableFuture
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.KSerializer
import kotlinx.serialization.json.Json
import kotlinx.serialization.serializerOrNull
import org.apache.logging.log4j.LogManager

internal class LoggingClient(private val delegate: Client) : Client {

companion object {
private val LOG = LogManager.getLogger(LoggingClient::class.java)
private val JSON = Json { prettyPrint = true }
}

override fun <Req, Res> callAsync(request: Request<Req, Res>): CompletableFuture<Response<Res>> {
LOG.info("→ CALL {}", formatRequest(request))
return delegate.callAsync(request).whenComplete { response, ex ->
if (ex != null) {
LOG.info("← CALL {} error: {}", request.target, ex.message)
} else {
LOG.info(
"← CALL {} status={} headers={} response={}",
request.target,
response.statusCode(),
response.headers().toLowercaseMap(),
formatBody(response.response()))
}
}
}

override fun <Req, Res> sendAsync(
request: Request<Req, Res>,
delay: Duration?
): CompletableFuture<SendResponse<Res>> {
LOG.info("→ SEND {}", formatRequest(request))
return delegate.sendAsync(request, delay).whenComplete { response, ex ->
if (ex != null) {
LOG.info("← SEND {} error: {}", request.target, ex.message)
} else {
LOG.info(
"← SEND {} status={} invocationId={} sendStatus={}",
request.target,
response.statusCode(),
response.invocationId(),
response.sendStatus())
}
}
}

// Delegate remaining abstract methods unchanged.

override fun awakeableHandle(id: String): Client.AwakeableHandle = delegate.awakeableHandle(id)

override fun <Res> invocationHandle(
invocationId: String,
resTypeTag: TypeTag<Res>
): Client.InvocationHandle<Res> = delegate.invocationHandle(invocationId, resTypeTag)

override fun <Res> idempotentInvocationHandle(
target: dev.restate.common.Target,
idempotencyKey: String,
resTypeTag: TypeTag<Res>
): Client.IdempotentInvocationHandle<Res> =
delegate.idempotentInvocationHandle(target, idempotencyKey, resTypeTag)

override fun <Res> workflowHandle(
workflowName: String,
workflowId: String,
resTypeTag: TypeTag<Res>
): Client.WorkflowHandle<Res> = delegate.workflowHandle(workflowName, workflowId, resTypeTag)

private fun formatRequest(request: Request<*, *>): String = buildString {
append(request.target)
if (request.idempotencyKey != null) append(" idempotency-key=${request.idempotencyKey}")
val headers = request.headers
if (!headers.isNullOrEmpty()) append(" headers=$headers")
append("\n payload: ")
append(formatBody(request.request))
}

@OptIn(ExperimentalSerializationApi::class)
@Suppress("UNCHECKED_CAST")
private fun formatBody(value: Any?): String {
if (value == null || value == Unit) return "(empty)"
if (value is ByteArray) return "[${value.size} bytes]"
return try {
val serializer =
serializerOrNull(value.javaClass) as? KSerializer<Any> ?: return value.toString()
JSON.encodeToString(serializer, value)
} catch (_: Exception) {
value.toString()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class TestSuite(

val testContainersLogger =
builder
.newLogger("org.testcontainers", Level.TRACE)
.newLogger("org.testcontainers", Level.DEBUG)
.add(builder.newAppenderRef("testRunnerLog"))
.add(builder.newAppenderRef("routingAppender"))
.addAttribute("additivity", false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,27 @@
package dev.restate.sdktesting.contracts

import dev.restate.sdk.annotation.*
import kotlinx.serialization.Serializable

/** Collection of various utilities/corner cases scenarios used by tests */
@Service
@Name("TestUtilsService")
interface TestUtilsService {

@Serializable
data class ResolveSignalRequest(
val invocationId: String,
val signalName: String,
val value: String
)

@Serializable
data class RejectSignalRequest(
val invocationId: String,
val signalName: String,
val reason: String
)

/** Just echo */
@Handler suspend fun echo(input: String): String

Expand All @@ -26,9 +42,6 @@ interface TestUtilsService {
/** Just echo */
@Handler @Raw suspend fun rawEcho(@Raw input: ByteArray): ByteArray

/** Create timers and await them all. Durations in milliseconds */
@Handler suspend fun sleepConcurrently(millisDuration: List<Long>)

/**
* Invoke `ctx.run` incrementing a local variable counter (not a restate state key!).
*
Expand All @@ -40,4 +53,10 @@ interface TestUtilsService {

/** Cancel invocation using the context. */
@Handler suspend fun cancelInvocation(invocationId: String)

/** Resolve a named signal on the given invocation. */
@Handler suspend fun resolveSignal(req: ResolveSignalRequest)

/** Reject a named signal on the given invocation. */
@Handler suspend fun rejectSignal(req: RejectSignalRequest)
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,23 @@ interface VirtualObjectCommandInterpreter {
// This is serialized as `{"type": "sleep", ...}`
@Serializable @SerialName("sleep") data class Sleep(val timeoutMillis: Long) : AwaitableCommand

// This is serialized as `{"type": "runReturns", ...}`
// When implementing this, make sure that the run executes some actual work, especially in async
// world (e.g. in ts, something as simple as setTimeout(1) is enough...)
@Serializable
@SerialName("runReturns")
data class RunReturns(val value: String) : AwaitableCommand

// This is serialized as `{"type": "runThrowTerminalException", ...}`
@Serializable
@SerialName("runThrowTerminalException")
data class RunThrowTerminalException(val reason: String) : AwaitableCommand

// This is serialized as `{"type": "createSignal", ...}`
@Serializable
@SerialName("createSignal")
data class CreateSignal(val signalName: String) : AwaitableCommand

@Serializable sealed interface Command

// Returns the index of the one that completed first successfully
Expand All @@ -54,6 +66,26 @@ interface VirtualObjectCommandInterpreter {
@SerialName("awaitAwakeableOrTimeout")
data class AwaitAwakeableOrTimeout(val awakeableKey: String, val timeoutMillis: Long) : Command

// Awaits all commands; returns first succeeded or throws if all failed
@Serializable
@SerialName("awaitFirstSucceededOrAllFailed")
data class AwaitFirstSucceededOrAllFailed(val commands: List<AwaitableCommand>) : Command

// Awaits all commands; returns first completed (throwing if it failed)
@Serializable
@SerialName("awaitFirstCompleted")
data class AwaitFirstCompleted(val commands: List<AwaitableCommand>) : Command

// Awaits all commands; returns pipe-joined values if all succeed, throws on first failure
@Serializable
@SerialName("awaitAllSucceededOrFirstFailed")
data class AwaitAllSucceededOrFirstFailed(val commands: List<AwaitableCommand>) : Command

// Awaits all commands; returns pipe-joined "ok:<value>" or "err:<reason>" for each result
@Serializable
@SerialName("awaitAllCompleted")
data class AwaitAllCompleted(val commands: List<AwaitableCommand>) : Command

@Serializable data class InterpretRequest(val commands: List<Command>)

@Serializable
Expand Down
Loading
Loading