Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 68 additions & 4 deletions cel2sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ func (con *converter) visit(expr *exprpb.Expr) error {
switch expr.ExprKind.(type) {
case *exprpb.Expr_CallExpr:
return con.visitCall(expr)
// TODO: Comprehensions are currently not supported.
case *exprpb.Expr_ComprehensionExpr:
return con.visitComprehension(expr)
case *exprpb.Expr_ConstExpr:
Expand Down Expand Up @@ -279,9 +278,74 @@ func (con *converter) visitCallUnary(expr *exprpb.Expr) error {
}

func (con *converter) visitComprehension(expr *exprpb.Expr) error {
// TODO: introduce a macro expansion map between the top-level comprehension id and the
// function call that the macro replaces.
return fmt.Errorf("unimplemented : %v", expr)
c := expr.GetComprehensionExpr()

loopStep := c.GetLoopStep().GetCallExpr()
if loopStep == nil {
return fmt.Errorf("unsupported macro")
}
loopFunction := loopStep.GetFunction()
loopArgs := loopStep.GetArgs()

con.str.WriteString("SELECT ")
switch loopFunction {
case operators.LogicalAnd:
con.str.WriteString("COUNT(*) = 0")
case operators.LogicalOr:
con.str.WriteString("COUNT(*) > 0")
case operators.Add:
// map
con.str.WriteString("ARRAY_AGG(")
if err := con.visit(loopArgs[1].GetListExpr().GetElements()[0]); err != nil {
return err
}
con.str.WriteString(")")
case operators.Conditional:
switch loopArgs[1].GetCallExpr().GetArgs()[1].ExprKind.(type) {
case *exprpb.Expr_ConstExpr:
// exists_one
con.str.WriteString("COUNT(*) = 1")
case *exprpb.Expr_ListExpr:
// filter
con.str.WriteString("ARRAY_AGG(")
if err := con.visit(loopArgs[1].GetCallExpr().GetArgs()[1].GetListExpr().GetElements()[0]); err != nil {
return err
}
con.str.WriteString(")")
default:
return fmt.Errorf("unsupported macro")
}
default:
return fmt.Errorf("unsupported macro")
}
con.str.WriteString(" FROM UNNEST(")
if err := con.visit(c.GetIterRange()); err != nil {
return err
}
con.str.WriteString(") AS `")
con.str.WriteString(c.GetIterVar())
con.str.WriteString("`")
switch loopFunction {
case operators.LogicalAnd, operators.LogicalOr, operators.Conditional:
con.str.WriteString(" WHERE ")
}
switch loopFunction {
case operators.LogicalAnd:
con.str.WriteString("NOT (")
if err := con.visit(loopArgs[1]); err != nil {
return err
}
con.str.WriteString(")")
case operators.LogicalOr:
if err := con.visit(loopArgs[1]); err != nil {
return err
}
case operators.Conditional:
if err := con.visit(loopArgs[0]); err != nil {
return err
}
}
return nil
}

func (con *converter) visitConst(expr *exprpb.Expr) error {
Expand Down
30 changes: 30 additions & 0 deletions cel2sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,36 @@ func TestConvert(t *testing.T) {
want: "\"test\" IN UNNEST(`trigram`.`cell`[OFFSET(0)].`value`)",
wantErr: false,
},
{
name: "macro_all",
args: args{source: `trigram.cell.all(cell, cell.page_count > 0)`},
want: "SELECT COUNT(*) = 0 FROM UNNEST(`trigram`.`cell`) AS `cell` WHERE NOT (`cell`.`page_count` > 0)",
wantErr: false,
},
{
name: "macro_exists",
args: args{source: `trigram.cell.exists(cell, cell.page_count > 0)`},
want: "SELECT COUNT(*) > 0 FROM UNNEST(`trigram`.`cell`) AS `cell` WHERE `cell`.`page_count` > 0",
wantErr: false,
},
{
name: "macro_exists_one",
args: args{source: `trigram.cell.exists_one(cell, cell.page_count > 0)`},
want: "SELECT COUNT(*) = 1 FROM UNNEST(`trigram`.`cell`) AS `cell` WHERE `cell`.`page_count` > 0",
wantErr: false,
},
{
name: "macro_filter",
args: args{source: `trigram.cell.filter(cell, cell.page_count > 0)`},
want: "SELECT ARRAY_AGG(`cell`) FROM UNNEST(`trigram`.`cell`) AS `cell` WHERE `cell`.`page_count` > 0",
wantErr: false,
},
{
name: "macro_map",
args: args{source: `trigram.cell.map(cell, cell.page_count + 1)`},
want: "SELECT ARRAY_AGG(`cell`.`page_count` + 1) FROM UNNEST(`trigram`.`cell`) AS `cell`",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down