Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions .github/workflows/stubs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,17 @@ jobs:
with:
python-version: ${{ env.PYTHON_VERSION }}

- name: Install mypy
run: |
python -m pip install mypy
- name: Update pip
run: python -m pip install --upgrade pip

- name: Install PySCIPOpt
run: |
export CFLAGS="-O0 -ggdb -Wall -Wextra -Werror -Wno-error=deprecated-declarations" # Debug mode. More warnings. Warnings as errors, but allow deprecated declarations.
python -m pip install . -v 2>&1 | tee build.log

- name: Install typing dependencies
run: |
python -m pip install -r requirements/pylock.types.toml

- name: Run MyPy
run: python -m mypy --package pyscipopt
Expand All @@ -52,6 +55,15 @@ jobs:
id: stubtest
run: stubs/test.sh

- name: Check baseline test files are up to date
run: |
./stubs/baseline.sh
# we need to ignore the .deb file we download above
if [[ -n $(git status --porcelain --untracked-files=no) ]]; then
echo "Baseline test files are out of date, run: ./stubs/baseline.sh on Python ${{ env.PYTHON_VERSION }} to update"
exit 1
fi

- name: Stub regeneration hint
if: failure() && steps.stubtest.outcome == 'failure'
run: |
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,6 @@ model.lp
# VSCode
.vscode/
.devcontainer/

# generated test files
tests/@types/expr.py
596 changes: 596 additions & 0 deletions requirements/pylock.types.toml

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions requirements/types.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.
mypy
320 changes: 320 additions & 0 deletions scripts/generate_expr_type_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
"""
This script generates test cases for expression arithmetic type annotations.

It evaluates the runtime output of all combinations of arithmetic operations between different types and generates test cases
that check whether the static type annotations match the actual runtime types of the results.
"""

import argparse
import logging

import itertools
import operator
from pathlib import Path

import pyscipopt


logger = logging.getLogger(__name__)
INDENT = " " * 4


# Initial lines at the start of the generated test file.
GLOBAL_STATEMENTS = [
"# @generated by scripts/generate_expr_type_tests.py - do not edit manually",
"",
"import decimal",
"import random",
"",
"import numpy",
"from typing_extensions import assert_type",
"",
"import pyscipopt.scip",
"",
"",
"model = pyscipopt.scip.Model()",
]

# Expressions to test, mapped from a name to the expression that will be evaluated at runtime to get the value for that name.
# These should cover all interesting types for arithmetic operations.
# Order matters since later expressions can refer to previously defined names.
EXPRESSIONS = {
# Variables
"var": "model.addVar()",
"mvar1d": "model.addMatrixVar(3)",
"mvar2d": "model.addMatrixVar((3, 3))",
"term": "pyscipopt.scip.Term(var)",
# Expressions
"constant": "pyscipopt.scip.Constant(-2.0)",
"expr": "var + 1",
"matrix_expr": "mvar2d * 2",
"sum_expr": "var + constant",
"prod_expr": "var * constant",
"pow_expr": "prod_expr**2",
"unary_expr": "abs(var)",
"var_expr": "pyscipopt.scip.VarExpr(var)",
# Constraints
"exprcons": "var <= 3",
"matrixexprcons": "mvar1d <= 3",
# Builtin numbers
"integer": "random.randint(1, 10)",
"floating_point": "random.random()",
"dec": 'decimal.Decimal("1.0")',
# NumPy arrays
"np_float": "numpy.float64(3.0)",
"array0d": "numpy.array(1)",
"array1d": "numpy.array([1, 2, 3])",
"array2d": "numpy.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])",
}

# Mappings from operator symbols to their corresponding operator functions.
# No spaces are added, so spacing must be added to the operator symbols.
BINARY_OPERATORS = {
" + ": operator.add,
" - ": operator.sub,
" * ": operator.mul,
" / ": operator.truediv,
"**": operator.pow,
" < ": operator.lt,
" <= ": operator.le,
" > ": operator.gt,
" >= ": operator.ge,
" == ": operator.eq,
" != ": operator.ne,
" @ ": operator.matmul,
}

