Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,49 @@ impl Engine {
v.to_json_str()
}

/// Registers a custom Python function as a Rego extension.
///
/// This allows you to define functions in Python that can be called directly
/// from your Rego policies. The Python function will be called synchronously
/// during policy evaluation.
///
/// Arguments passed from Rego are automatically converted to their corresponding
/// Python types. The return value is converted back to a Rego value.
///
/// * `path`: Full path to the function as it will be used in Rego.
/// * `nargs`: The number of arguments the function expects.
/// * `extension`: The Python function to execute. Must accept exactly `nargs` arguments.
///
/// Note: When the engine is cloned, extensions share the same Python callable reference
/// rather than being deep-copied. Stateful callables will share state across clones.
pub fn add_extension(&mut self, path: String, nargs: u8, extension: Py<PyAny>) -> Result<()> {
Python::with_gil(|py| {
if !extension.bind(py).is_callable() {
return Err(anyhow!("extension '{}' must be callable", path));
}
Ok(())
})?;

let func_ref = Arc::new(extension);
let path_clone = path.clone();

let extension_impl = move |args: Vec<Value>| -> Result<Value, anyhow::Error> {
Python::with_gil(|py| {
let py_args_vec: Result<Vec<PyObject>> =
args.into_iter().map(|arg| to(arg, py)).collect();
let py_args = PyTuple::new(py, py_args_vec?)?;
let py_result = func_ref.call1(py, py_args).map_err(|e| {
anyhow!("extension '{}' raises Python error: {}", path_clone, e)
})?;
let rego_result = from(&py_result.into_bound(py))?;
Ok(rego_result)
})
};

self.engine
.add_extension(path, nargs, Box::new(extension_impl))
}

/// Enable code coverage
///
/// * `enable`: Whether to enable coverage or not.
Expand Down
226 changes: 225 additions & 1 deletion bindings/python/test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import regorus
import sys

Expand Down Expand Up @@ -163,3 +162,228 @@ def run_host_await_example():
print(vm.resume('{"tier":"gold"}'))

run_host_await_example()

def test_extension_execution():
rego = regorus.Engine()
rego.add_policy("demo",
"""
package demo

result := greeting(a, b) if {
a := data.a
b := data.b
}
""")

def custom_function(arg1, arg2):
return f"{arg1}, {arg2}!"
rego.add_extension("greeting", 2, custom_function)

rego.add_data({"a": "Hello", "b": "World"})
result = rego.eval_rule("data.demo.result")
assert result == "Hello, World!", f"Unexpected result: {result}"

test_extension_execution()

def test_extension_wrong_arity():
rego = regorus.Engine()
rego.add_policy("demo",
"""
package demo

result := greeting(a, b) if {
a := data.a
b := data.b
}
""")

def custom_function(arg1, arg2):
return f"{arg1}, {arg2}!"

rego.add_extension("greeting", 3, custom_function)
rego.add_data({"a": "Hello", "b": "World"})

try:
rego.eval_rule("data.demo.result")
except RuntimeError as ex:
assert "error: incorrect number of parameters supplied to extension" in str(ex)
else:
assert False, "exception not thrown"

test_extension_wrong_arity()

def test_extension_raises_exception():
rego = regorus.Engine()
rego.add_policy("demo",
"""
package demo

result := greeting(a, b) if {
a := data.a
b := data.b
}
""")

def custom_function(arg1, arg2):
raise RuntimeError("unknown error")

rego.add_extension("greeting", 2, custom_function)
rego.add_data({"a": "Hello", "b": "World"})

try:
rego.eval_rule("data.demo.result")
except RuntimeError as ex:
assert "error: extension 'greeting' raises Python error: RuntimeError: unknown error" in str(ex)
else:
assert False, "exception not thrown"

test_extension_raises_exception()


def test_extension_zero_arg():
rego = regorus.Engine()
rego.add_policy("demo",
"""
package demo

result := greeting()
""")

def custom_function():
return "Hello, World!"

rego.add_extension("greeting", 0, custom_function)
rego.add_data({"a": "Hello", "b": "World"})

result = rego.eval_rule("data.demo.result")
assert result == "Hello, World!", f"Unexpected result: {result}"

test_extension_zero_arg()

def test_extension_non_callable():
rego = regorus.Engine()
rego.add_policy("demo",
"""
package demo

result := greeting()
""")

try:
rego.add_extension("greeting", 0, 123)
except RuntimeError as ex:
assert "extension 'greeting' must be callable" in str(ex)
else:
assert False, "exception not thrown"

test_extension_non_callable()


def test_extension_duplicate():
rego = regorus.Engine()
rego.add_policy("demo",
"""
package demo

result := greeting()
""")

def custom_function1(arg1, arg2):
return f"{arg1}, {arg2}!"
def custom_function2(arg1, arg2):
return f"{arg1}, {arg2}!"

rego.add_extension("greeting", 0, custom_function1)

try:
rego.add_extension("greeting", 0, custom_function2)
except RuntimeError as ex:
assert "extension already added" in str(ex)
else:
assert False, "exception not thrown"

test_extension_duplicate()


def test_extension_types():
rego = regorus.Engine()
rego.add_policy("demo",
"""
package demo

i := custom.triple(10)
f := custom.triple(2.5)
b1 := custom.negate(true)
b2 := custom.negate(false)

a := custom.first([true, null, 1])
b := custom.first([null, null, 1])
c := custom.first([null, null, null])

object := custom.modify_object({"a": 1, "b": 2})
list := custom.modify_list([3, 4])
set := custom.modify_set({5, 6})
""")

def triple(n):
return n*3

def negate(b):
return not b

def first(lst):
for i in lst:
if i is not None:
return i
return None

def modify_object(object):
assert isinstance(object, dict)
return {k: v*2 for k, v in object.items()}

def modify_list(lst):
assert isinstance(lst, list)
return [x*2 for x in lst]

def modify_set(st):
assert isinstance(st, set)
return {x*2 for x in st}

rego.add_extension("custom.triple", 1, triple)
rego.add_extension("custom.negate", 1, negate)
rego.add_extension("custom.first", 1, first)
rego.add_extension("custom.modify_object", 1, modify_object)
rego.add_extension("custom.modify_list", 1, modify_list)
rego.add_extension("custom.modify_set", 1, modify_set)

i = rego.eval_rule("data.demo.i")
assert i == 30, f"Unexpected result for 'i': {i}"

f = rego.eval_rule("data.demo.f")
assert f == 7.5, f"Unexpected result for 'f': {f}"

b1 = rego.eval_rule("data.demo.b1")
assert b1 == False, f"Unexpected result for 'b1': {b1}"

b2 = rego.eval_rule("data.demo.b2")
assert b2 == True, f"Unexpected result for 'b2': {b2}"

a = rego.eval_rule("data.demo.a")
assert a == True, f"Unexpected result for 'a': {a}"

b = rego.eval_rule("data.demo.b")
assert b == 1, f"Unexpected result for 'b': {b}"

c = rego.eval_rule("data.demo.c")
assert c is None, f"Unexpected result for 'c': {c}"

obj = rego.eval_rule("data.demo.object")
assert obj == {"a": 2, "b": 4}, f"Unexpected object: {obj}"

lst = rego.eval_rule("data.demo.list")
assert lst == [6, 8], f"Unexpected list: {lst}"

st = rego.eval_rule("data.demo.set")
assert st == {10, 12}, f"Unexpected set: {st}"

test_extension_types()
Loading