diff --git a/packages/2-sql/4-lanes/relational-core/src/ast/types.ts b/packages/2-sql/4-lanes/relational-core/src/ast/types.ts index 21882ed32e..055943588f 100644 --- a/packages/2-sql/4-lanes/relational-core/src/ast/types.ts +++ b/packages/2-sql/4-lanes/relational-core/src/ast/types.ts @@ -14,7 +14,12 @@ export type BinaryOp = | 'like' | 'ilike' | 'in' - | 'notIn'; + | 'notIn' + | 'add' + | 'sub' + | 'mul' + | 'div' + | 'mod'; export type AggregateCountFn = 'count'; export type AggregateOpFn = 'sum' | 'avg' | 'min' | 'max'; @@ -52,6 +57,7 @@ export interface ExprVisitor { exists(expr: ExistsExpr): R; nullCheck(expr: NullCheckExpr): R; not(expr: NotExpr): R; + cast(expr: CastExpr): R; literal(expr: LiteralExpr): R; param(expr: ParamRef): R; list(expr: ListExpression): R; @@ -839,6 +845,26 @@ export class BinaryExpr extends Expression { return new BinaryExpr('notIn', left, right); } + static add(left: AnyExpression, right: AnyExpression): BinaryExpr { + return new BinaryExpr('add', left, right); + } + + static sub(left: AnyExpression, right: AnyExpression): BinaryExpr { + return new BinaryExpr('sub', left, right); + } + + static mul(left: AnyExpression, right: AnyExpression): BinaryExpr { + return new BinaryExpr('mul', left, right); + } + + static div(left: AnyExpression, right: AnyExpression): BinaryExpr { + return new BinaryExpr('div', left, right); + } + + static mod(left: AnyExpression, right: AnyExpression): BinaryExpr { + return new BinaryExpr('mod', left, right); + } + override accept(visitor: ExprVisitor): R { return visitor.binary(this); } @@ -1020,6 +1046,35 @@ export class NotExpr extends Expression { } } +export class CastExpr extends Expression { + readonly kind = 'cast' as const; + readonly expr: AnyExpression; + readonly targetCodecId: string; + + constructor(expr: AnyExpression, targetCodecId: string) { + super(); + this.expr = expr; + this.targetCodecId = targetCodecId; + this.freeze(); + } + + static of(expr: AnyExpression, targetCodecId: string): CastExpr { + return new CastExpr(expr, targetCodecId); + } + + override accept(visitor: ExprVisitor): R { + return visitor.cast(this); + } + + override rewrite(rewriter: ExpressionRewriter): AnyExpression { + return new CastExpr(this.expr.rewrite(rewriter), this.targetCodecId); + } + + override fold(folder: ExpressionFolder): T { + return this.expr.fold(folder); + } +} + export class EqColJoinOn extends AstNode { readonly kind = 'eq-col-join-on' as const; readonly left: ColumnRef; @@ -1714,7 +1769,8 @@ export type AnyExpression = | OrExpr | ExistsExpr | NullCheckExpr - | NotExpr; + | NotExpr + | CastExpr; export type AnyInsertOnConflictAction = DoNothingConflictAction | DoUpdateSetConflictAction; export type AnyInsertValue = ColumnRef | ParamRef | DefaultValueExpr; export type AnyOperationArg = AnyExpression | ParamRef | LiteralExpr; diff --git a/packages/2-sql/4-lanes/relational-core/test/ast/kind-discriminants.test.ts b/packages/2-sql/4-lanes/relational-core/test/ast/kind-discriminants.test.ts index e01819ada2..5012e1f995 100644 --- a/packages/2-sql/4-lanes/relational-core/test/ast/kind-discriminants.test.ts +++ b/packages/2-sql/4-lanes/relational-core/test/ast/kind-discriminants.test.ts @@ -3,6 +3,7 @@ import { AggregateExpr, AndExpr, BinaryExpr, + CastExpr, DefaultValueExpr, DeleteAst, DerivedTableSource, @@ -58,6 +59,7 @@ const allKindEntries: Array<[string, { kind: string }]> = [ ['InsertOnConflict', InsertOnConflict.on([col('t', 'id')])], ['DoNothingConflictAction', new DoNothingConflictAction()], ['DoUpdateSetConflictAction', new DoUpdateSetConflictAction({ id: col('t', 'id') })], + ['CastExpr', CastExpr.of(col('t', 'id'), 'pg/text@1')], ]; describe('AST kind discriminants', () => { @@ -90,6 +92,7 @@ describe('AST kind discriminants', () => { ['InsertOnConflict', 'insert-on-conflict'], ['DoNothingConflictAction', 'do-nothing'], ['DoUpdateSetConflictAction', 'do-update-set'], + ['CastExpr', 'cast'], ])('%s has kind "%s"', (className, expectedKind) => { const entry = allKindEntries.find(([name]) => name === className); expect(entry).toBeDefined(); diff --git a/packages/2-sql/4-lanes/sql-builder/src/expression.ts b/packages/2-sql/4-lanes/sql-builder/src/expression.ts index f22297da53..48a2014ec6 100644 --- a/packages/2-sql/4-lanes/sql-builder/src/expression.ts +++ b/packages/2-sql/4-lanes/sql-builder/src/expression.ts @@ -108,6 +108,32 @@ export type BuiltinFunctions[]) => Expression; or: (...ors: ExpressionOrValue[]) => Expression; + add: ( + a: ExpressionOrValue, + b: ExpressionOrValue, + ) => Expression; + sub: ( + a: ExpressionOrValue, + b: ExpressionOrValue, + ) => Expression; + mul: ( + a: ExpressionOrValue, + b: ExpressionOrValue, + ) => Expression; + div: ( + a: ExpressionOrValue, + b: ExpressionOrValue, + ) => Expression; + mod: ( + a: ExpressionOrValue, + b: ExpressionOrValue, + ) => Expression; + + cast: ( + expr: Expression, + target: { codecId: TargetCodecId; nullable: Nullable }, + ) => Expression<{ codecId: TargetCodecId; nullable: Nullable }>; + exists: (subquery: Subquery>) => Expression; notExists: (subquery: Subquery>) => Expression; diff --git a/packages/2-sql/4-lanes/sql-builder/src/runtime/functions.ts b/packages/2-sql/4-lanes/sql-builder/src/runtime/functions.ts index 2220bcdd0b..a8e77aedff 100644 --- a/packages/2-sql/4-lanes/sql-builder/src/runtime/functions.ts +++ b/packages/2-sql/4-lanes/sql-builder/src/runtime/functions.ts @@ -4,6 +4,7 @@ import { type AnyExpression as AstExpression, BinaryExpr, type BinaryOp, + CastExpr, ExistsExpr, ListExpression, LiteralExpr, @@ -76,6 +77,24 @@ function inOrNotIn( return boolExpr(binaryFn(left, SubqueryExpr.of(valuesOrSubquery.buildAst()))); } +function arithmetic( + a: ExpressionOrValue, + b: ExpressionOrValue, + op: BinaryOp, +): ExpressionImpl { + const field = ( + a instanceof ExpressionImpl + ? a.field + : b instanceof ExpressionImpl + ? b.field + : { codecId: 'unknown', nullable: false } + ) as T; + return new ExpressionImpl( + new BinaryExpr(op, resolve(a as ExprOrVal), resolve(b as ExprOrVal)), + field, + ); +} + function numericAgg( fn: 'sum' | 'avg' | 'min' | 'max', expr: Expression, @@ -96,6 +115,32 @@ function createBuiltinFunctions() { lte: (a: ExprOrVal, b: ExprOrVal) => comparison(a, b, 'lte'), and: (...exprs: ExprOrVal[]) => boolExpr(AndExpr.of(exprs.map(resolveToAst))), or: (...exprs: ExprOrVal[]) => boolExpr(OrExpr.of(exprs.map(resolveToAst))), + add: ( + a: ExpressionOrValue, + b: ExpressionOrValue, + ) => arithmetic(a, b, 'add'), + sub: ( + a: ExpressionOrValue, + b: ExpressionOrValue, + ) => arithmetic(a, b, 'sub'), + mul: ( + a: ExpressionOrValue, + b: ExpressionOrValue, + ) => arithmetic(a, b, 'mul'), + div: ( + a: ExpressionOrValue, + b: ExpressionOrValue, + ) => arithmetic(a, b, 'div'), + mod: ( + a: ExpressionOrValue, + b: ExpressionOrValue, + ) => arithmetic(a, b, 'mod'), + cast: ( + expr: Expression, + target: { codecId: TargetCodecId; nullable: Nullable }, + ) => { + return new ExpressionImpl(CastExpr.of(expr.buildAst(), target.codecId), target); + }, exists: (subquery: Subquery>) => boolExpr(ExistsExpr.exists(subquery.buildAst())), notExists: (subquery: Subquery>) => diff --git a/packages/2-sql/4-lanes/sql-builder/test/integration/arithmetic.test.ts b/packages/2-sql/4-lanes/sql-builder/test/integration/arithmetic.test.ts new file mode 100644 index 0000000000..c39d314c8c --- /dev/null +++ b/packages/2-sql/4-lanes/sql-builder/test/integration/arithmetic.test.ts @@ -0,0 +1,83 @@ +import { describe, expect, it } from 'vitest'; +import { setupIntegrationTest } from './setup'; + +describe('integration: arithmetic operations', () => { + const { db, runtime } = setupIntegrationTest(); + + it('add computes column + literal', async () => { + const rows = await runtime().execute( + db() + .posts.select('id') + .select('viewsPlus', (f, fns) => fns.add(f.views, 10)) + .where((f, fns) => fns.eq(f.id, 1)) + .build(), + ); + expect(rows).toHaveLength(1); + expect(rows[0]!.viewsPlus).toBe(110); + }); + + it('sub computes column - literal', async () => { + const rows = await runtime().execute( + db() + .posts.select('id') + .select('viewsMinus', (f, fns) => fns.sub(f.views, 10)) + .where((f, fns) => fns.eq(f.id, 1)) + .build(), + ); + expect(rows[0]!.viewsMinus).toBe(90); + }); + + it('mul computes column * literal', async () => { + const rows = await runtime().execute( + db() + .posts.select('id') + .select('viewsDouble', (f, fns) => fns.mul(f.views, 2)) + .where((f, fns) => fns.eq(f.id, 1)) + .build(), + ); + expect(rows[0]!.viewsDouble).toBe(200); + }); + + it('div computes column / literal', async () => { + const rows = await runtime().execute( + db() + .posts.select('id') + .select('viewsHalf', (f, fns) => fns.div(f.views, 2)) + .where((f, fns) => fns.eq(f.id, 1)) + .build(), + ); + expect(rows[0]!.viewsHalf).toBe(50); + }); + + it('mod computes column % literal', async () => { + const rows = await runtime().execute( + db() + .posts.select('id') + .select('viewsMod', (f, fns) => fns.mod(f.views, 3)) + .where((f, fns) => fns.eq(f.id, 1)) + .build(), + ); + expect(rows[0]!.viewsMod).toBe(1); + }); + + it('arithmetic in WHERE clause filters correctly', async () => { + const rows = await runtime().execute( + db() + .posts.select('id', 'views') + .where((f, fns) => fns.gt(fns.add(f.views, 50), 150)) + .build(), + ); + expect(rows.every((r) => r.views + 50 > 150)).toBe(true); + }); + + it('nested arithmetic works', async () => { + const rows = await runtime().execute( + db() + .posts.select('id') + .select('computed', (f, fns) => fns.add(fns.mul(f.views, 2), 1)) + .where((f, fns) => fns.eq(f.id, 1)) + .build(), + ); + expect(rows[0]!.computed).toBe(201); + }); +}); diff --git a/packages/2-sql/4-lanes/sql-builder/test/integration/cast.test.ts b/packages/2-sql/4-lanes/sql-builder/test/integration/cast.test.ts new file mode 100644 index 0000000000..b6b495002d --- /dev/null +++ b/packages/2-sql/4-lanes/sql-builder/test/integration/cast.test.ts @@ -0,0 +1,28 @@ +import { describe, expect, it } from 'vitest'; +import { setupIntegrationTest } from './setup'; + +describe('integration: type cast', () => { + const { db, runtime } = setupIntegrationTest(); + + it('cast int to text', async () => { + const rows = await runtime().execute( + db() + .posts.select('id') + .select('idText', (f, fns) => fns.cast(f.id, { codecId: 'pg/text@1', nullable: false })) + .where((f, fns) => fns.eq(f.id, 1)) + .build(), + ); + expect(rows[0]!.idText).toBe('1'); + }); + + it('cast in WHERE clause', async () => { + const rows = await runtime().execute( + db() + .posts.select('id') + .where((f, fns) => fns.eq(fns.cast(f.id, { codecId: 'pg/text@1', nullable: false }), '1')) + .build(), + ); + expect(rows).toHaveLength(1); + expect(rows[0]!.id).toBe(1); + }); +}); diff --git a/packages/2-sql/4-lanes/sql-builder/test/runtime/functions.test.ts b/packages/2-sql/4-lanes/sql-builder/test/runtime/functions.test.ts index d388f6f33f..13a5de392d 100644 --- a/packages/2-sql/4-lanes/sql-builder/test/runtime/functions.test.ts +++ b/packages/2-sql/4-lanes/sql-builder/test/runtime/functions.test.ts @@ -2,6 +2,7 @@ import { AggregateExpr, AndExpr, BinaryExpr, + CastExpr, ColumnRef, ExistsExpr, IdentifierRef, @@ -254,6 +255,80 @@ describe('createAggregateFunctions', () => { }); }); +describe('arithmetic operators', () => { + let fns: ReturnType; + + beforeEach(() => { + fns = createFunctions({}); + }); + + it('add produces BinaryExpr with op add', () => { + const result = fns.add(f().id, 1); + const ast = result.buildAst() as BinaryExpr; + + expect(ast).toBeInstanceOf(BinaryExpr); + expect(ast.op).toBe('add'); + expect(ast.left).toBeInstanceOf(IdentifierRef); + expect(ast.right).toBeInstanceOf(ParamRef); + }); + + it('sub produces BinaryExpr with op sub', () => { + const result = fns.sub(f().id, 1); + expect((result.buildAst() as BinaryExpr).op).toBe('sub'); + }); + + it('mul produces BinaryExpr with op mul', () => { + const result = fns.mul(f().id, 2); + expect((result.buildAst() as BinaryExpr).op).toBe('mul'); + }); + + it('div produces BinaryExpr with op div', () => { + const result = fns.div(f().id, 2); + expect((result.buildAst() as BinaryExpr).op).toBe('div'); + }); + + it('mod produces BinaryExpr with op mod', () => { + const result = fns.mod(f().id, 2); + expect((result.buildAst() as BinaryExpr).op).toBe('mod'); + }); + + it('preserves codec from expression operand', () => { + const result = fns.add(f().id, 1); + expect((result as ExpressionImpl).field).toEqual({ codecId: 'pg/int4@1', nullable: false }); + }); + + it('two expressions produce BinaryExpr with both column refs', () => { + const fields = jf(); + const result = fns.add(fields.users.id, fields.posts.id); + const ast = result.buildAst() as BinaryExpr; + + expect(ast.op).toBe('add'); + expect(ast.left).toBeInstanceOf(ColumnRef); + expect(ast.right).toBeInstanceOf(ColumnRef); + }); +}); + +describe('cast', () => { + let fns: ReturnType; + + beforeEach(() => { + fns = createFunctions({}); + }); + + it('produces CastExpr with target codec id', () => { + const result = fns.cast(f().id, { codecId: 'pg/text@1', nullable: false }); + const ast = result.buildAst() as CastExpr; + + expect(ast).toBeInstanceOf(CastExpr); + expect(ast.targetCodecId).toBe('pg/text@1'); + }); + + it('returns field with target codec and nullable', () => { + const result = fns.cast(f().id, { codecId: 'pg/text@1', nullable: true }); + expect((result as ExpressionImpl).field).toEqual({ codecId: 'pg/text@1', nullable: true }); + }); +}); + describe('extension functions', () => { it('produces OperationExpr from queryOperationTypes', () => { const opTypes = { diff --git a/packages/3-extensions/sql-orm-client/src/query-plan-aggregate.ts b/packages/3-extensions/sql-orm-client/src/query-plan-aggregate.ts index 4d44693236..696e79fb77 100644 --- a/packages/3-extensions/sql-orm-client/src/query-plan-aggregate.ts +++ b/packages/3-extensions/sql-orm-client/src/query-plan-aggregate.ts @@ -93,6 +93,7 @@ function validateGroupedHavingExpr(expr: AnyExpression): AnyExpression { not(expr) { return new NotExpr(validateGroupedHavingExpr(expr.expr)); }, + cast: rejectHavingExpr, binary(expr) { return new BinaryExpr( expr.op, diff --git a/packages/3-extensions/sql-orm-client/src/where-binding.ts b/packages/3-extensions/sql-orm-client/src/where-binding.ts index 9370711fdd..4af7ae770f 100644 --- a/packages/3-extensions/sql-orm-client/src/where-binding.ts +++ b/packages/3-extensions/sql-orm-client/src/where-binding.ts @@ -4,6 +4,7 @@ import { type AnyExpression, type AnyFromSource, BinaryExpr, + CastExpr, type ColumnRef, DerivedTableSource, ExistsExpr, @@ -84,6 +85,9 @@ function bindWhereExprNode(contract: SqlContract, expr: AnyExpressio not(expr) { return new NotExpr(bindWhereExprNode(contract, expr.expr)); }, + cast(expr) { + return CastExpr.of(bindExpression(contract, expr.expr), expr.targetCodecId); + }, }); } diff --git a/packages/3-targets/6-adapters/postgres/src/core/adapter.ts b/packages/3-targets/6-adapters/postgres/src/core/adapter.ts index d0dfb7bab4..151d002d5c 100644 --- a/packages/3-targets/6-adapters/postgres/src/core/adapter.ts +++ b/packages/3-targets/6-adapters/postgres/src/core/adapter.ts @@ -6,7 +6,9 @@ import { type AnyFromSource, type AnyQueryAst, type BinaryExpr, + type CastExpr, type CodecParamsDescriptor, + type CodecRegistry, type ColumnRef, createCodecRegistry, type DeleteAst, @@ -36,6 +38,12 @@ import type { PostgresAdapterOptions, PostgresContract, PostgresLoweredStatement const VECTOR_CODEC_ID = 'pg/vector@1' as const; +interface RenderCtx { + readonly contract?: PostgresContract; + readonly pim?: ParamIndexMap; + readonly codecs: CodecRegistry; +} + function getCodecParamCast(codecId: string | undefined): string | undefined { if (codecId === VECTOR_CODEC_ID) { return 'vector'; @@ -128,21 +136,27 @@ class PostgresAdapterImpl params.push(ref.value); } + const ctx: RenderCtx = { + contract: context.contract, + pim: paramIndexMap, + codecs: this.codecRegistry, + }; + let sql: string; const node = ast; switch (node.kind) { case 'select': - sql = renderSelect(node, context.contract, paramIndexMap); + sql = renderSelect(node, ctx); break; case 'insert': - sql = renderInsert(node, context.contract, paramIndexMap); + sql = renderInsert(node, ctx); break; case 'update': - sql = renderUpdate(node, context.contract, paramIndexMap); + sql = renderUpdate(node, ctx); break; case 'delete': - sql = renderDelete(node, context.contract, paramIndexMap); + sql = renderDelete(node, ctx); break; // v8 ignore next 4 default: @@ -158,27 +172,26 @@ class PostgresAdapterImpl } } -function renderSelect(ast: SelectAst, contract?: PostgresContract, pim?: ParamIndexMap): string { - const selectClause = `SELECT ${renderDistinctPrefix(ast.distinct, ast.distinctOn, contract, pim)}${renderProjection( +function renderSelect(ast: SelectAst, ctx: RenderCtx): string { + const selectClause = `SELECT ${renderDistinctPrefix(ast.distinct, ast.distinctOn, ctx)}${renderProjection( ast.projection, - contract, - pim, + ctx, )}`; - const fromClause = `FROM ${renderSource(ast.from, contract, pim)}`; + const fromClause = `FROM ${renderSource(ast.from, ctx)}`; const joinsClause = ast.joins?.length - ? ast.joins.map((join) => renderJoin(join, contract, pim)).join(' ') + ? ast.joins.map((join) => renderJoin(join, ctx)).join(' ') : ''; - const whereClause = ast.where ? `WHERE ${renderWhere(ast.where, contract, pim)}` : ''; + const whereClause = ast.where ? `WHERE ${renderWhere(ast.where, ctx)}` : ''; const groupByClause = ast.groupBy?.length - ? `GROUP BY ${ast.groupBy.map((expr) => renderExpr(expr, contract, pim)).join(', ')}` + ? `GROUP BY ${ast.groupBy.map((expr) => renderExpr(expr, ctx)).join(', ')}` : ''; - const havingClause = ast.having ? `HAVING ${renderWhere(ast.having, contract, pim)}` : ''; + const havingClause = ast.having ? `HAVING ${renderWhere(ast.having, ctx)}` : ''; const orderClause = ast.orderBy?.length ? `ORDER BY ${ast.orderBy .map((order) => { - const expr = renderExpr(order.expr, contract, pim); + const expr = renderExpr(order.expr, ctx); return `${expr} ${order.dir.toUpperCase()}`; }) .join(', ')}` @@ -202,18 +215,14 @@ function renderSelect(ast: SelectAst, contract?: PostgresContract, pim?: ParamIn return clauses.trim(); } -function renderProjection( - projection: ReadonlyArray, - contract?: PostgresContract, - pim?: ParamIndexMap, -): string { +function renderProjection(projection: ReadonlyArray, ctx: RenderCtx): string { return projection .map((item) => { const alias = quoteIdentifier(item.alias); if (item.expr.kind === 'literal') { return `${renderLiteral(item.expr)} AS ${alias}`; } - return `${renderExpr(item.expr, contract, pim)} AS ${alias}`; + return `${renderExpr(item.expr, ctx)} AS ${alias}`; }) .join(', '); } @@ -221,11 +230,10 @@ function renderProjection( function renderDistinctPrefix( distinct: true | undefined, distinctOn: ReadonlyArray | undefined, - contract?: PostgresContract, - pim?: ParamIndexMap, + ctx: RenderCtx, ): string { if (distinctOn && distinctOn.length > 0) { - const rendered = distinctOn.map((expr) => renderExpr(expr, contract, pim)).join(', '); + const rendered = distinctOn.map((expr) => renderExpr(expr, ctx)).join(', '); return `DISTINCT ON (${rendered}) `; } if (distinct) { @@ -234,11 +242,7 @@ function renderDistinctPrefix( return ''; } -function renderSource( - source: AnyFromSource, - contract?: PostgresContract, - pim?: ParamIndexMap, -): string { +function renderSource(source: AnyFromSource, ctx: RenderCtx): string { const node = source; switch (node.kind) { case 'table-source': { @@ -249,7 +253,7 @@ function renderSource( return `${table} AS ${quoteIdentifier(node.alias)}`; } case 'derived-table-source': - return `(${renderSelect(node.query, contract, pim)}) AS ${quoteIdentifier(node.alias)}`; + return `(${renderSelect(node.query, ctx)}) AS ${quoteIdentifier(node.alias)}`; // v8 ignore next 4 default: throw new Error( @@ -264,35 +268,38 @@ function assertScalarSubquery(query: SelectAst): void { } } -function renderSubqueryExpr( - expr: SubqueryExpr, - contract?: PostgresContract, - pim?: ParamIndexMap, -): string { +function renderSubqueryExpr(expr: SubqueryExpr, ctx: RenderCtx): string { assertScalarSubquery(expr.query); - return `(${renderSelect(expr.query, contract, pim)})`; + return `(${renderSelect(expr.query, ctx)})`; } -function renderWhere( - expr: AnyExpression, - contract?: PostgresContract, - pim?: ParamIndexMap, -): string { - return renderExpr(expr, contract, pim); +function renderWhere(expr: AnyExpression, ctx: RenderCtx): string { + return renderExpr(expr, ctx); } -function renderNullCheck( - expr: NullCheckExpr, - contract?: PostgresContract, - pim?: ParamIndexMap, -): string { - const rendered = renderExpr(expr.expr, contract, pim); +function renderNullCheck(expr: NullCheckExpr, ctx: RenderCtx): string { + const rendered = renderExpr(expr.expr, ctx); const renderedExpr = expr.expr.kind === 'operation' || expr.expr.kind === 'subquery' ? `(${rendered})` : rendered; return expr.isNull ? `${renderedExpr} IS NULL` : `${renderedExpr} IS NOT NULL`; } -function renderBinary(expr: BinaryExpr, contract?: PostgresContract, pim?: ParamIndexMap): string { +function resolveNativeType(codecId: string, codecs: CodecRegistry): string { + const codec = codecs.get(codecId); + const nativeType = codec?.meta?.db?.sql?.postgres?.nativeType; + if (!nativeType) { + throw new Error(`Unknown codec ID for cast: ${codecId}`); + } + return nativeType; +} + +function renderCast(expr: CastExpr, ctx: RenderCtx): string { + const inner = renderExpr(expr.expr, ctx); + const nativeType = resolveNativeType(expr.targetCodecId, ctx.codecs); + return `(${inner})::${nativeType}`; +} + +function renderBinary(expr: BinaryExpr, ctx: RenderCtx): string { if (expr.right.kind === 'list' && expr.right.values.length === 0) { if (expr.op === 'in') { return 'FALSE'; @@ -303,7 +310,7 @@ function renderBinary(expr: BinaryExpr, contract?: PostgresContract, pim?: Param } const leftExpr = expr.left; - const left = renderExpr(leftExpr, contract, pim); + const left = renderExpr(leftExpr, ctx); const leftRendered = leftExpr.kind === 'operation' || leftExpr.kind === 'subquery' ? `(${left})` : left; @@ -311,7 +318,7 @@ function renderBinary(expr: BinaryExpr, contract?: PostgresContract, pim?: Param let right: string; switch (rightNode.kind) { case 'list': - right = renderListLiteral(rightNode, pim); + right = renderListLiteral(rightNode, ctx); break; case 'literal': right = renderLiteral(rightNode); @@ -320,10 +327,10 @@ function renderBinary(expr: BinaryExpr, contract?: PostgresContract, pim?: Param right = renderColumn(rightNode); break; case 'param-ref': - right = renderParamRef(rightNode, pim); + right = renderParamRef(rightNode, ctx.pim); break; default: - right = renderExpr(rightNode, contract, pim); + right = renderExpr(rightNode, ctx); break; } @@ -338,20 +345,27 @@ function renderBinary(expr: BinaryExpr, contract?: PostgresContract, pim?: Param ilike: 'ILIKE', in: 'IN', notIn: 'NOT IN', + add: '+', + sub: '-', + mul: '*', + div: '/', + mod: '%', }; - return `${leftRendered} ${operatorMap[expr.op]} ${right}`; + const arithmeticOps: ReadonlySet = new Set(['add', 'sub', 'mul', 'div', 'mod']); + const result = `${leftRendered} ${operatorMap[expr.op]} ${right}`; + return arithmeticOps.has(expr.op) ? `(${result})` : result; } -function renderListLiteral(expr: ListExpression, pim?: ParamIndexMap): string { +function renderListLiteral(expr: ListExpression, ctx: RenderCtx): string { if (expr.values.length === 0) { return '(NULL)'; } const values = expr.values .map((v) => { - if (v.kind === 'param-ref') return renderParamRef(v, pim); + if (v.kind === 'param-ref') return renderParamRef(v, ctx.pim); if (v.kind === 'literal') return renderLiteral(v); - return renderExpr(v, undefined, pim); + return renderExpr(v, ctx); }) .join(', '); return `(${values})`; @@ -364,62 +378,44 @@ function renderColumn(ref: ColumnRef): string { return `${quoteIdentifier(ref.table)}.${quoteIdentifier(ref.column)}`; } -function renderAggregateExpr( - expr: AggregateExpr, - contract?: PostgresContract, - pim?: ParamIndexMap, -): string { +function renderAggregateExpr(expr: AggregateExpr, ctx: RenderCtx): string { const fn = expr.fn.toUpperCase(); if (!expr.expr) { return `${fn}(*)`; } - return `${fn}(${renderExpr(expr.expr, contract, pim)})`; + return `${fn}(${renderExpr(expr.expr, ctx)})`; } -function renderJsonObjectExpr( - expr: JsonObjectExpr, - contract?: PostgresContract, - pim?: ParamIndexMap, -): string { +function renderJsonObjectExpr(expr: JsonObjectExpr, ctx: RenderCtx): string { const args = expr.entries .flatMap((entry): [string, string] => { const key = `'${escapeLiteral(entry.key)}'`; if (entry.value.kind === 'literal') { return [key, renderLiteral(entry.value)]; } - return [key, renderExpr(entry.value, contract, pim)]; + return [key, renderExpr(entry.value, ctx)]; }) .join(', '); return `json_build_object(${args})`; } -function renderOrderByItems( - items: ReadonlyArray, - contract?: PostgresContract, - pim?: ParamIndexMap, -): string { - return items - .map((item) => `${renderExpr(item.expr, contract, pim)} ${item.dir.toUpperCase()}`) - .join(', '); +function renderOrderByItems(items: ReadonlyArray, ctx: RenderCtx): string { + return items.map((item) => `${renderExpr(item.expr, ctx)} ${item.dir.toUpperCase()}`).join(', '); } -function renderJsonArrayAggExpr( - expr: JsonArrayAggExpr, - contract?: PostgresContract, - pim?: ParamIndexMap, -): string { +function renderJsonArrayAggExpr(expr: JsonArrayAggExpr, ctx: RenderCtx): string { const aggregateOrderBy = expr.orderBy && expr.orderBy.length > 0 - ? ` ORDER BY ${renderOrderByItems(expr.orderBy, contract, pim)}` + ? ` ORDER BY ${renderOrderByItems(expr.orderBy, ctx)}` : ''; - const aggregated = `json_agg(${renderExpr(expr.expr, contract, pim)}${aggregateOrderBy})`; + const aggregated = `json_agg(${renderExpr(expr.expr, ctx)}${aggregateOrderBy})`; if (expr.onEmpty === 'emptyArray') { return `coalesce(${aggregated}, json_build_array())`; } return aggregated; } -function renderExpr(expr: AnyExpression, contract?: PostgresContract, pim?: ParamIndexMap): string { +function renderExpr(expr: AnyExpression, ctx: RenderCtx): string { const node = expr; switch (node.kind) { case 'column-ref': @@ -427,42 +423,44 @@ function renderExpr(expr: AnyExpression, contract?: PostgresContract, pim?: Para case 'identifier-ref': return quoteIdentifier(node.name); case 'operation': - return renderOperation(node, contract, pim); + return renderOperation(node, ctx); case 'subquery': - return renderSubqueryExpr(node, contract, pim); + return renderSubqueryExpr(node, ctx); case 'aggregate': - return renderAggregateExpr(node, contract, pim); + return renderAggregateExpr(node, ctx); case 'json-object': - return renderJsonObjectExpr(node, contract, pim); + return renderJsonObjectExpr(node, ctx); case 'json-array-agg': - return renderJsonArrayAggExpr(node, contract, pim); + return renderJsonArrayAggExpr(node, ctx); case 'binary': - return renderBinary(node, contract, pim); + return renderBinary(node, ctx); case 'and': if (node.exprs.length === 0) { return 'TRUE'; } - return `(${node.exprs.map((part) => renderExpr(part, contract, pim)).join(' AND ')})`; + return `(${node.exprs.map((part) => renderExpr(part, ctx)).join(' AND ')})`; case 'or': if (node.exprs.length === 0) { return 'FALSE'; } - return `(${node.exprs.map((part) => renderExpr(part, contract, pim)).join(' OR ')})`; + return `(${node.exprs.map((part) => renderExpr(part, ctx)).join(' OR ')})`; case 'exists': { const notKeyword = node.notExists ? 'NOT ' : ''; - const subquery = renderSelect(node.subquery, contract, pim); + const subquery = renderSelect(node.subquery, ctx); return `${notKeyword}EXISTS (${subquery})`; } case 'null-check': - return renderNullCheck(node, contract, pim); + return renderNullCheck(node, ctx); case 'not': - return `NOT (${renderExpr(node.expr, contract, pim)})`; + return `NOT (${renderExpr(node.expr, ctx)})`; + case 'cast': + return renderCast(node, ctx); case 'param-ref': - return renderParamRef(node, pim); + return renderParamRef(node, ctx.pim); case 'literal': return renderLiteral(node); case 'list': - return renderListLiteral(node, pim); + return renderListLiteral(node, ctx); // v8 ignore next 4 default: throw new Error( @@ -508,14 +506,10 @@ function renderLiteral(expr: LiteralExpr): string { return `'${escapeLiteral(json)}'`; } -function renderOperation( - expr: OperationExpr, - contract?: PostgresContract, - pim?: ParamIndexMap, -): string { - const self = renderExpr(expr.self, contract, pim); +function renderOperation(expr: OperationExpr, ctx: RenderCtx): string { + const self = renderExpr(expr.self, ctx); const args = expr.args.map((arg) => { - return renderExpr(arg, contract, pim); + return renderExpr(arg, ctx); }); let result = expr.lowering.template; @@ -527,21 +521,21 @@ function renderOperation( return result; } -function renderJoin(join: JoinAst, contract?: PostgresContract, pim?: ParamIndexMap): string { +function renderJoin(join: JoinAst, ctx: RenderCtx): string { const joinType = join.joinType.toUpperCase(); const lateral = join.lateral ? 'LATERAL ' : ''; - const source = renderSource(join.source, contract, pim); - const onClause = renderJoinOn(join.on, contract, pim); + const source = renderSource(join.source, ctx); + const onClause = renderJoinOn(join.on, ctx); return `${joinType} JOIN ${lateral}${source} ON ${onClause}`; } -function renderJoinOn(on: JoinOnExpr, contract?: PostgresContract, pim?: ParamIndexMap): string { +function renderJoinOn(on: JoinOnExpr, ctx: RenderCtx): string { if (on.kind === 'eq-col-join-on') { const left = renderColumn(on.left); const right = renderColumn(on.right); return `${left} = ${right}`; } - return renderWhere(on, contract, pim); + return renderWhere(on, ctx); } function getInsertColumnOrder( @@ -587,7 +581,11 @@ function renderInsertValue(value: InsertValue | undefined, pim?: ParamIndexMap): } } -function renderInsert(ast: InsertAst, contract: PostgresContract, pim?: ParamIndexMap): string { +function renderInsert(ast: InsertAst, ctx: RenderCtx): string { + const contract = ctx.contract; + if (!contract) { + throw new Error('INSERT requires a contract'); + } const table = quoteIdentifier(ast.table.name); const rows = ast.rows; if (rows.length === 0) { @@ -616,7 +614,7 @@ function renderInsert(ast: InsertAst, contract: PostgresContract, pim?: ParamInd const columns = columnOrder.map((column) => quoteIdentifier(column)); const values = rows .map((row) => { - const renderedRow = columnOrder.map((column) => renderInsertValue(row[column], pim)); + const renderedRow = columnOrder.map((column) => renderInsertValue(row[column], ctx.pim)); return `(${renderedRow.join(', ')})`; }) .join(', '); @@ -638,7 +636,7 @@ function renderInsert(ast: InsertAst, contract: PostgresContract, pim?: ParamInd const updates = Object.entries(action.set).map(([colName, value]) => { const target = quoteIdentifier(colName); if (value.kind === 'param-ref') { - return `${target} = ${renderParamRef(value, pim)}`; + return `${target} = ${renderParamRef(value, ctx.pim)}`; } return `${target} = ${renderColumn(value)}`; }); @@ -659,14 +657,14 @@ function renderInsert(ast: InsertAst, contract: PostgresContract, pim?: ParamInd return `${insertClause}${onConflictClause}${returningClause}`; } -function renderUpdate(ast: UpdateAst, contract: PostgresContract, pim?: ParamIndexMap): string { +function renderUpdate(ast: UpdateAst, ctx: RenderCtx): string { const table = quoteIdentifier(ast.table.name); const setClauses = Object.entries(ast.set).map(([col, val]) => { const column = quoteIdentifier(col); let value: string; switch (val.kind) { case 'param-ref': - value = renderParamRef(val, pim); + value = renderParamRef(val, ctx.pim); break; case 'column-ref': value = renderColumn(val); @@ -680,7 +678,7 @@ function renderUpdate(ast: UpdateAst, contract: PostgresContract, pim?: ParamInd return `${column} = ${value}`; }); - const whereClause = ast.where ? ` WHERE ${renderWhere(ast.where, contract, pim)}` : ''; + const whereClause = ast.where ? ` WHERE ${renderWhere(ast.where, ctx)}` : ''; const returningClause = ast.returning?.length ? ` RETURNING ${ast.returning.map((col) => `${quoteIdentifier(col.table)}.${quoteIdentifier(col.column)}`).join(', ')}` : ''; @@ -688,9 +686,9 @@ function renderUpdate(ast: UpdateAst, contract: PostgresContract, pim?: ParamInd return `UPDATE ${table} SET ${setClauses.join(', ')}${whereClause}${returningClause}`; } -function renderDelete(ast: DeleteAst, contract?: PostgresContract, pim?: ParamIndexMap): string { +function renderDelete(ast: DeleteAst, ctx: RenderCtx): string { const table = quoteIdentifier(ast.table.name); - const whereClause = ast.where ? ` WHERE ${renderWhere(ast.where, contract, pim)}` : ''; + const whereClause = ast.where ? ` WHERE ${renderWhere(ast.where, ctx)}` : ''; const returningClause = ast.returning?.length ? ` RETURNING ${ast.returning.map((col) => `${quoteIdentifier(col.table)}.${quoteIdentifier(col.column)}`).join(', ')}` : ''; diff --git a/packages/3-targets/6-adapters/postgres/test/adapter.test.ts b/packages/3-targets/6-adapters/postgres/test/adapter.test.ts index d263056514..874382ca2b 100644 --- a/packages/3-targets/6-adapters/postgres/test/adapter.test.ts +++ b/packages/3-targets/6-adapters/postgres/test/adapter.test.ts @@ -4,6 +4,7 @@ import { AndExpr, type AnyQueryAst, BinaryExpr, + CastExpr, ColumnRef, DefaultValueExpr, DeleteAst, @@ -274,4 +275,57 @@ describe('Postgres adapter', () => { expect(sql).toContain('WHERE FALSE'); }); + + it('renders arithmetic operations with parentheses', () => { + const ast = SelectAst.from(TableSource.named('user')).withProjection([ + ProjectionItem.of('sum', BinaryExpr.add(ColumnRef.of('user', 'id'), LiteralExpr.of(1))), + ProjectionItem.of('diff', BinaryExpr.sub(ColumnRef.of('user', 'id'), LiteralExpr.of(1))), + ProjectionItem.of('prod', BinaryExpr.mul(ColumnRef.of('user', 'id'), LiteralExpr.of(2))), + ProjectionItem.of('quot', BinaryExpr.div(ColumnRef.of('user', 'id'), LiteralExpr.of(2))), + ProjectionItem.of('rem', BinaryExpr.mod(ColumnRef.of('user', 'id'), LiteralExpr.of(2))), + ]); + + const sql = adapter.lower(ast, { contract, params: [] }).body.sql; + + expect(sql).toBe( + 'SELECT ("user"."id" + 1) AS "sum", ("user"."id" - 1) AS "diff", ("user"."id" * 2) AS "prod", ("user"."id" / 2) AS "quot", ("user"."id" % 2) AS "rem" FROM "user"', + ); + }); + + it('renders arithmetic in WHERE clause', () => { + const ast = SelectAst.from(TableSource.named('user')) + .withProjection([ProjectionItem.of('id', ColumnRef.of('user', 'id'))]) + .withWhere( + BinaryExpr.gt( + BinaryExpr.add(ColumnRef.of('user', 'id'), LiteralExpr.of(1)), + LiteralExpr.of(5), + ), + ); + + const sql = adapter.lower(ast, { contract, params: [] }).body.sql; + + expect(sql).toContain('WHERE ("user"."id" + 1) > 5'); + }); + + it('renders type cast with Postgres :: syntax', () => { + const ast = SelectAst.from(TableSource.named('user')).withProjection([ + ProjectionItem.of('idText', CastExpr.of(ColumnRef.of('user', 'id'), 'pg/text@1')), + ]); + + const sql = adapter.lower(ast, { contract, params: [] }).body.sql; + + expect(sql).toBe('SELECT ("user"."id")::text AS "idText" FROM "user"'); + }); + + it('renders cast in WHERE clause', () => { + const ast = SelectAst.from(TableSource.named('user')) + .withProjection([ProjectionItem.of('id', ColumnRef.of('user', 'id'))]) + .withWhere( + BinaryExpr.eq(CastExpr.of(ColumnRef.of('user', 'id'), 'pg/text@1'), LiteralExpr.of('1')), + ); + + const sql = adapter.lower(ast, { contract, params: [] }).body.sql; + + expect(sql).toContain(`WHERE ("user"."id")::text = '1'`); + }); });