INPLACE_BINARY_OPERATORS = {
"+=": operator.iadd,
"-=": operator.isub,
"*=": operator.imul,
"/=": operator.itruediv,
"**=": operator.ipow,
"@=": operator.imatmul,
}

# Operator function and string with a formatting placeholder for the operation.
UNARY_OPERATORS = [
("+{}", operator.pos),
("-{}", operator.neg),
("abs({})", abs),
("pyscipopt.exp({})", pyscipopt.exp),
("pyscipopt.log({})", pyscipopt.log),
("pyscipopt.sqrt({})", pyscipopt.sqrt),
("pyscipopt.sin({})", pyscipopt.sin),
("pyscipopt.cos({})", pyscipopt.cos),
("{}.sum()", lambda x: x.sum()),
("{}.sum(axis=-1)", lambda x: x.sum(axis=-1)),
]


def build_runtime_values(expressions: dict[str, str]) -> dict[str, object]:
"""Evaluate the expressions and return a mapping from expression names to their runtime values.

Expressions are evaluated in order, so that later expressions can refer to previously defined members.
"""
eval_scope = {}

for statement in GLOBAL_STATEMENTS:
logger.debug(f"Executing statement: {statement}")
exec(statement, {}, eval_scope)

for name, expr in expressions.items():
logger.debug(f"Evaluating expression for {name}: {expr}")
eval_scope[name] = eval(expr, {}, eval_scope)

return eval_scope


def generate_erroring_line(
expr: str,
error: Exception,
indent: str = "",
inplace: bool = False,
) -> str:
"""Generate a line in the generated test file for an expression that produces a runtime error.

Expressions that error at runtime have a type ignore comment to indicate that they are expected to produce a type error.
"""
# Add a fake assignment to prevent "unused expression" errors
expr = f"{indent}{expr}" if inplace else f"{indent}_ = {expr}"
error_message = str(error).replace("\n", "").strip()
return f"{expr} # type: ignore # {error.__class__.__name__}: {error_message}"


def type_name(obj: object) -> str:
"""Get the fully qualified type name of an object."""
if obj.__class__.__module__ == "builtins":
return obj.__class__.__name__
return f"{obj.__class__.__module__}.{obj.__class__.__name__}"


def generate_result_expectation(expr: str, result: object, indent: str = "") -> str:
"""Generate a line in the generated test file for an expression that produces a result without error.

The result type at runtime is used in an `assert_type` call to check that the static type annotations match the actual runtime type of the result.
"""
runtime_type_name = type_name(result)
return f"{indent}assert_type({expr}, {runtime_type_name})"


def no_pyscipopt_objs(*objs: object) -> bool:
"""Check is there are no objects from the `pyscipopt` module in the given objects.
If so, we can skip generating test cases for the expression since it won't involve any `pyscipopt` types that we care about testing.
"""
return not any(obj.__class__.__module__.startswith("pyscipopt") for obj in objs)


def generate_test_cases():
"""Build the test file content by evaluating the expressions and generating test cases for their results.

There are 4 phases:
1. Evaluate all the expressions in `EXPRESSIONS` and store their runtime values.
2. Generate test cases for unary operators applied to each expression in `EXPRESSIONS`.
3. Generate test cases for binary operators applied to all pairs of expressions in `EXPRESSIONS`.
4. Generate test cases for inplace binary operators applied to all pairs of expressions in `EXPRESSIONS`.
"""
runtime_values = build_runtime_values(EXPRESSIONS)

lines = [*GLOBAL_STATEMENTS, "", ""]

for name, expr in EXPRESSIONS.items():
# Define the value from the expression
lines.append(f"{name} = {expr}")
# Check it has the expected type
lines.append(generate_result_expectation(name, runtime_values[name]))

lines.extend(
[
"",
"###################",
"# Unary operators #",
"###################",
"",
]
)

