diff --git a/cel2sql.go b/cel2sql.go index fe40528..4f90bcd 100644 --- a/cel2sql.go +++ b/cel2sql.go @@ -37,7 +37,6 @@ func (con *converter) visit(expr *exprpb.Expr) error { switch expr.ExprKind.(type) { case *exprpb.Expr_CallExpr: return con.visitCall(expr) - // TODO: Comprehensions are currently not supported. case *exprpb.Expr_ComprehensionExpr: return con.visitComprehension(expr) case *exprpb.Expr_ConstExpr: @@ -279,9 +278,74 @@ func (con *converter) visitCallUnary(expr *exprpb.Expr) error { } func (con *converter) visitComprehension(expr *exprpb.Expr) error { - // TODO: introduce a macro expansion map between the top-level comprehension id and the - // function call that the macro replaces. - return fmt.Errorf("unimplemented : %v", expr) + c := expr.GetComprehensionExpr() + + loopStep := c.GetLoopStep().GetCallExpr() + if loopStep == nil { + return fmt.Errorf("unsupported macro") + } + loopFunction := loopStep.GetFunction() + loopArgs := loopStep.GetArgs() + + con.str.WriteString("SELECT ") + switch loopFunction { + case operators.LogicalAnd: + con.str.WriteString("COUNT(*) = 0") + case operators.LogicalOr: + con.str.WriteString("COUNT(*) > 0") + case operators.Add: + // map + con.str.WriteString("ARRAY_AGG(") + if err := con.visit(loopArgs[1].GetListExpr().GetElements()[0]); err != nil { + return err + } + con.str.WriteString(")") + case operators.Conditional: + switch loopArgs[1].GetCallExpr().GetArgs()[1].ExprKind.(type) { + case *exprpb.Expr_ConstExpr: + // exists_one + con.str.WriteString("COUNT(*) = 1") + case *exprpb.Expr_ListExpr: + // filter + con.str.WriteString("ARRAY_AGG(") + if err := con.visit(loopArgs[1].GetCallExpr().GetArgs()[1].GetListExpr().GetElements()[0]); err != nil { + return err + } + con.str.WriteString(")") + default: + return fmt.Errorf("unsupported macro") + } + default: + return fmt.Errorf("unsupported macro") + } + con.str.WriteString(" FROM UNNEST(") + if err := con.visit(c.GetIterRange()); err != nil { + return err + } + con.str.WriteString(") AS `") + con.str.WriteString(c.GetIterVar()) + con.str.WriteString("`") + switch loopFunction { + case operators.LogicalAnd, operators.LogicalOr, operators.Conditional: + con.str.WriteString(" WHERE ") + } + switch loopFunction { + case operators.LogicalAnd: + con.str.WriteString("NOT (") + if err := con.visit(loopArgs[1]); err != nil { + return err + } + con.str.WriteString(")") + case operators.LogicalOr: + if err := con.visit(loopArgs[1]); err != nil { + return err + } + case operators.Conditional: + if err := con.visit(loopArgs[0]); err != nil { + return err + } + } + return nil } func (con *converter) visitConst(expr *exprpb.Expr) error { diff --git a/cel2sql_test.go b/cel2sql_test.go index 66bc310..7197535 100644 --- a/cel2sql_test.go +++ b/cel2sql_test.go @@ -252,6 +252,36 @@ func TestConvert(t *testing.T) { want: "\"test\" IN UNNEST(`trigram`.`cell`[OFFSET(0)].`value`)", wantErr: false, }, + { + name: "macro_all", + args: args{source: `trigram.cell.all(cell, cell.page_count > 0)`}, + want: "SELECT COUNT(*) = 0 FROM UNNEST(`trigram`.`cell`) AS `cell` WHERE NOT (`cell`.`page_count` > 0)", + wantErr: false, + }, + { + name: "macro_exists", + args: args{source: `trigram.cell.exists(cell, cell.page_count > 0)`}, + want: "SELECT COUNT(*) > 0 FROM UNNEST(`trigram`.`cell`) AS `cell` WHERE `cell`.`page_count` > 0", + wantErr: false, + }, + { + name: "macro_exists_one", + args: args{source: `trigram.cell.exists_one(cell, cell.page_count > 0)`}, + want: "SELECT COUNT(*) = 1 FROM UNNEST(`trigram`.`cell`) AS `cell` WHERE `cell`.`page_count` > 0", + wantErr: false, + }, + { + name: "macro_filter", + args: args{source: `trigram.cell.filter(cell, cell.page_count > 0)`}, + want: "SELECT ARRAY_AGG(`cell`) FROM UNNEST(`trigram`.`cell`) AS `cell` WHERE `cell`.`page_count` > 0", + wantErr: false, + }, + { + name: "macro_map", + args: args{source: `trigram.cell.map(cell, cell.page_count + 1)`}, + want: "SELECT ARRAY_AGG(`cell`.`page_count` + 1) FROM UNNEST(`trigram`.`cell`) AS `cell`", + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {