diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 237998c5..858c9026 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -344,6 +344,49 @@ impl Engine { v.to_json_str() } + /// Registers a custom Python function as a Rego extension. + /// + /// This allows you to define functions in Python that can be called directly + /// from your Rego policies. The Python function will be called synchronously + /// during policy evaluation. + /// + /// Arguments passed from Rego are automatically converted to their corresponding + /// Python types. The return value is converted back to a Rego value. + /// + /// * `path`: Full path to the function as it will be used in Rego. + /// * `nargs`: The number of arguments the function expects. + /// * `extension`: The Python function to execute. Must accept exactly `nargs` arguments. + /// + /// Note: When the engine is cloned, extensions share the same Python callable reference + /// rather than being deep-copied. Stateful callables will share state across clones. + pub fn add_extension(&mut self, path: String, nargs: u8, extension: Py) -> Result<()> { + Python::with_gil(|py| { + if !extension.bind(py).is_callable() { + return Err(anyhow!("extension '{}' must be callable", path)); + } + Ok(()) + })?; + + let func_ref = Arc::new(extension); + let path_clone = path.clone(); + + let extension_impl = move |args: Vec| -> Result { + Python::with_gil(|py| { + let py_args_vec: Result> = + args.into_iter().map(|arg| to(arg, py)).collect(); + let py_args = PyTuple::new(py, py_args_vec?)?; + let py_result = func_ref.call1(py, py_args).map_err(|e| { + anyhow!("extension '{}' raises Python error: {}", path_clone, e) + })?; + let rego_result = from(&py_result.into_bound(py))?; + Ok(rego_result) + }) + }; + + self.engine + .add_extension(path, nargs, Box::new(extension_impl)) + } + /// Enable code coverage /// /// * `enable`: Whether to enable coverage or not. diff --git a/bindings/python/test.py b/bindings/python/test.py index a1e20259..f1b56a72 100644 --- a/bindings/python/test.py +++ b/bindings/python/test.py @@ -1,6 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - import regorus import sys @@ -163,3 +162,228 @@ def run_host_await_example(): print(vm.resume('{"tier":"gold"}')) run_host_await_example() + +def test_extension_execution(): + rego = regorus.Engine() + rego.add_policy("demo", + """ + package demo + + result := greeting(a, b) if { + a := data.a + b := data.b + } + """) + + def custom_function(arg1, arg2): + return f"{arg1}, {arg2}!" + rego.add_extension("greeting", 2, custom_function) + + rego.add_data({"a": "Hello", "b": "World"}) + result = rego.eval_rule("data.demo.result") + assert result == "Hello, World!", f"Unexpected result: {result}" + +test_extension_execution() + +def test_extension_wrong_arity(): + rego = regorus.Engine() + rego.add_policy("demo", + """ + package demo + + result := greeting(a, b) if { + a := data.a + b := data.b + } + """) + + def custom_function(arg1, arg2): + return f"{arg1}, {arg2}!" + + rego.add_extension("greeting", 3, custom_function) + rego.add_data({"a": "Hello", "b": "World"}) + + try: + rego.eval_rule("data.demo.result") + except RuntimeError as ex: + assert "error: incorrect number of parameters supplied to extension" in str(ex) + else: + assert False, "exception not thrown" + +test_extension_wrong_arity() + +def test_extension_raises_exception(): + rego = regorus.Engine() + rego.add_policy("demo", + """ + package demo + + result := greeting(a, b) if { + a := data.a + b := data.b + } + """) + + def custom_function(arg1, arg2): + raise RuntimeError("unknown error") + + rego.add_extension("greeting", 2, custom_function) + rego.add_data({"a": "Hello", "b": "World"}) + + try: + rego.eval_rule("data.demo.result") + except RuntimeError as ex: + assert "error: extension 'greeting' raises Python error: RuntimeError: unknown error" in str(ex) + else: + assert False, "exception not thrown" + +test_extension_raises_exception() + + +def test_extension_zero_arg(): + rego = regorus.Engine() + rego.add_policy("demo", + """ + package demo + + result := greeting() + """) + + def custom_function(): + return "Hello, World!" + + rego.add_extension("greeting", 0, custom_function) + rego.add_data({"a": "Hello", "b": "World"}) + + result = rego.eval_rule("data.demo.result") + assert result == "Hello, World!", f"Unexpected result: {result}" + +test_extension_zero_arg() + +def test_extension_non_callable(): + rego = regorus.Engine() + rego.add_policy("demo", + """ + package demo + + result := greeting() + """) + + try: + rego.add_extension("greeting", 0, 123) + except RuntimeError as ex: + assert "extension 'greeting' must be callable" in str(ex) + else: + assert False, "exception not thrown" + +test_extension_non_callable() + + +def test_extension_duplicate(): + rego = regorus.Engine() + rego.add_policy("demo", + """ + package demo + + result := greeting() + """) + + def custom_function1(arg1, arg2): + return f"{arg1}, {arg2}!" + def custom_function2(arg1, arg2): + return f"{arg1}, {arg2}!" + + rego.add_extension("greeting", 0, custom_function1) + + try: + rego.add_extension("greeting", 0, custom_function2) + except RuntimeError as ex: + assert "extension already added" in str(ex) + else: + assert False, "exception not thrown" + +test_extension_duplicate() + + +def test_extension_types(): + rego = regorus.Engine() + rego.add_policy("demo", + """ + package demo + + i := custom.triple(10) + f := custom.triple(2.5) + b1 := custom.negate(true) + b2 := custom.negate(false) + + a := custom.first([true, null, 1]) + b := custom.first([null, null, 1]) + c := custom.first([null, null, null]) + + object := custom.modify_object({"a": 1, "b": 2}) + list := custom.modify_list([3, 4]) + set := custom.modify_set({5, 6}) + """) + + def triple(n): + return n*3 + + def negate(b): + return not b + + def first(lst): + for i in lst: + if i is not None: + return i + return None + + def modify_object(object): + assert isinstance(object, dict) + return {k: v*2 for k, v in object.items()} + + def modify_list(lst): + assert isinstance(lst, list) + return [x*2 for x in lst] + + def modify_set(st): + assert isinstance(st, set) + return {x*2 for x in st} + + rego.add_extension("custom.triple", 1, triple) + rego.add_extension("custom.negate", 1, negate) + rego.add_extension("custom.first", 1, first) + rego.add_extension("custom.modify_object", 1, modify_object) + rego.add_extension("custom.modify_list", 1, modify_list) + rego.add_extension("custom.modify_set", 1, modify_set) + + i = rego.eval_rule("data.demo.i") + assert i == 30, f"Unexpected result for 'i': {i}" + + f = rego.eval_rule("data.demo.f") + assert f == 7.5, f"Unexpected result for 'f': {f}" + + b1 = rego.eval_rule("data.demo.b1") + assert b1 == False, f"Unexpected result for 'b1': {b1}" + + b2 = rego.eval_rule("data.demo.b2") + assert b2 == True, f"Unexpected result for 'b2': {b2}" + + a = rego.eval_rule("data.demo.a") + assert a == True, f"Unexpected result for 'a': {a}" + + b = rego.eval_rule("data.demo.b") + assert b == 1, f"Unexpected result for 'b': {b}" + + c = rego.eval_rule("data.demo.c") + assert c is None, f"Unexpected result for 'c': {c}" + + obj = rego.eval_rule("data.demo.object") + assert obj == {"a": 2, "b": 4}, f"Unexpected object: {obj}" + + lst = rego.eval_rule("data.demo.list") + assert lst == [6, 8], f"Unexpected list: {lst}" + + st = rego.eval_rule("data.demo.set") + assert st == {10, 12}, f"Unexpected set: {st}" + +test_extension_types()