diff --git a/docs/.vitepress/config.ts b/docs/.vitepress/config.ts index 78a4ca1..6927616 100644 --- a/docs/.vitepress/config.ts +++ b/docs/.vitepress/config.ts @@ -36,6 +36,7 @@ const docsSidebar = [ { text: 'List Query', link: '/guide/list-query' }, { text: 'Lambda Deployment', link: '/guide/lambda' }, { text: 'Testing', link: '/guide/testing' }, + { text: 'Row Level Security', link: '/guide/rls' }, ], }, { diff --git a/docs/guide/rls.md b/docs/guide/rls.md new file mode 100644 index 0000000..3ff3479 --- /dev/null +++ b/docs/guide/rls.md @@ -0,0 +1,155 @@ +--- +title: Row Level Security +--- + +# Row Level Security (RLS) + +Glasswork ships first-class helpers for PostgreSQL Row Level Security using Prisma Client Extensions. This guide shows how to scope every request to a tenant, provide an admin escape hatch, and test RLS behavior without extra boilerplate. + +## What you get + +- Per-request Prisma clients that set `SET LOCAL` session variables before each query. +- Hono middleware to place tenant context on the request. +- Awilix provider to resolve a scoped Prisma client (`tenantPrisma` by default). +- Admin/bypass client for system operations. +- Testing helpers to seed data per tenant and run code with scoped clients. + +## Defaults + +- Session variables: + - `app.tenant_id` + - `app.user_id` + - `app.user_role` + - `app.bypass_rls` (used by the admin client and `seedTenant`) +- Transaction wrapping is on by default so `SET LOCAL` stays scoped to the query. + +## Database setup (PostgreSQL) + +Enable RLS and add policies that read the session variables: + +```sql +ALTER TABLE projects ENABLE ROW LEVEL SECURITY; +ALTER TABLE projects FORCE ROW LEVEL SECURITY; + +CREATE POLICY tenant_isolation ON projects + FOR ALL + USING (tenant_id = current_setting('app.tenant_id', true)::uuid); + +CREATE POLICY admin_delete ON projects + FOR DELETE + USING ( + tenant_id = current_setting('app.tenant_id', true)::uuid + AND current_setting('app.user_role', true) = 'admin' + ); +``` + +Prisma schema convention (example): + +```prisma +model Project { + id String @id @default(cuid()) + name String + tenantId String + tenant Tenant @relation(fields: [tenantId], references: [id]) + createdBy String + + @@index([tenantId]) +} +``` + +## Glasswork integration + +1) **Add middleware** to extract tenant info (uses `c.get('auth')` by default): + +```ts +import { rlsMiddleware } from 'glasswork'; + +app.use(rlsMiddleware()); +``` + +2) **Register the scoped Prisma provider** in your module: + +```ts +import { createRLSProvider } from 'glasswork'; + +export const AppModule = defineModule({ + name: 'app', + providers: [ + PrismaService, // base client on prismaService.client + createRLSProvider(), // exposes tenantPrisma (scoped) + ProjectService, + ], +}); +``` + +3) **Inject the scoped client** in services: + +```ts +class ProjectService { + constructor(private readonly tenantPrisma: PrismaClient) {} + + async findAll() { + return this.tenantPrisma.project.findMany(); + } +} +``` + +### Customizing tokens and variables + +```ts +createRLSProvider({ + provide: 'scopedPrisma', + clientToken: 'prisma', // if you register the base client directly + clientProperty: undefined, // set to undefined when the token is the client itself + contextToken: 'tenantContext', + config: { + useTransaction: true, + sessionVariables: { + tenantId: 'myapp.tenant_id', + userId: 'myapp.user_id', + role: 'myapp.user_role', + }, + }, +}); +``` + +### Admin / bypass client + +```ts +import { createAdminClient } from 'glasswork'; + +const adminPrisma = createAdminClient(prisma); +await adminPrisma.project.deleteMany(); // runs with app.bypass_rls = true +``` + +## Testing utilities + +- `withTenant(prisma, tenantContext | tenantId, fn, options?)` — runs `fn` with a scoped client. +- `seedTenant(prisma, tenantId, seedFn, options?)` — sets `app.bypass_rls` and `app.tenant_id` inside a transaction, then executes `seedFn`. + +Example: + +```ts +import { seedTenant, withTenant } from 'glasswork'; + +await seedTenant(prisma, 'tenant-1', async (tx) => { + await tx.project.create({ data: { id: 'p1', name: 'One' } }); +}); + +await withTenant(prisma, 'tenant-1', async (tenantPrisma) => { + const projects = await tenantPrisma.project.findMany(); + expect(projects).toHaveLength(1); +}); +``` + +## Performance notes + +- Wrapping each query in a transaction adds a small overhead; keep it enabled unless you manage session variables per connection yourself. +- For bulk operations, batch work inside a single `prisma.$transaction` to set variables once. + +## CLI status + +A `glasswork generate rls` helper is planned but not shipped yet. Until then: +- Keep tenant fields consistent (`tenantId` with an index). +- Generate policies manually using the SQL snippets above. +- If you need automation, mirror the `formatSetStatement` pattern to build your own migration scripts. diff --git a/src/hono.d.ts b/src/hono.d.ts index eb2c863..d58d628 100644 --- a/src/hono.d.ts +++ b/src/hono.d.ts @@ -1,6 +1,7 @@ /** * Augment Hono's context with Glasswork-specific variables */ +import type { TenantContext } from './rls/types.js'; import type { OpenAPIResponseHook } from './types.js'; export interface Session { @@ -13,5 +14,6 @@ declare module 'hono' { interface ContextVariableMap { session?: Session; openapiResponseHooks?: OpenAPIResponseHook[]; + tenantContext?: TenantContext; } } diff --git a/src/index.ts b/src/index.ts index 3a94b4a..7649a97 100644 --- a/src/index.ts +++ b/src/index.ts @@ -144,7 +144,6 @@ export { // Middleware export { createRateLimitMiddleware } from './middleware/rate-limit.js'; - // OpenAPI export { defaultOpenAPIComponents } from './openapi/defaults.js'; export { configureOpenAPI } from './openapi/openapi.js'; @@ -156,6 +155,25 @@ export { paginationHeadersProcessor, responseHeadersProcessor, } from './openapi/openapi-processors.js'; +// RLS +export { + type AdminClientOptions, + createAdminClient, + createRLSClient, + createRLSProvider, + type RLSConfig, + type RLSMiddlewareOptions, + type RLSProviderOptions, + rlsMiddleware, + type SeedTenantOptions, + type SessionVariableNames, + seedTenant, + type TenantContext, + type TenantContextExtractor, + type TenantRole, + type WithTenantOptions, + withTenant, +} from './rls/index.js'; // Utilities export { deepMerge } from './utils/deep-merge.js'; diff --git a/src/rls/client.ts b/src/rls/client.ts new file mode 100644 index 0000000..67bef9a --- /dev/null +++ b/src/rls/client.ts @@ -0,0 +1,151 @@ +import type { PrismaClient } from '@prisma/client'; +import type { RLSConfig, SessionVariableNames, TenantContext } from './types.js'; +import { assertTenantContext, formatSetStatement } from './utils.js'; + +const DEFAULT_SESSION_VARIABLES: SessionVariableNames = { + tenantId: 'app.tenant_id', + userId: 'app.user_id', + role: 'app.user_role', + bypass: 'app.bypass_rls', +}; + +const DEFAULT_CONFIG: RLSConfig = { + sessionVariables: DEFAULT_SESSION_VARIABLES, + useTransaction: true, +}; + +type RawExecutor = { + $executeRawUnsafe: (query: string) => Promise; +}; + +type OperationInvoker = (args: unknown) => unknown; + +/** + * Create a Prisma client extension that sets RLS session variables + * for every query. + */ +export function createRLSClient( + prisma: TClient, + context: TenantContext, + config?: Partial +): TClient { + const mergedConfig = buildConfig(config); + const tenantContext = assertTenantContext(context, 'tenantContext'); + const statements = createStatements(tenantContext, mergedConfig.sessionVariables); + + return prisma.$extends({ + name: 'rls', + query: { + $allOperations: async ({ model, operation, args, query }) => { + if (mergedConfig.useTransaction) { + return prisma.$transaction(async (tx) => { + await applyStatements(tx, statements); + const operationFn = findOperation(tx, model, operation); + if (operationFn) { + return operationFn(args); + } + return query(args); + }); + } + + await applyStatements(prisma, statements); + return query(args); + }, + }, + }) as TClient; +} + +export interface AdminClientOptions { + bypassVariable?: string; + useTransaction?: boolean; +} + +/** + * Create a Prisma client that sets a bypass flag for administrative operations. + */ +export function createAdminClient( + prisma: TClient, + options: AdminClientOptions = {} +): TClient { + const bypassVariable = + options.bypassVariable ?? DEFAULT_SESSION_VARIABLES.bypass ?? 'app.bypass_rls'; + const useTransaction = options.useTransaction ?? true; + const statements = [formatSetStatement(bypassVariable, 'true')]; + + return prisma.$extends({ + name: 'rls-admin-bypass', + query: { + $allOperations: async ({ model, operation, args, query }) => { + if (useTransaction) { + return prisma.$transaction(async (tx) => { + await applyStatements(tx, statements); + const operationFn = findOperation(tx, model, operation); + if (operationFn) { + return operationFn(args); + } + return query(args); + }); + } + + await applyStatements(prisma, statements); + return query(args); + }, + }, + }) as TClient; +} + +function buildConfig(config?: Partial): RLSConfig { + return { + sessionVariables: { + ...DEFAULT_SESSION_VARIABLES, + ...config?.sessionVariables, + }, + useTransaction: config?.useTransaction ?? DEFAULT_CONFIG.useTransaction, + }; +} + +function createStatements( + context: TenantContext, + sessionVariables: SessionVariableNames +): string[] { + return [ + formatSetStatement(sessionVariables.tenantId, context.tenantId), + formatSetStatement(sessionVariables.userId, context.userId), + formatSetStatement(sessionVariables.role, context.role), + ]; +} + +async function applyStatements(target: RawExecutor, statements: string[]): Promise { + for (const statement of statements) { + await target.$executeRawUnsafe(statement); + } +} + +function findOperation( + client: unknown, + model: string | undefined, + operation: string +): OperationInvoker | undefined { + if (!client || typeof client !== 'object') { + return undefined; + } + + const scope = model ? (client as Record)[model] : client; + + if (!scope || typeof scope !== 'object') { + return undefined; + } + + const candidate = (scope as Record)[operation]; + return typeof candidate === 'function' ? (candidate as OperationInvoker) : undefined; +} + +/** + * @internal Exported for test utilities. + */ +export const __private__ = { + buildConfig, + createStatements, + applyStatements, + findOperation, +}; diff --git a/src/rls/index.ts b/src/rls/index.ts new file mode 100644 index 0000000..611093e --- /dev/null +++ b/src/rls/index.ts @@ -0,0 +1,19 @@ +export { + type AdminClientOptions, + createAdminClient, + createRLSClient, +} from './client.js'; +export { rlsMiddleware } from './middleware.js'; +export { createRLSProvider } from './provider.js'; +export { seedTenant, withTenant } from './testing.js'; +export type { + RLSConfig, + RLSMiddlewareOptions, + RLSProviderOptions, + SeedTenantOptions, + SessionVariableNames, + TenantContext, + TenantContextExtractor, + TenantRole, + WithTenantOptions, +} from './types.js'; diff --git a/src/rls/middleware.ts b/src/rls/middleware.ts new file mode 100644 index 0000000..d844fa1 --- /dev/null +++ b/src/rls/middleware.ts @@ -0,0 +1,43 @@ +import type { MiddlewareHandler } from 'hono'; +import type { RLSMiddlewareOptions, TenantContext } from './types.js'; +import { assertTenantContext, isTenantContext } from './utils.js'; + +function defaultExtractor(context: Parameters[0]): TenantContext | undefined { + const auth = context.get('auth'); + if (isTenantContext(auth)) { + return auth; + } + + return undefined; +} + +/** + * Hono middleware that extracts tenant context and stores it for DI. + */ +export function rlsMiddleware(options: RLSMiddlewareOptions = {}): MiddlewareHandler { + const { + contextKey = 'tenantContext', + extractTenant = defaultExtractor, + allowUnauthenticated = true, + } = options; + + return async (context, next) => { + const existing = context.get(contextKey); + if (existing) { + return next(); + } + + const tenantContext = await Promise.resolve(extractTenant(context)); + + if (!tenantContext) { + if (allowUnauthenticated) { + return next(); + } + throw new Error('Tenant context is required but was not found'); + } + + const validated = assertTenantContext(tenantContext, contextKey); + context.set(contextKey, validated); + return next(); + }; +} diff --git a/src/rls/provider.ts b/src/rls/provider.ts new file mode 100644 index 0000000..a83800a --- /dev/null +++ b/src/rls/provider.ts @@ -0,0 +1,69 @@ +import type { PrismaClient } from '@prisma/client'; +import type { ProviderConfig } from '../core/types.js'; +import { createRLSClient } from './client.js'; +import type { RLSProviderOptions } from './types.js'; +import { assertTenantContext } from './utils.js'; + +/** + * Create an Awilix provider that returns a tenant-scoped Prisma client. + */ +export function createRLSProvider(options: RLSProviderOptions = {}): ProviderConfig { + const { + provide = 'tenantPrisma', + clientToken = 'prismaService', + clientProperty = 'client', + contextToken = 'tenantContext', + config, + } = options; + + return { + provide, + useFactory: (dependencies) => { + const prisma = resolvePrismaClient(dependencies[clientToken], clientProperty, clientToken); + const tenantContext = assertTenantContext(dependencies[contextToken], contextToken); + return createRLSClient(prisma, tenantContext, config); + }, + inject: [clientToken, contextToken], + scope: 'SCOPED', + }; +} + +function resolvePrismaClient( + candidate: unknown, + clientProperty: string | undefined, + token: string +): PrismaClient { + if (clientProperty && candidate && typeof candidate === 'object') { + const value = (candidate as Record)[clientProperty]; + if (isPrismaLike(value)) { + return value; + } + } + + if (isPrismaLike(candidate)) { + return candidate; + } + + throw new Error(`Dependency "${token}" does not expose a Prisma client`); +} + +function isPrismaLike(value: unknown): value is PrismaClient { + return ( + typeof value === 'object' && + value !== null && + '$extends' in (value as Record) && + typeof (value as Record).$extends === 'function' && + '$transaction' in (value as Record) && + typeof (value as Record).$transaction === 'function' && + '$executeRawUnsafe' in (value as Record) && + typeof (value as Record).$executeRawUnsafe === 'function' + ); +} + +/** + * @internal Exported for testing. + */ +export const __private__ = { + isPrismaLike, + resolvePrismaClient, +}; diff --git a/src/rls/testing.ts b/src/rls/testing.ts new file mode 100644 index 0000000..7715123 --- /dev/null +++ b/src/rls/testing.ts @@ -0,0 +1,39 @@ +import type { PrismaClient } from '@prisma/client'; +import { createRLSClient } from './client.js'; +import type { SeedTenantOptions, TenantContext, WithTenantOptions } from './types.js'; +import { formatSetStatement } from './utils.js'; + +/** + * Execute a function with a tenant-scoped Prisma client. + */ +export async function withTenant( + prisma: PrismaClient, + tenant: TenantContext | string, + callback: (client: PrismaClient) => Promise, + options: WithTenantOptions = {} +): Promise { + const tenantContext: TenantContext = + typeof tenant === 'string' ? { tenantId: tenant, userId: 'test-user', role: 'admin' } : tenant; + + const scopedClient = createRLSClient(prisma, tenantContext, options.config); + return callback(scopedClient); +} + +/** + * Seed data for a specific tenant using a bypass client and session variables. + */ +export async function seedTenant( + prisma: PrismaClient, + tenantId: string, + seed: (client: PrismaClient) => Promise, + options: SeedTenantOptions = {} +): Promise { + const bypassVariable = options.bypassVariable ?? 'app.bypass_rls'; + const tenantVariable = options.tenantVariable ?? 'app.tenant_id'; + + await prisma.$transaction(async (tx) => { + await tx.$executeRawUnsafe(formatSetStatement(bypassVariable, 'true')); + await tx.$executeRawUnsafe(formatSetStatement(tenantVariable, tenantId)); + await seed(tx); + }); +} diff --git a/src/rls/types.ts b/src/rls/types.ts new file mode 100644 index 0000000..f8e55f7 --- /dev/null +++ b/src/rls/types.ts @@ -0,0 +1,78 @@ +import type { Context } from 'hono'; + +export type TenantRole = 'admin' | 'member' | 'viewer' | (string & {}); + +/** + * Tenant context extracted from authentication/session state. + */ +export interface TenantContext { + tenantId: string; + userId: string; + role: TenantRole; +} + +/** + * Names of PostgreSQL session variables used for RLS. + */ +export interface SessionVariableNames { + tenantId: string; + userId: string; + role: string; + bypass?: string; +} + +/** + * Configuration for the RLS Prisma extension. + */ +export interface RLSConfig { + sessionVariables: SessionVariableNames; + useTransaction: boolean; +} + +/** + * Options for building an Awilix provider that scopes Prisma per request. + */ +export interface RLSProviderOptions { + /** + * Token to register in the container (default: "tenantPrisma"). + */ + provide?: string; + /** + * Token that resolves to the base Prisma client or a service exposing it. + * Defaults to "prismaService". + */ + clientToken?: string; + /** + * Property on the injected service that contains the Prisma client. + * Defaults to "client". Set to undefined if the token is the Prisma client itself. + */ + clientProperty?: string; + /** + * Token that contains the current tenant context. + * Defaults to "tenantContext". + */ + contextToken?: string; + /** + * Optional overrides for the RLS configuration. + */ + config?: Partial; +} + +export type TenantContextExtractor = ( + context: Context +) => TenantContext | undefined | Promise; + +export interface RLSMiddlewareOptions { + contextKey?: string; + extractTenant?: TenantContextExtractor; + allowUnauthenticated?: boolean; +} + +export interface WithTenantOptions { + config?: Partial; +} + +export interface SeedTenantOptions { + bypassVariable?: string; + tenantVariable?: string; +} diff --git a/src/rls/utils.ts b/src/rls/utils.ts new file mode 100644 index 0000000..3739b15 --- /dev/null +++ b/src/rls/utils.ts @@ -0,0 +1,58 @@ +import type { TenantContext } from './types.js'; + +/** + * Escape a SQL literal by doubling single quotes. + * This prevents SQL injection in SET LOCAL statements. + */ +export function escapeLiteral(value: string): string { + return value.replaceAll("'", "''"); +} + +/** + * Escape a SQL identifier by doubling double quotes. + * This keeps session variable names safe to interpolate. + */ +export function escapeIdentifier(identifier: string): string { + if (!identifier.trim()) { + throw new Error('Session variable name cannot be empty'); + } + + return identifier.replaceAll('"', '""'); +} + +/** + * Build a SET LOCAL statement for a session variable. + */ +export function formatSetStatement(variableName: string, value: string): string { + return `SET LOCAL "${escapeIdentifier(variableName)}" = '${escapeLiteral(value)}'`; +} + +/** + * Lightweight runtime guard for TenantContext. + */ +export function isTenantContext(value: unknown): value is TenantContext { + if (!value || typeof value !== 'object') { + return false; + } + + const ctx = value as Record; + return ( + typeof ctx.tenantId === 'string' && + ctx.tenantId.length > 0 && + typeof ctx.userId === 'string' && + ctx.userId.length > 0 && + typeof ctx.role === 'string' && + ctx.role.length > 0 + ); +} + +/** + * Ensure a tenant context is available, throwing a clear error if missing. + */ +export function assertTenantContext(value: unknown, label: string): TenantContext { + if (!isTenantContext(value)) { + throw new Error(`Tenant context "${label}" is required for RLS`); + } + + return value; +} diff --git a/test/rls/rls.spec.ts b/test/rls/rls.spec.ts new file mode 100644 index 0000000..987bd78 --- /dev/null +++ b/test/rls/rls.spec.ts @@ -0,0 +1,299 @@ +import type { PrismaClient } from '@prisma/client'; +import { describe, expect, it, vi } from 'vitest'; +import * as rlsClient from '../../src/rls/client.js'; +import { + createAdminClient, + createRLSClient, + createRLSProvider, + rlsMiddleware, + seedTenant, + withTenant, +} from '../../src/rls/index.js'; +import type { RLSProviderOptions, TenantContext } from '../../src/rls/types.js'; + +type MockOperation = ReturnType; + +interface MockTransactionClient { + $executeRawUnsafe: MockOperation; + project: { + findMany: MockOperation; + deleteMany: MockOperation; + }; +} + +interface MockPrismaClient { + project: { + findMany: MockOperation; + deleteMany: MockOperation; + }; + $executeRawUnsafe: MockOperation; + $transaction: MockOperation; + $extends: (extension: { + query: { + $allOperations: (input: { + model?: string; + operation: string; + args: unknown; + query: (args: unknown) => Promise; + }) => Promise; + }; + }) => unknown; +} + +function createMockPrisma() { + const baseQuery = vi.fn().mockResolvedValue('base-query'); + const tx: MockTransactionClient = { + $executeRawUnsafe: vi.fn().mockResolvedValue(undefined), + project: { + findMany: vi.fn().mockResolvedValue('tx-result'), + deleteMany: vi.fn().mockResolvedValue('tx-deleted'), + }, + }; + + const prisma: MockPrismaClient = { + project: { + findMany: vi.fn().mockResolvedValue('base-result'), + deleteMany: vi.fn().mockResolvedValue('base-deleted'), + }, + $executeRawUnsafe: vi.fn().mockResolvedValue(undefined), + $transaction: vi.fn(async (callback: (client: MockTransactionClient) => unknown) => + callback(tx) + ), + $extends: (extension) => { + const callOperation = (operation: string, args: unknown, model?: string) => + extension.query.$allOperations({ + model: model ?? 'project', + operation, + args, + query: baseQuery, + }); + + return { + project: { + findMany: (args?: unknown) => callOperation('findMany', args), + deleteMany: (args?: unknown) => callOperation('deleteMany', args), + }, + }; + }, + }; + + return { prisma, tx, baseQuery }; +} + +describe('createRLSClient', () => { + it('wraps queries in a transaction and sets session variables', async () => { + const { prisma, tx } = createMockPrisma(); + + const client = createRLSClient(prisma as unknown as PrismaClient, { + tenantId: 'tenant-1', + userId: 'user-1', + role: 'member', + }); + + const extended = client as unknown as { + project: { findMany: (args?: unknown) => Promise }; + }; + + const result = await extended.project.findMany({ where: { id: 1 } }); + + expect(prisma.$transaction).toHaveBeenCalledTimes(1); + expect(tx.$executeRawUnsafe).toHaveBeenNthCalledWith( + 1, + 'SET LOCAL "app.tenant_id" = \'tenant-1\'' + ); + expect(tx.$executeRawUnsafe).toHaveBeenNthCalledWith(2, 'SET LOCAL "app.user_id" = \'user-1\''); + expect(tx.$executeRawUnsafe).toHaveBeenNthCalledWith( + 3, + 'SET LOCAL "app.user_role" = \'member\'' + ); + expect(tx.project.findMany).toHaveBeenCalledWith({ where: { id: 1 } }); + expect(result).toBe('tx-result'); + }); + + it('supports disabling transaction wrapping', async () => { + const { prisma, baseQuery } = createMockPrisma(); + + const client = createRLSClient( + prisma as unknown as PrismaClient, + { + tenantId: 'tenant-1', + userId: 'user-1', + role: 'member', + }, + { useTransaction: false } + ); + + const extended = client as unknown as { + project: { findMany: (args?: unknown) => Promise }; + }; + + const result = await extended.project.findMany({ take: 5 }); + + expect(prisma.$transaction).not.toHaveBeenCalled(); + expect(prisma.$executeRawUnsafe).toHaveBeenCalledTimes(3); + expect(baseQuery).toHaveBeenCalledWith({ take: 5 }); + expect(result).toBe('base-query'); + }); + + it('escapes tenant values to prevent SQL injection', async () => { + const { prisma, tx } = createMockPrisma(); + + const client = createRLSClient(prisma as unknown as PrismaClient, { + tenantId: "tenant-'1", + userId: "user-'1", + role: "role-'1", + }); + + const extended = client as unknown as { + project: { deleteMany: (args?: unknown) => Promise }; + }; + + await extended.project.deleteMany({ where: { id: 1 } }); + + expect(tx.$executeRawUnsafe).toHaveBeenNthCalledWith( + 1, + "SET LOCAL \"app.tenant_id\" = 'tenant-''1'" + ); + expect(tx.$executeRawUnsafe).toHaveBeenNthCalledWith( + 2, + "SET LOCAL \"app.user_id\" = 'user-''1'" + ); + expect(tx.$executeRawUnsafe).toHaveBeenNthCalledWith( + 3, + "SET LOCAL \"app.user_role\" = 'role-''1'" + ); + }); +}); + +describe('createAdminClient', () => { + it('sets bypass flag before executing queries', async () => { + const { prisma, tx } = createMockPrisma(); + + const client = createAdminClient(prisma as unknown as PrismaClient); + const extended = client as unknown as { + project: { findMany: (args?: unknown) => Promise }; + }; + + await extended.project.findMany(); + + expect(tx.$executeRawUnsafe).toHaveBeenCalledWith('SET LOCAL "app.bypass_rls" = \'true\''); + }); +}); + +describe('createRLSProvider', () => { + it('builds a scoped provider that resolves a Prisma client', () => { + const { prisma } = createMockPrisma(); + const tenantContext: TenantContext = { tenantId: 't', userId: 'u', role: 'admin' }; + const providerOptions: RLSProviderOptions = {}; + const provider = createRLSProvider(providerOptions); + const scopedClient = { scoped: true }; + + const spy = vi + .spyOn(rlsClient, 'createRLSClient') + // biome-ignore lint/suspicious/noExplicitAny: vi spy return type + .mockReturnValue(scopedClient as any); + + const resolved = provider.useFactory?.({ + prismaService: prisma, + tenantContext, + } as Record); + + expect(provider.scope).toBe('SCOPED'); + expect(provider.inject).toEqual(['prismaService', 'tenantContext']); + expect(spy).toHaveBeenCalledWith(prisma, tenantContext, undefined); + expect(resolved).toBe(scopedClient); + + spy.mockRestore(); + }); +}); + +describe('rlsMiddleware', () => { + it('stores tenant context from auth data', async () => { + const auth: TenantContext = { tenantId: 't', userId: 'u', role: 'member' }; + const set = vi.fn(); + const context = { + get: vi.fn((key: string) => (key === 'auth' ? auth : undefined)), + set, + }; + const next = vi.fn(); + + await rlsMiddleware()(context as unknown as Parameters[0], next); + + expect(set).toHaveBeenCalledWith('tenantContext', auth); + expect(next).toHaveBeenCalled(); + }); + + it('allows requests without tenant when configured', async () => { + const set = vi.fn(); + const context = { + get: vi.fn(() => undefined), + set, + }; + const next = vi.fn(); + + await rlsMiddleware({ allowUnauthenticated: true })( + context as unknown as Parameters[0], + next + ); + + expect(set).not.toHaveBeenCalled(); + expect(next).toHaveBeenCalled(); + }); +}); + +describe('testing utilities', () => { + it('withTenant creates scoped client and passes it to callback', async () => { + const { prisma } = createMockPrisma(); + const scopedClient = { scoped: true }; + const spy = vi + .spyOn(rlsClient, 'createRLSClient') + // biome-ignore lint/suspicious/noExplicitAny: vi spy return type + .mockReturnValue(scopedClient as any); + + const result = await withTenant( + prisma as unknown as PrismaClient, + 'tenant-42', + async (client) => { + expect(client).toBe(scopedClient); + return 'ok'; + } + ); + + expect(result).toBe('ok'); + expect(spy).toHaveBeenCalledWith( + prisma, + { tenantId: 'tenant-42', userId: 'test-user', role: 'admin' }, + undefined + ); + + spy.mockRestore(); + }); + + it('seedTenant sets bypass and tenant variables and runs seed callback', async () => { + const tx = { + $executeRawUnsafe: vi.fn().mockResolvedValue(undefined), + project: { + create: vi.fn().mockResolvedValue(undefined), + }, + }; + + const prisma = { + $transaction: vi.fn(async (callback: (client: typeof tx) => unknown) => callback(tx)), + }; + + await seedTenant(prisma as unknown as PrismaClient, 'tenant-seed', async (client) => { + await client.project.create({ data: { id: 1 } }); + }); + + expect(prisma.$transaction).toHaveBeenCalledTimes(1); + expect(tx.$executeRawUnsafe).toHaveBeenNthCalledWith( + 1, + 'SET LOCAL "app.bypass_rls" = \'true\'' + ); + expect(tx.$executeRawUnsafe).toHaveBeenNthCalledWith( + 2, + 'SET LOCAL "app.tenant_id" = \'tenant-seed\'' + ); + expect(tx.project.create).toHaveBeenCalledWith({ data: { id: 1 } }); + }); +});