-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpatch_parser.py
More file actions
131 lines (108 loc) · 4.95 KB
/
patch_parser.py
File metadata and controls
131 lines (108 loc) · 4.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import re
import logging
from typing import Dict, List, TypedDict
logging.basicConfig(level=logging.INFO)
class Patch(TypedDict):
file_path: str
diff: str
class UnitTest(TypedDict):
file_path: str
code: str
def parse_llm_output(text: str) -> Dict[str, List[Patch] | List[UnitTest]]:
"""
Parses an LLM-generated response to extract code patches and unit tests.
Args:
text: The LLM-generated text.
Returns:
A dictionary with two keys:
- "patches": A list of dictionaries, where each dictionary represents a
unified diff patch and contains the "file_path" and "diff" text.
- "unit_tests": A list of dictionaries, where each dictionary
represents a unit test and contains the "file_path" and "code".
"""
patches = _extract_patches(text)
unit_tests = _extract_unit_tests(text)
if not patches:
logging.info("No patches found in the response.")
if not unit_tests:
logging.info("No unit tests found in the response.")
return {"patches": patches, "unit_tests": unit_tests}
def _validate_patch(diff_text: str) -> bool:
"""
Validates the patch to prevent basic TypeErrors.
This is a simplified check. A more robust solution might involve a static analysis tool.
"""
# Avoid adding a string and an integer
if re.search(r"\+\s*.*(\w+\s*\+\s*['\"].*['\"])", diff_text) or re.search(r"\+\s*.*(['\"].*['\"]\s*\+\s*\w+)", diff_text):
# Check if it's a string concatenation
if not re.search(r"\+\s*.*(['\"].*['\"]\s*\+\s*['\"].*['\"])", diff_text):
logging.warning(f"Potential TypeError detected in patch:\n{diff_text}")
return False
return True
def _extract_patches(text: str) -> List[Patch]:
"""Extracts unified diff patches from the text."""
# Regex to find unified diff blocks
diff_pattern = re.compile(
r"diff --git a/(.+) b/(.+)\n--- a/.*\n\+\+\+ b/.*\n@@ .* @@\n([\s\S]*?)(?=\ndiff --git|\Z)",
re.MULTILINE,
)
patches = []
for match in diff_pattern.finditer(text):
file_path = match.group(1)
# Reconstruct the full diff text
diff_text = f"diff --git a/{file_path} b/{match.group(2)}\n--- a/{file_path}\n+++ b/{match.group(2)}\n@@ {match.group(3)}"
if _validate_patch(diff_text):
patches.append({"file_path": file_path, "diff": diff_text})
logging.info(f"Found and validated patch for file: {file_path}")
else:
logging.warning(f"Invalid patch detected for file {file_path}. Skipping.")
return patches
def _extract_file_path(text: str, match_start: int) -> str | None:
"""Extracts the file path for a given match."""
preceding_text = text[:match_start]
# 1. Look for diff headers
diff_header_match = re.findall(r"diff --git a/(\S+) b/\S+", preceding_text)
if diff_header_match:
return diff_header_match[-1]
# 2. Look for --- a/ or +++ b/
plus_minus_header_match = re.findall(r"--- a/(\S+)", preceding_text)
if plus_minus_header_match:
return plus_minus_header_match[-1]
plus_header_match = re.findall(r"\+\+\+ b/(\S+)", preceding_text)
if plus_header_match:
return plus_header_match[-1]
# 3. Fallback to comment lines or markers
marker_match = re.findall(r"(?:#|File:|Path:)\s*([\w/\\-]+\.py)", preceding_text)
if marker_match:
path = marker_match[-1]
return path.replace("\\", "/").strip()
logging.warning("Could not determine file path for a test block.")
return None
def _extract_unit_tests(text: str) -> List[UnitTest]:
"""Extracts Python unit tests from the text."""
unit_tests: List[UnitTest] = []
processed_tests = set()
# Extract from Python code fences
fenced_code_pattern = re.compile(r"```python\n(.*?)\n```", re.DOTALL)
for match in fenced_code_pattern.finditer(text):
code_block = match.group(1)
test_funcs = re.findall(r"def (test_[A-Za-z0-9_]+)", code_block)
if test_funcs:
file_path = _extract_file_path(text, match.start())
for func in test_funcs:
if func not in processed_tests:
unit_tests.append({"file_path": file_path, "code": code_block})
processed_tests.add(func)
logging.info(f"Found fended unit test for file: {file_path}")
break
# Extract inline test definitions
inline_test_pattern = re.compile(r"^(def (test_[A-Za-z0-9_]+)\(.*?\):.*?)(?=\n\n|\Z)", re.DOTALL | re.MULTILINE)
for match in inline_test_pattern.finditer(text):
test_code = match.group(1)
test_func = match.group(2)
if test_func not in processed_tests:
file_path = _extract_file_path(text, match.start())
unit_tests.append({"file_path": file_path, "code": test_code})
processed_tests.add(test_func)
logging.info(f"Found inline unit test for file: {file_path}")
return unit_tests