Skip to content

Commit 48da40c

Browse files
committed
feat(sql-builder): add arithmetic operations and type casts
Add arithmetic operators (add, sub, mul, div, mod) and type cast expressions to the SQL builder DSL, needed for drizzle-benchmark parity. - Extend BinaryOp with arithmetic ops, reusing BinaryExpr - Add CastExpr AST node rendering as Postgres (expr)::type syntax - Resolve cast native types from CodecRegistry via RenderCtx - Refactor adapter render functions to use RenderCtx object
1 parent 50d26af commit 48da40c

8 files changed

Lines changed: 486 additions & 119 deletions

File tree

packages/2-sql/4-lanes/relational-core/src/ast/types.ts

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@ export type BinaryOp =
1414
| 'like'
1515
| 'ilike'
1616
| 'in'
17-
| 'notIn';
17+
| 'notIn'
18+
| 'add'
19+
| 'sub'
20+
| 'mul'
21+
| 'div'
22+
| 'mod';
1823

1924
export type AggregateCountFn = 'count';
2025
export type AggregateOpFn = 'sum' | 'avg' | 'min' | 'max';
@@ -52,6 +57,7 @@ export interface ExprVisitor<R> {
5257
exists(expr: ExistsExpr): R;
5358
nullCheck(expr: NullCheckExpr): R;
5459
not(expr: NotExpr): R;
60+
cast(expr: CastExpr): R;
5561
literal(expr: LiteralExpr): R;
5662
param(expr: ParamRef): R;
5763
list(expr: ListExpression): R;
@@ -839,6 +845,26 @@ export class BinaryExpr extends Expression {
839845
return new BinaryExpr('notIn', left, right);
840846
}
841847

848+
static add(left: AnyExpression, right: AnyExpression): BinaryExpr {
849+
return new BinaryExpr('add', left, right);
850+
}
851+
852+
static sub(left: AnyExpression, right: AnyExpression): BinaryExpr {
853+
return new BinaryExpr('sub', left, right);
854+
}
855+
856+
static mul(left: AnyExpression, right: AnyExpression): BinaryExpr {
857+
return new BinaryExpr('mul', left, right);
858+
}
859+
860+
static div(left: AnyExpression, right: AnyExpression): BinaryExpr {
861+
return new BinaryExpr('div', left, right);
862+
}
863+
864+
static mod(left: AnyExpression, right: AnyExpression): BinaryExpr {
865+
return new BinaryExpr('mod', left, right);
866+
}
867+
842868
override accept<R>(visitor: ExprVisitor<R>): R {
843869
return visitor.binary(this);
844870
}
@@ -1020,6 +1046,35 @@ export class NotExpr extends Expression {
10201046
}
10211047
}
10221048