for name in EXPRESSIONS:
if no_pyscipopt_objs(runtime_values[name]):
continue
lines.extend([f"# Unary operators for {name}", ""])
success_lines = []
failure_lines = []
for op_repr, op_func in UNARY_OPERATORS:
expr = op_repr.format(name)
logger.debug(f"Evaluating unary operator {expr}")
try:
result = op_func(runtime_values[name])
except Exception as e:
failure_lines.append(generate_erroring_line(expr, e))
else:
success_lines.append(generate_result_expectation(expr, result))
if success_lines:
lines.extend([*success_lines, ""])
if failure_lines:
lines.extend([*failure_lines, ""])

lines.extend(
[
"####################",
"# Binary operators #",
"####################",
"",
]
)

for left, right in itertools.product(EXPRESSIONS, repeat=2):
if no_pyscipopt_objs(runtime_values[left], runtime_values[right]):
continue
lines.extend([f"# Binary operators for {left} and {right}", ""])
success_lines = []
failure_lines = []
for op_symbol, op_func in BINARY_OPERATORS.items():
logger.debug(
f"Evaluating binary operator {op_symbol} for {left} and {right}"
)
expr = f"{left}{op_symbol}{right}"
try:
result = op_func(runtime_values[left], runtime_values[right])
except Exception as e:
failure_lines.append(generate_erroring_line(expr, e))
else:
success_lines.append(generate_result_expectation(expr, result))
if success_lines:
lines.extend([*success_lines, ""])
if failure_lines:
lines.extend([*failure_lines, ""])

lines.extend(
[
"#####################",
"# Inplace operators #",
"#####################",
]
)

for left, right in itertools.product(EXPRESSIONS, repeat=2):
if no_pyscipopt_objs(runtime_values[left], runtime_values[right]):
continue
lines.extend(["", f"# Inplace operators for {left} and {right}", ""])
for op_symbol, op_func in INPLACE_BINARY_OPERATORS.items():
logger.debug(
f"Evaluating inplace binary operator {op_symbol} for {left} and {right}"
)

# For inplace tests, the target gets modified and can change type.
# To avoid influencing other tests, we wrap each case in a function
# and create a fresh target for the test.
# The function simply calls the inplace operator on the target (left)
# and then checks what type it has after the operation.
target_name = f"{left}_{op_func.__name__}_{right}"
stmt = f"{target_name} {op_symbol} {right}"
lines.extend(
[
"",
f"def test_inplace_{target_name}() -> None:",
f"{INDENT}{target_name} = {EXPRESSIONS[left]}",
]
)

function_scope_locals = runtime_values.copy()
# 1. create the temporary target
exec(f"{target_name} = {EXPRESSIONS[left]}", {}, function_scope_locals)
try:
# 2. apply the inplace operator
exec(stmt, {}, function_scope_locals)
except Exception as e:
lines.append(
generate_erroring_line(stmt, e, indent=INDENT, inplace=True)
)
else:
# 3. fetch the resulting type
new_type = function_scope_locals[target_name]
lines.append(f"{INDENT}{stmt}")
lines.append(
generate_result_expectation(target_name, new_type, indent=INDENT)
)
lines.append("")

return "\n".join(lines)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--output",
"-o",
type=Path,
default=Path(__file__).parent.parent / "tests" / "@types" / "expr.py",
)
parser.add_argument("-v", "--verbose", action="store_true", default=0)
args = parser.parse_args()

logging.basicConfig()
logger.setLevel(logging.DEBUG if args.verbose else logging.WARNING)

test_cases = generate_test_cases()
target = Path(args.output)
target.parent.mkdir(parents=True, exist_ok=True)
with target.open("w") as f:
f.write(test_cases)
13 changes: 13 additions & 0 deletions stubs/baseline.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash -e

# Update baseline test files

REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"

python "$REPO_ROOT"/scripts/generate_expr_type_tests.py

for test_file in "$REPO_ROOT"/tests/@types/*.py; do
echo "Updating mypy baseline for $test_file"
output_file="${test_file%.*}.mypy.out"
python -m mypy "$test_file" --warn-unused-ignores | grep "error:" > "$output_file"
done
Loading
Loading