Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion packages/libs/restate-sdk/src/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,9 @@ export type RestatePromise<T> = Promise<T> & {
* If this mapper returns a value, this value will be used to resolve the returned {@link RestatePromise}.
* If the mapper throws a {@link TerminalError}, this error will be used to reject the returned {@link RestatePromise}.
*/
map<U>(mapper: (value?: T, failure?: TerminalError) => U): RestatePromise<U>;
map<U>(
mapper: (value?: T, failure?: TerminalError) => U | Promise<U>
): RestatePromise<U>;
};

/**
Expand Down
51 changes: 34 additions & 17 deletions packages/libs/restate-sdk/src/promises.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ export abstract class InternalRestatePromise<T> implements RestatePromise<T> {
abstract finally(onfinally: (() => void) | undefined | null): Promise<T>;

abstract map<U>(
mapper: (value?: T, failure?: TerminalError) => U
mapper: (value?: T, failure?: TerminalError) => U | Promise<U>
): RestatePromise<U>;
abstract orTimeout(millis: Duration | number): RestatePromise<T>;

Expand Down Expand Up @@ -170,7 +170,9 @@ abstract class BaseRestatePromise<T> extends InternalRestatePromise<T> {
) as RestatePromise<T>;
}

map<U>(mapper: (value?: T, failure?: TerminalError) => U): RestatePromise<U> {
map<U>(
mapper: (value?: T, failure?: TerminalError) => U | Promise<U>
): RestatePromise<U> {
return new MappedRestatePromise(this[RESTATE_CTX_SYMBOL], this, mapper);
}

Expand Down Expand Up @@ -338,15 +340,15 @@ export class MappedRestatePromise<T, U> extends BaseRestatePromise<U> {
constructor(
ctx: ContextImpl,
readonly inner: InternalRestatePromise<T>,
mapper: (value?: T, failure?: TerminalError) => U
mapper: (value?: T, failure?: TerminalError) => U | Promise<U>
) {
super(ctx);
this.publicPromiseMapper = (value?: T, failure?: TerminalError) => {
this.publicPromiseMapper = async (value?: T, failure?: TerminalError) => {
try {
return Promise.resolve(mapper(value, failure));
return await mapper(value, failure);
} catch (e) {
if (e instanceof TerminalError) {
return Promise.reject(e);
throw e;
} else {
ctx.abortAttempt(e);
return pendingPromise();
Expand Down Expand Up @@ -382,30 +384,40 @@ export class MappedRestatePromise<T, U> extends BaseRestatePromise<U> {
}

export class ConstRestatePromise<T> extends InternalRestatePromise<T> {
private _constPromise?: Promise<T>;

private constructor(
private readonly constPromise: Promise<T>,
// Factory for the underlying promise. Called at most once, memoized in
// `_constPromise`. This lets `map` be lazy: the mapper is only invoked
// when someone actually awaits the result (via then/catch/finally/publicPromise),
// matching the contract documented on RestatePromise.map.
private readonly promiseFactory: () => Promise<T>,
private readonly settled: boolean
) {
super();
}

private get constPromise(): Promise<T> {
return (this._constPromise ??= this.promiseFactory());
}

static resolve<T>(value: T): ConstRestatePromise<Awaited<T>> {
return new ConstRestatePromise(Promise.resolve(value), true);
return new ConstRestatePromise(() => Promise.resolve(value), true);
}

static reject<T = never>(reason: TerminalError): ConstRestatePromise<T> {
return new ConstRestatePromise<T>(Promise.reject(reason), true);
return new ConstRestatePromise<T>(() => Promise.reject(reason), true);
}

static pending<T>(): ConstRestatePromise<T> {
return new ConstRestatePromise<T>(pendingPromise(), false);
return new ConstRestatePromise<T>(() => pendingPromise<T>(), false);
}

static fromPromise<T>(
promise: Promise<T>,
settled: boolean
): ConstRestatePromise<T> {
return new ConstRestatePromise(promise, settled);
return new ConstRestatePromise(() => promise, settled);
}

// --- Promise methods
Expand Down Expand Up @@ -434,12 +446,17 @@ export class ConstRestatePromise<T> extends InternalRestatePromise<T> {
return ConstRestatePromise.reject(new TimeoutError());
}

map<U>(mapper: (value?: T, failure?: TerminalError) => U): RestatePromise<U> {
return ConstRestatePromise.fromPromise(
this.constPromise.then(
(value) => mapper(value, undefined),
(reason) => mapper(undefined, reason as TerminalError)
),
map<U>(
mapper: (value?: T, failure?: TerminalError) => U | Promise<U>
): RestatePromise<U> {
if (!this.settled) return this as unknown as RestatePromise<U>;
const selfConstPromise = this.constPromise;
return new ConstRestatePromise<U>(
() =>
selfConstPromise.then(
(value) => mapper(value, undefined),
(reason) => mapper(undefined, reason as TerminalError)
),
this.settled
);
}
Expand Down
90 changes: 90 additions & 0 deletions packages/tests/restate-e2e-services/src/promise_combinators.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import * as restate from "@restatedev/restate-sdk";
import { REGISTRY } from "./services.js";
import { setTimeout } from "node:timers/promises";

const promiseCombinators = restate.service({
name: "PromiseCombinators",
Expand Down Expand Up @@ -158,6 +159,95 @@ const promiseCombinators = restate.service({
return "unexpected";
});
},

// --- Async map on ConstRestatePromise ---

resolveAsyncMap: async (
_ctx: restate.Context,
value: string
): Promise<string> => {
// async mapper on a resolved const promise
return RestatePromise.resolve(value).map(async (v) => {
return `mapped:${v ?? ""}`;
});
},

rejectAsyncMapRecover: async (
_ctx: restate.Context,
message: string
): Promise<string> => {
// async mapper recovers from a rejected const promise
return RestatePromise.reject<string>(
new restate.TerminalError(message)
).map(async (_v, err) => {
return `recovered:${err?.message ?? ""}`;
});
},

resolveAsyncMapChained: async (
_ctx: restate.Context,
value: string
): Promise<string> => {
// chained async maps on a resolved const promise
return RestatePromise.resolve(value)
.map(async (v) => `${v ?? ""}-a`)
.map(async (v) => `${v ?? ""}-b`)
.map(async (v) => `${v ?? ""}-c`);
},

resolveAsyncMapWithCtxRun: async (
ctx: restate.Context,
value: string
): Promise<string> => {
// async mapper that performs a ctx.run inside — verifies determinism:
// the ctx.run must be journaled exactly once across replays even though
// the mapper is a microtask-deferred async closure.
return RestatePromise.resolve(value).map(async (v) => {
const suffix = await ctx.run("append", () => "ran");
return `${v ?? ""}-${suffix}`;
});
},

resolveAsyncMapThrows: async (
_ctx: restate.Context,
input: { value: string; errorMessage: string }
): Promise<string> => {
// async mapper throws TerminalError — must propagate as rejection
return RestatePromise.resolve(input.value).map(async () => {
throw new restate.TerminalError(input.errorMessage);
});
},

resolveAsyncMapOrTimeout: async (
_ctx: restate.Context,
value: string
): Promise<string> => {
// resolve().map(async).orTimeout() — mapped promise inherits settled=true,
// so orTimeout returns `this` and the async mapper still runs to completion.
return RestatePromise.resolve(value)
.map(async (v) => `mapped:${v ?? ""}`)
.orTimeout(1);
},

allSettledAsyncMapWithCtxRun: async (
ctx: restate.Context,
values: string[]
): Promise<string[]> => {
// Build N const RestatePromises, each with an async mapper that calls ctx.run,
// then await them together via RestatePromise.allSettled.
// Verifies: (a) mappers fire lazily (only when allSettled consumes them),
// (b) each ctx.run is journaled deterministically, (c) results come back in order.
const promises = values.map((v, i) =>
RestatePromise.resolve(v).map(async (inner) => {
const suffix = await ctx.run(`run-${i}`, async () => {
await setTimeout(Math.random() * 1000);
return `ran-${i}`;
});
return `${inner ?? ""}:${suffix}`;
})
);
return RestatePromise.all(promises);
},
},
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,54 @@ describe("PromiseCombinators", () => {
const result = await client.raceEmptyOrTimeoutMapped();
expect(result).toBe("timeout");
});

// --- Async map on ConstRestatePromise ---

it("resolve().map(async) returns the mapped value", async () => {
const result = await client.resolveAsyncMap("hello");
expect(result).toBe("mapped:hello");
});

it("reject().map(async) can recover from rejection", async () => {
const result = await client.rejectAsyncMapRecover("boom");
expect(result).toBe("recovered:boom");
});

it("resolve().map(async).map(async).map(async) chains correctly", async () => {
const result = await client.resolveAsyncMapChained("start");
expect(result).toBe("start-a-b-c");
});

it("resolve().map(async) can perform ctx.run inside the mapper", async () => {
const result = await client.resolveAsyncMapWithCtxRun("val");
expect(result).toBe("val-ran");
});

it("resolve().map(async) propagates TerminalError thrown in mapper", async () => {
await expect(
client.resolveAsyncMapThrows({ value: "x", errorMessage: "mapper fail" })
).rejects.toThrow("mapper fail");
});

it("resolve().map(async).orTimeout() returns the mapped value", async () => {
const result = await client.resolveAsyncMapOrTimeout("hello");
expect(result).toBe("mapped:hello");
});

it("allSettled over many resolve().map(async ctx.run) preserves order and journals each run", async () => {
const result = await client.allSettledAsyncMapWithCtxRun([
"a",
"b",
"c",
"d",
"e",
]);
expect(result).toEqual([
"a:ran-0",
"b:ran-1",
"c:ran-2",
"d:ran-3",
"e:ran-4",
]);
});
});
Loading