1049+
export class CastExpr extends Expression {
1050+
readonly kind = 'cast' as const;
1051+
readonly expr: AnyExpression;
1052+
readonly targetCodecId: string;
1053+
1054+
constructor(expr: AnyExpression, targetCodecId: string) {
1055+
super();
1056+
this.expr = expr;
1057+
this.targetCodecId = targetCodecId;
1058+
this.freeze();
1059+
}
1060+
1061+
static of(expr: AnyExpression, targetCodecId: string): CastExpr {
1062+
return new CastExpr(expr, targetCodecId);
1063+
}
1064+
1065+
override accept<R>(visitor: ExprVisitor<R>): R {
1066+
return visitor.cast(this);
1067+
}
1068+
1069+
override rewrite(rewriter: ExpressionRewriter): AnyExpression {
1070+
return new CastExpr(this.expr.rewrite(rewriter), this.targetCodecId);
1071+
}
1072+
1073+
override fold<T>(folder: ExpressionFolder<T>): T {
1074+
return this.expr.fold(folder);
1075+
}
1076+
}
1077+
10231078
export class EqColJoinOn extends AstNode {
10241079
readonly kind = 'eq-col-join-on' as const;
10251080
readonly left: ColumnRef;
@@ -1714,7 +1769,8 @@ export type AnyExpression =
17141769
| OrExpr
17151770
| ExistsExpr
17161771
| NullCheckExpr
1717-
| NotExpr;
1772+
| NotExpr
1773+
| CastExpr;
17181774
export type AnyInsertOnConflictAction = DoNothingConflictAction | DoUpdateSetConflictAction;
17191775
export type AnyInsertValue = ColumnRef | ParamRef | DefaultValueExpr;
17201776
export type AnyOperationArg = AnyExpression | ParamRef | LiteralExpr;

packages/2-sql/4-lanes/relational-core/test/ast/kind-discriminants.test.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import {
33
AggregateExpr,
44
AndExpr,
55
BinaryExpr,
6+
CastExpr,
67
DefaultValueExpr,
78
DeleteAst,
89
DerivedTableSource,
@@ -58,6 +59,7 @@ const allKindEntries: Array<[string, { kind: string }]> = [
5859
['InsertOnConflict', InsertOnConflict.on([col('t', 'id')])],
5960
['DoNothingConflictAction', new DoNothingConflictAction()],
6061
['DoUpdateSetConflictAction', new DoUpdateSetConflictAction({ id: col('t', 'id') })],
62+
['CastExpr', CastExpr.of(col('t', 'id'), 'pg/text@1')],
6163
];
6264

6365
describe('AST kind discriminants', () => {
@@ -90,6 +92,7 @@ describe('AST kind discriminants', () => {
9092
['InsertOnConflict', 'insert-on-conflict'],
9193
['DoNothingConflictAction', 'do-nothing'],
9294
['DoUpdateSetConflictAction', 'do-update-set'],
95+
['CastExpr', 'cast'],
9396
])('%s has kind "%s"', (className, expectedKind) => {
9497
const entry = allKindEntries.find(([name]) => name === className);
9598
expect(entry).toBeDefined();

packages/2-sql/4-lanes/sql-builder/src/expression.ts

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,32 @@ export type BuiltinFunctions<CT extends Record<string, { readonly input: unknown
108108
and: (...ands: ExpressionOrValue<BooleanCodecType, CT>[]) => Expression<BooleanCodecType>;
109109
or: (...ors: ExpressionOrValue<BooleanCodecType, CT>[]) => Expression<BooleanCodecType>;
110110

111+
add: <T extends ScopeField>(
112+
a: ExpressionOrValue<T, CT>,
113+
b: ExpressionOrValue<T, CT>,
114+
) => Expression<T>;
115+
sub: <T extends ScopeField>(
116+
a: ExpressionOrValue<T, CT>,
117+
b: ExpressionOrValue<T, CT>,
118+
) => Expression<T>;
119+
mul: <T extends ScopeField>(
120+
a: ExpressionOrValue<T, CT>,
121+
b: ExpressionOrValue<T, CT>,
122+
) => Expression<T>;
123+
div: <T extends ScopeField>(
124+
a: ExpressionOrValue<T, CT>,
125+
b: ExpressionOrValue<T, CT>,
126+
) => Expression<T>;
127+
mod: <T extends ScopeField>(
128+
a: ExpressionOrValue<T, CT>,
129+
b: ExpressionOrValue<T, CT>,
130+
) => Expression<T>;
131+
132+
cast: <TargetCodecId extends string, Nullable extends boolean>(
133+
expr: Expression<ScopeField>,
134+
target: { codecId: TargetCodecId; nullable: Nullable },
135+
) => Expression<{ codecId: TargetCodecId; nullable: Nullable }>;
136+
111137
exists: (subquery: Subquery<Record<string, ScopeField>>) => Expression<BooleanCodecType>;
112138
notExists: (subquery: Subquery<Record<string, ScopeField>>) => Expression<BooleanCodecType>;
113139

packages/2-sql/4-lanes/sql-builder/src/runtime/functions.ts

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import {
44
type AnyExpression as AstExpression,
55
BinaryExpr,
66
type BinaryOp,
7+
CastExpr,
78
ExistsExpr,
89
ListExpression,
910
LiteralExpr,
@@ -76,6 +77,24 @@ function inOrNotIn(
7677
return boolExpr(binaryFn(left, SubqueryExpr.of(valuesOrSubquery.buildAst())));
7778
}
7879

80+
function arithmetic<T extends ScopeField>(
81+
a: ExpressionOrValue<T, CodecTypes>,
82+
b: ExpressionOrValue<T, CodecTypes>,
83+
op: BinaryOp,
84+
): ExpressionImpl<T> {
85+
const field = (
86+
a instanceof ExpressionImpl
87+
? a.field
88+
: b instanceof ExpressionImpl
89+
? b.field
90+
: { codecId: 'unknown', nullable: false }
91+
) as T;
92+
return new ExpressionImpl(
93+
new BinaryExpr(op, resolve(a as ExprOrVal), resolve(b as ExprOrVal)),
94+
field,
95+
);
96+
}
97+
7998
function numericAgg(
8099
fn: 'sum' | 'avg' | 'min' | 'max',
81100
expr: Expression<ScopeField>,
@@ -96,6 +115,32 @@ function createBuiltinFunctions() {
96115
lte: (a: ExprOrVal, b: ExprOrVal) => comparison(a, b, 'lte'),
97116
and: (...exprs: ExprOrVal<BooleanCodecType>[]) => boolExpr(AndExpr.of(exprs.map(resolveToAst))),
98117
or: (...exprs: ExprOrVal<BooleanCodecType>[]) => boolExpr(OrExpr.of(exprs.map(resolveToAst))),
118+
add: <T extends ScopeField>(
119+
a: ExpressionOrValue<T, CodecTypes>,
120+
b: ExpressionOrValue<T, CodecTypes>,
121+
) => arithmetic(a, b, 'add'),
122+
sub: <T extends ScopeField>(
123+
a: ExpressionOrValue<T, CodecTypes>,
124+
b: ExpressionOrValue<T, CodecTypes>,
125+
) => arithmetic(a, b, 'sub'),
126+
mul: <T extends ScopeField>(
127+
a: ExpressionOrValue<T, CodecTypes>,
128+
b: ExpressionOrValue<T, CodecTypes>,
129+
) => arithmetic(a, b, 'mul'),
130+
div: <T extends ScopeField>(
131+
a: ExpressionOrValue<T, CodecTypes>,
132+
b: ExpressionOrValue<T, CodecTypes>,
133+
) => arithmetic(a, b, 'div'),
134+
mod: <T extends ScopeField>(
135+
a: ExpressionOrValue<T, CodecTypes>,
136+
b: ExpressionOrValue<T, CodecTypes>,
137+
) => arithmetic(a, b, 'mod'),
138+
cast: <TargetCodecId extends string, Nullable extends boolean>(
139+
expr: Expression<ScopeField>,
140+
target: { codecId: TargetCodecId; nullable: Nullable },
141+
) => {
142+
return new ExpressionImpl(CastExpr.of(expr.buildAst(), target.codecId), target);
143+
},
99144
exists: (subquery: Subquery<Record<string, ScopeField>>) =>
100145
boolExpr(ExistsExpr.exists(subquery.buildAst())),
101146
notExists: (subquery: Subquery<Record<string, ScopeField>>) =>
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import { describe, expect, it } from 'vitest';
2+
import { setupIntegrationTest } from './setup';
3+
4+
describe('integration: arithmetic operations', () => {
5+
const { db, runtime } = setupIntegrationTest();
6+
7+
it('add computes column + literal', async () => {
8+
const rows = await runtime().execute(
9+
db()
10+
.posts.select('id')
11+
.select('viewsPlus', (f, fns) => fns.add(f.views, 10))
12+
.where((f, fns) => fns.eq(f.id, 1))
13+
.build(),
14+
);
15+
expect(rows).toHaveLength(1);
16+
expect(rows[0]!.viewsPlus).toBe(110);
17+
});
18+
19+
it('sub computes column - literal', async () => {
20+
const rows = await runtime().execute(
21+
db()
22+
.posts.select('id')
23+
.select('viewsMinus', (f, fns) => fns.sub(f.views, 10))
24+
.where((f, fns) => fns.eq(f.id, 1))
25+
.build(),
26+
);
27+
expect(rows[0]!.viewsMinus).toBe(90);
28+
});
29+
30+
it('mul computes column * literal', async () => {
31+
const rows = await runtime().execute(
32+
db()
33+
.posts.select('id')
34+
.select('viewsDouble', (f, fns) => fns.mul(f.views, 2))
35+
.where((f, fns) => fns.eq(f.id, 1))
36+
.build(),
37+
);
38+
expect(rows[0]!.viewsDouble).toBe(200);
39+
});
40+
41+
it('div computes column / literal', async () => {
42+
const rows = await runtime().execute(
43+
db()
44+
.posts.select('id')
45+
.select('viewsHalf', (f, fns) => fns.div(f.views, 2))
46+
.where((f, fns) => fns.eq(f.id, 1))
47+
.build(),
48+
);
49+
expect(rows[0]!.viewsHalf).toBe(50);
50+
});
51+
52+
it('mod computes column % literal', async () => {
53+
const rows = await runtime().execute(
54+
db()
55+
.posts.select('id')
56+
.select('viewsMod', (f, fns) => fns.mod(f.views, 3))
57+
.where((f, fns) => fns.eq(f.id, 1))
58+
.build(),
59+
);
60+
expect(rows[0]!.viewsMod).toBe(1);
61+
});
62+
63+
it('arithmetic in WHERE clause filters correctly', async () => {
64+
const rows = await runtime().execute(
65+
db()
66+
.posts.select('id', 'views')
67+
.where((f, fns) => fns.gt(fns.add(f.views, 50), 150))
68+
.build(),
69+
);
70+
expect(rows.every((r) => r.views + 50 > 150)).toBe(true);
71+
});
72+
73+
it('nested arithmetic works', async () => {
74+
const rows = await runtime().execute(
75+
db()
76+
.posts.select('id')
77+
.select('computed', (f, fns) => fns.add(fns.mul(f.views, 2), 1))
78+
.where((f, fns) => fns.eq(f.id, 1))
79+
.build(),
80+
);
81+
expect(rows[0]!.computed).toBe(201);
82+
});
83+
});
84+
85+
describe('integration: type cast', () => {
86+
const { db, runtime } = setupIntegrationTest();
87+
88+
it('cast int to text', async () => {
89+
const rows = await runtime().execute(
90+
db()
91+
.posts.select('id')
92+
.select('idText', (f, fns) => fns.cast(f.id, { codecId: 'pg/text@1', nullable: false }))
93+
.where((f, fns) => fns.eq(f.id, 1))
94+
.build(),
95+
);
96+
expect(rows[0]!.idText).toBe('1');
97+
});
98+
99+
it('cast used in order by', async () => {
100+
const rows = await runtime().execute(
101+
db()
102+
.posts.select('id')
103+
.select('idText', (f, fns) => fns.cast(f.id, { codecId: 'pg/text@1', nullable: false }))
104+
.orderBy((f) => f.idText)
105+
.build(),
106+
);
107+
expect(rows).toHaveLength(4);
108+
expect(rows.map((r) => r.idText)).toEqual(['1', '2', '3', '4']);
109+
});
110+
});

0 commit comments

Comments
 (0)