Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/probabilistic-transition-kernels.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@hashintel/petrinaut": minor
---

Add probability distribution support to transition kernels (`Distribution.Gaussian`, `Distribution.Uniform`)
Original file line number Diff line number Diff line change
@@ -1,236 +1,13 @@
import ts from "typescript";

import type { SDCPN } from "../../core/types/sdcpn";
import {
createLanguageServiceHost,
type VirtualFile,
} from "./create-language-service-host";
import { getItemFilePath } from "./file-paths";
import { createLanguageServiceHost } from "./create-language-service-host";
import { generateVirtualFiles } from "./generate-virtual-files";

export type SDCPNLanguageService = ts.LanguageService & {
updateFileContent: (fileName: string, content: string) => void;
};

/**
* Sanitizes a color ID to be a valid TypeScript identifier.
* Removes all characters that are not valid suffixes for TypeScript identifiers
* (keeps only letters, digits, and underscores).
*/
function sanitizeColorId(colorId: string): string {
return colorId.replace(/[^a-zA-Z0-9_]/g, "");
}

/**
* Maps SDCPN element types to TypeScript types
*/
function toTsType(type: "real" | "integer" | "boolean"): string {
return type === "boolean" ? "boolean" : "number";
}

/**
* Generates virtual files for all SDCPN entities
*/
function generateVirtualFiles(sdcpn: SDCPN): Map<string, VirtualFile> {
const files = new Map<string, VirtualFile>();

// Build lookup maps for places and types
const placeById = new Map(sdcpn.places.map((place) => [place.id, place]));
const colorById = new Map(sdcpn.types.map((color) => [color.id, color]));

// Generate parameters type definition
const parametersProperties = sdcpn.parameters
.map((param) => ` "${param.variableName}": ${toTsType(param.type)};`)
.join("\n");

files.set(getItemFilePath("parameters-defs"), {
content: `export type Parameters = {\n${parametersProperties}\n};`,
});

// Generate type definitions for each color
for (const color of sdcpn.types) {
const sanitizedColorId = sanitizeColorId(color.id);
const properties = color.elements
.map((el) => ` ${el.name}: ${toTsType(el.type)};`)
.join("\n");

files.set(getItemFilePath("color-defs", { colorId: color.id }), {
content: `export type Color_${sanitizedColorId} = {\n${properties}\n}`,
});
}

// Generate files for each differential equation
for (const de of sdcpn.differentialEquations) {
const sanitizedColorId = sanitizeColorId(de.colorId);
const deDefsPath = getItemFilePath("differential-equation-defs", {
id: de.id,
});
const deCodePath = getItemFilePath("differential-equation-code", {
id: de.id,
});
const parametersDefsPath = getItemFilePath("parameters-defs");
const colorDefsPath = getItemFilePath("color-defs", {
colorId: de.colorId,
});

// Type definitions file
files.set(deDefsPath, {
content: [
`import type { Parameters } from "${parametersDefsPath}";`,
`import type { Color_${sanitizedColorId} } from "${colorDefsPath}";`,
``,
`type Tokens = Array<Color_${sanitizedColorId}>;`,
`export type Dynamics = (fn: (tokens: Tokens, parameters: Parameters) => Tokens) => void;`,
].join("\n"),
});

// User code file with injected declarations
files.set(deCodePath, {
prefix: [
`import type { Dynamics } from "${deDefsPath}";`,
// TODO: Directly wrap user code in Dynamics call to remove need for user to write it.
`declare const Dynamics: Dynamics;`,
"",
].join("\n"),
content: de.code,
});
}

// Generate files for each transition
for (const transition of sdcpn.transitions) {
const parametersDefsPath = getItemFilePath("parameters-defs");
const lambdaDefsPath = getItemFilePath("transition-lambda-defs", {
transitionId: transition.id,
});
const lambdaCodePath = getItemFilePath("transition-lambda-code", {
transitionId: transition.id,
});
const kernelDefsPath = getItemFilePath("transition-kernel-defs", {
transitionId: transition.id,
});
const kernelCodePath = getItemFilePath("transition-kernel-code", {
transitionId: transition.id,
});

// Build input type: { [placeName]: [Token, Token, ...] } based on input arcs
const inputTypeImports: string[] = [];
const inputTypeProperties: string[] = [];

for (const arc of transition.inputArcs) {
const place = placeById.get(arc.placeId);
if (!place?.colorId) {
continue;
}
const color = colorById.get(place.colorId);
if (!color) {
continue;
}

const sanitizedColorId = sanitizeColorId(color.id);
const colorDefsPath = getItemFilePath("color-defs", {
colorId: color.id,
});
// Only add import if not already present (multiple arcs may share the same color)
const importStatement = `import type { Color_${sanitizedColorId} } from "${colorDefsPath}";`;
if (!inputTypeImports.includes(importStatement)) {
inputTypeImports.push(importStatement);
}
const tokenTuple = Array.from({ length: arc.weight })
.fill(`Color_${sanitizedColorId}`)
.join(", ");
inputTypeProperties.push(` "${place.name}": [${tokenTuple}];`);
}

// Build output type: { [placeName]: [Token, Token, ...] } based on output arcs
const outputTypeImports: string[] = [];
const outputTypeProperties: string[] = [];

for (const arc of transition.outputArcs) {
const place = placeById.get(arc.placeId);
if (!place?.colorId) {
continue;
}
const color = colorById.get(place.colorId);
if (!color) {
continue;
}

const sanitizedColorId = sanitizeColorId(color.id);
const colorDefsPath = getItemFilePath("color-defs", {
colorId: color.id,
});
// Only add import if not already present from input arcs or previous output arcs
const importStatement = `import type { Color_${sanitizedColorId} } from "${colorDefsPath}";`;
if (
!inputTypeImports.includes(importStatement) &&
!outputTypeImports.includes(importStatement)
) {
outputTypeImports.push(importStatement);
}
const tokenTuple = Array.from({ length: arc.weight })
.fill(`Color_${sanitizedColorId}`)
.join(", ");
outputTypeProperties.push(` "${place.name}": [${tokenTuple}];`);
}

const allImports = [...inputTypeImports, ...outputTypeImports];
const inputType =
inputTypeProperties.length > 0
? `{\n${inputTypeProperties.join("\n")}\n}`
: "Record<string, never>";
const outputType =
outputTypeProperties.length > 0
? `{\n${outputTypeProperties.join("\n")}\n}`
: "Record<string, never>";
const lambdaReturnType =
transition.lambdaType === "predicate" ? "boolean" : "number";

// Lambda definitions file
files.set(lambdaDefsPath, {
content: [
`import type { Parameters } from "${parametersDefsPath}";`,
...allImports,
``,
`export type Input = ${inputType};`,
`export type Lambda = (fn: (input: Input, parameters: Parameters) => ${lambdaReturnType}) => void;`,
].join("\n"),
});

// Lambda code file
files.set(lambdaCodePath, {
prefix: [
`import type { Lambda } from "${lambdaDefsPath}";`,
`declare const Lambda: Lambda;`,
"",
].join("\n"),
content: transition.lambdaCode,
});

// TransitionKernel definitions file
files.set(kernelDefsPath, {
content: [
`import type { Parameters } from "${parametersDefsPath}";`,
...allImports,
``,
`export type Input = ${inputType};`,
`export type Output = ${outputType};`,
`export type TransitionKernel = (fn: (input: Input, parameters: Parameters) => Output) => void;`,
].join("\n"),
});

// TransitionKernel code file
files.set(kernelCodePath, {
prefix: [
`import type { TransitionKernel } from "${kernelDefsPath}";`,
`declare const TransitionKernel: TransitionKernel;`,
"",
].join("\n"),
content: transition.transitionKernelCode,
});
}

return files;
}

/**
* Adjusts diagnostic positions to account for injected prefix
*/
Expand Down
5 changes: 5 additions & 0 deletions libs/@hashintel/petrinaut/src/checker/lib/file-paths.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/

export type SDCPNFileType =
| "sdcpn-lib-defs"
| "parameters-defs"
| "color-defs"
| "differential-equation-defs"
Expand All @@ -14,6 +15,7 @@ export type SDCPNFileType =
| "transition-kernel-code";

type FilePathParams = {
"sdcpn-lib-defs": Record<string, never>;
"parameters-defs": Record<string, never>;
"color-defs": { colorId: string };
"differential-equation-defs": { id: string };
Expand All @@ -40,6 +42,9 @@ export const getItemFilePath = <T extends SDCPNFileType>(
const params = args[0];

switch (fileType) {
case "sdcpn-lib-defs":
return "/sdcpn-lib.d.ts";

case "parameters-defs":
return "/parameters/defs.d.ts";

Expand Down
Loading
Loading