From da0bd5d7c9447bf15fb69853a9c3e2c0b269b5a9 Mon Sep 17 00:00:00 2001 From: Yifeng Lu Date: Tue, 2 Sep 2025 11:46:53 -0700 Subject: [PATCH] Improve autofix by adding schema as input for LLM. PiperOrigin-RevId: 802232318 --- langfun/core/coding/python/correction.py | 23 +++++++++++- langfun/core/coding/python/correction_test.py | 37 +++++++++++++++++++ langfun/core/structured/schema.py | 6 ++- langfun/core/structured/schema_test.py | 22 +++++++++++ 4 files changed, 86 insertions(+), 2 deletions(-) diff --git a/langfun/core/coding/python/correction.py b/langfun/core/coding/python/correction.py index c2480a2a..78f3df20 100644 --- a/langfun/core/coding/python/correction.py +++ b/langfun/core/coding/python/correction.py @@ -15,12 +15,14 @@ from typing import Any import langfun.core as lf from langfun.core.coding.python import execution +from langfun.core.structured import schema as schema_lib import pyglove as pg class CodeWithError(pg.Object): """Python code with error.""" + schema_definition: str | None code: str error: str @@ -42,6 +44,7 @@ def run_with_correction( returns_code: bool = False, returns_stdout: bool = False, outputs_intermediate: bool = False, + schema: schema_lib.Schema | None = None, ) -> Any | tuple[Any, str]: """Correct code with a language model via self-play. @@ -68,6 +71,7 @@ def run_with_correction( outputs_intermediate: If True, intermediate output will be outputted as a dict, with the last line's value accessible by key '__result__'. Otherwise the value of the last line will be returned. + schema: Optional schema for the expected output. Returns: Run result if `returns_code` is set to False (default), otherwise a tuple @@ -83,6 +87,13 @@ def run_with_correction( # pytype: enable=import-error # pylint: enable=g-import-not-at-top + if schema is not None: + if isinstance(schema, type): + schema = schema_lib.Schema.from_value(schema) + schema_definition = schema.schema_str(protocol="python") + else: + schema_definition = None + if max_attempts == 0: result = _maybe_custom_validate( execution.run( @@ -126,7 +137,14 @@ def result_and_error(code: str) -> tuple[Any, str | None]: try: # Disable autofix for code correction to avoid recursion. correction = querying.query( - CodeWithError(code=code, error=error), CorrectedCode, lm=lm, autofix=0 + CodeWithError( + schema_definition=schema_definition, + code=code, + error=error, + ), + CorrectedCode, + lm=lm, + autofix=0, ) except pg.coding.CodeError: break @@ -148,6 +166,7 @@ def result_and_error(code: str) -> tuple[Any, str | None]: def correct( code: str, error: str | None = None, + schema: schema_lib.Schema | None = None, *, global_vars: dict[str, Any] | None = None, lm: lf.LanguageModel = lf.contextual(), @@ -162,6 +181,7 @@ def correct( error: An optional initial error for `code` when it's problematic, usually caught from elsewhere when it ran. If None, code will be executed once to verify if its good and obtain a feedback error message. + schema: Optional schema for the expected output. global_vars: A dict of str to value as the global variables that could be accessed within the corrected code. lm: Language model to be used. If not specified, it will try to use the `lm` @@ -183,6 +203,7 @@ def correct( return run_with_correction( code, error=error, + schema=schema, global_vars=global_vars, lm=lm, max_attempts=max_attempts, diff --git a/langfun/core/coding/python/correction_test.py b/langfun/core/coding/python/correction_test.py index 1912f2ec..44546b09 100644 --- a/langfun/core/coding/python/correction_test.py +++ b/langfun/core/coding/python/correction_test.py @@ -45,6 +45,43 @@ def test_run_with_correction(self): ) self.assertEqual(result, 4) + def test_run_with_correction_with_schema(self): + class Flight(pg.Object): + airline: str + flight_number: str + + class Result(pg.Object): + flights: list[Flight] + + result = correction.run_with_correction( + inspect.cleandoc(""" + Result( + flights=[ + Flight(airline='DELTA', flight_number='DL123'), + Flight(airline='UNITED', flight_number='UA456'), + ] + ) + """), + schema=Result, + global_vars=dict(Result=Result, Flight=Flight), + lm=fake.StaticSequence([ + inspect.cleandoc(""" + CorrectedCode( + corrected_code='Result(flights=[Flight(airline="DELTA", flight_number="DL123"), Flight(airline="UNITED", flight_number="UA456")])', + ) + """), + ]), + ) + self.assertEqual( + result, + Result( + flights=[ + Flight(airline='DELTA', flight_number='DL123'), + Flight(airline='UNITED', flight_number='UA456'), + ] + ), + ) + def test_run_with_correction_upon_custom_validation(self): class Foo(pg.Object): diff --git a/langfun/core/structured/schema.py b/langfun/core/structured/schema.py index 10eefb70..69bf075e 100644 --- a/langfun/core/structured/schema.py +++ b/langfun/core/structured/schema.py @@ -22,7 +22,6 @@ import typing from typing import Any, Literal, Sequence, Type, Union import langfun.core as lf -from langfun.core.coding.python import correction import pyglove as pg @@ -747,6 +746,7 @@ def parse( global_vars.update({d.__name__: d for d in dependencies}) return structure_from_python( text, + schema=schema, global_vars=global_vars, autofix=autofix, autofix_lm=autofix_lm, @@ -757,6 +757,7 @@ def parse( def structure_from_python( code: str, *, + schema: Schema | None = None, global_vars: dict[str, Any] | None = None, permission: pg.coding.CodePermission = ( pg.coding.CodePermission.ASSIGN | pg.coding.CodePermission.CALL @@ -765,6 +766,8 @@ def structure_from_python( autofix_lm: lf.LanguageModel = lf.contextual(), ) -> Any: """Evaluates structure from Python code with access to symbols.""" + from langfun.core.coding.python import correction # pylint: disable=g-import-not-at-top # pytype: disable=import-error + global_vars = global_vars or {} global_vars.update({ 'pg': pg, @@ -787,6 +790,7 @@ def structure_from_python( max_attempts=autofix, lm=autofix_lm, permission=permission, + schema=schema, ) diff --git a/langfun/core/structured/schema_test.py b/langfun/core/structured/schema_test.py index e6d1d7bd..f8fd5b26 100644 --- a/langfun/core/structured/schema_test.py +++ b/langfun/core/structured/schema_test.py @@ -811,6 +811,28 @@ class A(pg.Object): A([Foo(1), Foo(2)], y='bar'), ) + def test_parse_with_correction_with_schema(self): + class Flight(pg.Object): + airline: str + flight_number: str + + class Result(pg.Object): + flights: list[Flight] + + self.assertEqual( + schema_lib.ValuePythonRepr().parse( + "Result(flights=[Flight(airline='DELTA', flight_number='DL123')])", + schema_lib.Schema(Result), + autofix=1, + autofix_lm=fake.StaticResponse(inspect.cleandoc(""" + CorrectedCode( + corrected_code="Result(flights=[Flight(airline='DELTA', flight_number='DL123')])", + ) + """)), + ), + Result(flights=[Flight(airline='DELTA', flight_number='DL123')]), + ) + def test_parse_class_def(self): self.assertTrue( inspect.isclass(