diff --git a/.tekton/on-pull-request.yaml b/.tekton/on-pull-request.yaml index e7225831..b463ef63 100644 --- a/.tekton/on-pull-request.yaml +++ b/.tekton/on-pull-request.yaml @@ -157,6 +157,7 @@ spec: - name: buildah-temp-cache workspace: buildah-temp-cache - name: lint-and-test + timeout: 2h30m0s # Timeout for the task runAfter: - fetch-repository workspaces: diff --git a/src/vuln_analysis/functions/code_agent_graph_defs.py b/src/vuln_analysis/functions/code_agent_graph_defs.py index c1278f49..4f5777ec 100644 --- a/src/vuln_analysis/functions/code_agent_graph_defs.py +++ b/src/vuln_analysis/functions/code_agent_graph_defs.py @@ -48,6 +48,7 @@ from exploit_iq_commons.data_models.checker_status import L2BuildResult, VulnerabilityIntel from exploit_iq_commons.data_models.common import TargetPackage from vuln_analysis.functions.react_internals import CheckerThought, Observation, L1VerdictExtraction +from vuln_analysis.utils.token_utils import count_tokens from vuln_analysis.utils.rpm_checker_prompts import ( L1_VERDICT_EXTRACTION_PROMPT, VULNERABILITY_INTEL_EXTRACTION_PROMPT, @@ -729,6 +730,45 @@ def is_main_source(path: str) -> bool: return "\n".join(lines) +MAX_HUNK_LINES_FOR_INTEL = 10 +MAX_PATCH_TOKENS = 3000 +MAX_PATCH_CHUNKS = 2 + +VULNERABILITY_INTEL_MERGE_LIST_FIELDS = ( + "affected_files", + "vulnerable_functions", + "vulnerable_variables", + "vulnerable_patterns", + "fix_patterns", + "search_keywords", + "component_names", +) + + +def _format_patch_file_lines_for_intel(pf: PatchFile) -> list[str]: + """Format one patch file block for VULNERABILITY_INTEL_EXTRACTION_PROMPT.""" + lines = [f"File: {pf.target_path}"] + for hunk in pf.hunks: + if hunk.removed_lines: + lines.append(" Removed (vulnerable):") + for line in hunk.removed_lines[:MAX_HUNK_LINES_FOR_INTEL]: + lines.append(f" - {line}") + if len(hunk.removed_lines) > MAX_HUNK_LINES_FOR_INTEL: + lines.append( + f" ... (+{len(hunk.removed_lines) - MAX_HUNK_LINES_FOR_INTEL} more lines)" + ) + if hunk.added_lines: + lines.append(" Added (fix):") + for line in hunk.added_lines[:MAX_HUNK_LINES_FOR_INTEL]: + lines.append(f" + {line}") + if len(hunk.added_lines) > MAX_HUNK_LINES_FOR_INTEL: + lines.append( + f" ... (+{len(hunk.added_lines) - MAX_HUNK_LINES_FOR_INTEL} more lines)" + ) + lines.append("") + return lines + + def format_patch_data_for_intel( parsed_patch: ParsedPatch | None ) -> str: @@ -748,29 +788,142 @@ def format_patch_data_for_intel( """ if not parsed_patch: return "" - + lines = [f"Patch: {parsed_patch.patch_filename}", ""] for pf in parsed_patch.files: - lines.append(f"File: {pf.target_path}") - for hunk in pf.hunks: - if hunk.removed_lines: - lines.append(" Removed (vulnerable):") - for line in hunk.removed_lines[:10]: - lines.append(f" - {line}") - if len(hunk.removed_lines) > 10: - lines.append(f" ... (+{len(hunk.removed_lines) - 10} more lines)") - if hunk.added_lines: - lines.append(" Added (fix):") - for line in hunk.added_lines[:10]: - lines.append(f" + {line}") - if len(hunk.added_lines) > 10: - lines.append(f" ... (+{len(hunk.added_lines) - 10} more lines)") - lines.append("") - + lines.extend(_format_patch_file_lines_for_intel(pf)) + return "\n".join(lines) -def get_relevant_hunks(parsed_patch: ParsedPatch | None, grep_query: str) -> str: +def format_patch_data_chunks_for_intel( + parsed_patch: ParsedPatch | None, + max_tokens: int = MAX_PATCH_TOKENS, + max_chunks: int = MAX_PATCH_CHUNKS, +) -> list[str]: + """Split patch intel text into token-bounded chunks for LLM extraction. + + Returns a single-element list with the full formatted patch when it already + fits within max_tokens, preserving identical output to format_patch_data_for_intel(). + """ + if not parsed_patch: + return [""] + + full_text = format_patch_data_for_intel(parsed_patch) + if count_tokens(full_text) <= max_tokens: + return [full_text] + + patch_header = f"Patch: {parsed_patch.patch_filename}\n\n" + header_tokens = count_tokens(patch_header) + body_token_budget = max(max_tokens - header_tokens, 1) + + file_blocks: list[tuple[str, int]] = [] + for pf in parsed_patch.files: + block = "\n".join(_format_patch_file_lines_for_intel(pf)) + file_blocks.append((block, count_tokens(block))) + + if not file_blocks: + return [full_text] + + chunks: list[str] = [] + current_parts: list[str] = [] + current_tokens = 0 + + for file_block, block_tokens in file_blocks: + if block_tokens > body_token_budget: + if current_parts: + chunks.append(patch_header + "\n".join(current_parts)) + if len(chunks) >= max_chunks: + return chunks + current_parts = [] + current_tokens = 0 + truncated = _truncate_diff_by_tokens(file_block, body_token_budget) + chunks.append(patch_header + truncated) + if len(chunks) >= max_chunks: + return chunks + continue + + if current_tokens + block_tokens > body_token_budget and current_parts: + chunks.append(patch_header + "\n".join(current_parts)) + if len(chunks) >= max_chunks: + return chunks + current_parts = [] + current_tokens = 0 + + current_parts.append(file_block) + current_tokens += block_tokens + + if current_parts and len(chunks) < max_chunks: + chunks.append(patch_header + "\n".join(current_parts)) + + return chunks if chunks else [full_text] + + +def merge_vulnerability_intel_chunks( + chunk_intel: list[VulnerabilityIntel], +) -> VulnerabilityIntel: + """Merge structured intel extracted from multiple patch chunks.""" + if not chunk_intel: + return VulnerabilityIntel() + + if len(chunk_intel) == 1: + return chunk_intel[0] + + merged = VulnerabilityIntel() + for intel in chunk_intel: + for field_name in VULNERABILITY_INTEL_MERGE_LIST_FIELDS: + existing = getattr(merged, field_name) + new_values = [value for value in getattr(intel, field_name) if value not in existing] + setattr(merged, field_name, existing + new_values) + + if not merged.root_cause and intel.root_cause: + merged.root_cause = intel.root_cause + if not merged.vulnerability_type and intel.vulnerability_type: + merged.vulnerability_type = intel.vulnerability_type + if not merged.known_mitigations and intel.known_mitigations: + merged.known_mitigations = intel.known_mitigations + if merged.affected_bitness == "both" and intel.affected_bitness != "both": + merged.affected_bitness = intel.affected_bitness + if merged.affected_architectures is None and intel.affected_architectures is not None: + merged.affected_architectures = intel.affected_architectures + + logger.debug( + "merge_vulnerability_intel_chunks: merged %d chunks into %d affected_files, " + "%d search_keywords", + len(chunk_intel), + len(merged.affected_files), + len(merged.search_keywords), + ) + return merged + + +def _truncate_diff_by_tokens(diff_text: str, max_tokens: int) -> str: + """Truncate a diff to fit within max_tokens, preserving complete lines.""" + lines = diff_text.split('\n') + kept_lines: list[str] = [] + kept_tokens = 0 + + for line in lines: + line_tokens = count_tokens(line) + if kept_tokens + line_tokens > max_tokens: + break + kept_lines.append(line) + kept_tokens += line_tokens + + if kept_lines: + truncated_tokens = count_tokens(diff_text) - kept_tokens + if truncated_tokens > 0: + kept_lines.append(f"[... truncated {truncated_tokens} tokens ...]") + return '\n'.join(kept_lines) + return diff_text[:max_tokens * 4] + "\n[... truncated ...]" + + +def get_relevant_hunks( + parsed_patch: ParsedPatch | None, + grep_query: str, + max_tokens: int = MAX_PATCH_TOKENS, + max_chunks: int = MAX_PATCH_CHUNKS, +) -> list[str]: """Extract unified diff hunks for files matching the grep target. Parameters @@ -779,32 +932,71 @@ def get_relevant_hunks(parsed_patch: ParsedPatch | None, grep_query: str) -> str Parsed patch file structure (may be None if no patch available). grep_query: The grep query string, which may include a file filter (e.g., "pattern,filename.c"). + max_tokens: + Maximum tokens per chunk. + max_chunks: + Maximum number of chunks to return. Returns ------- - str - Unified diff format string with relevant hunks, or empty string if no patch/match. + list[str] + List of unified diff chunks, each within max_tokens. Returns [""] if no patch/match. """ if not parsed_patch: - return "" + return [""] file_pattern = None if "," in grep_query: file_pattern = grep_query.split(",")[-1].strip() - hunks = [] + file_diffs: list[tuple[str, int]] = [] for pf in parsed_patch.files: if file_pattern and file_pattern not in pf.target_path: continue - hunks.append(f"--- a/{pf.target_path}") - hunks.append(f"+++ b/{pf.target_path}") + lines = [f"--- a/{pf.target_path}", f"+++ b/{pf.target_path}"] for hunk in pf.hunks: for line in hunk.removed_lines: - hunks.append(f"-\t{line}") + lines.append(f"-\t{line}") for line in hunk.added_lines: - hunks.append(f"+\t{line}") + lines.append(f"+\t{line}") + file_diff = "\n".join(lines) + file_diffs.append((file_diff, count_tokens(file_diff))) + + if not file_diffs: + return [""] + + chunks: list[str] = [] + current_parts: list[str] = [] + current_tokens = 0 + + for file_diff, tokens in file_diffs: + if tokens > max_tokens: + if current_parts: + chunks.append("\n".join(current_parts)) + if len(chunks) >= max_chunks: + return chunks + current_parts = [] + current_tokens = 0 + truncated = _truncate_diff_by_tokens(file_diff, max_tokens) + chunks.append(truncated) + if len(chunks) >= max_chunks: + return chunks + continue + + if current_tokens + tokens > max_tokens and current_parts: + chunks.append("\n".join(current_parts)) + if len(chunks) >= max_chunks: + return chunks + current_parts = [] + current_tokens = 0 + + current_parts.append(file_diff) + current_tokens += tokens + + if current_parts and len(chunks) < max_chunks: + chunks.append("\n".join(current_parts)) - return "\n".join(hunks) if hunks else "" + return chunks if chunks else [""] # --------------------------------------------------------------------------- diff --git a/src/vuln_analysis/functions/cve_package_code_agent.py b/src/vuln_analysis/functions/cve_package_code_agent.py index 035396c6..dc4d7a72 100644 --- a/src/vuln_analysis/functions/cve_package_code_agent.py +++ b/src/vuln_analysis/functions/cve_package_code_agent.py @@ -42,7 +42,8 @@ upstream_search_preprocess, extract_l1_verdict, VulnerabilityIntel, - format_patch_data_for_intel, + format_patch_data_chunks_for_intel, + merge_vulnerability_intel_chunks, get_relevant_hunks, ReferenceHints, AdvisoryContent, @@ -72,7 +73,7 @@ from vuln_analysis.utils.vulnerability_intel_sanitizer import VulnerabilityIntelSanitizer from vuln_analysis.utils.reference_fetcher import ReferenceFetcher from vuln_analysis.utils.reference_parser import ReflectiveReferenceParser, ParserConfig -from vuln_analysis.utils.token_utils import truncate_tool_output, truncate_tool_output_list +from vuln_analysis.utils.token_utils import truncate_tool_output, truncate_tool_output_list, count_tokens from vuln_analysis.runtime_context import ctx_state logger = LoggingFactory.get_agent_logger(__name__) @@ -513,13 +514,10 @@ async def L1_agent(state: CodeAgentState) -> dict: if downstream_report and downstream_report.is_patch_file_available: parsed_patch = downstream_report.parsed_patch - patch_data = format_patch_data_for_intel(parsed_patch) elif upstream_report and upstream_report.is_fixed_srpm_is_needed: parsed_patch = upstream_report.fixed_parsed_patch - patch_data = format_patch_data_for_intel(parsed_patch) elif git_search_report and git_search_report.parsed_patch: parsed_patch = git_search_report.parsed_patch - patch_data = format_patch_data_for_intel(parsed_patch) logger.info( "L1_agent: Using discovered patch from git search (commit=%s, confidence=%.2f)", git_search_report.best_result.commit_hash_short if git_search_report.best_result else "unknown", @@ -527,7 +525,21 @@ async def L1_agent(state: CodeAgentState) -> dict: ) else: parsed_patch = None - patch_data = "" + + max_patch_chunk_tokens = ( + config.context_window_token_limit - config.reference_mining_prompt_overhead + ) + patch_chunks = ( + format_patch_data_chunks_for_intel(parsed_patch, max_tokens=max_patch_chunk_tokens) + if parsed_patch + else [""] + ) + if len(patch_chunks) > 1: + logger.info( + "L1_agent: chunking patch intel extraction into %d chunks (max_tokens=%d)", + len(patch_chunks), + max_patch_chunk_tokens, + ) # Extract vendor mitigations from intel (OSIDB preferred, fallback to RHSA) vendor_mitigations = "" @@ -540,16 +552,19 @@ async def L1_agent(state: CodeAgentState) -> dict: if mit_text: vendor_mitigations = mit_text - vul_prompt = VULNERABILITY_INTEL_EXTRACTION_PROMPT.format( - vuln_id=vuln_id, - target_package=target_package.name, - cve_description=cve_description, - vendor_mitigations=vendor_mitigations or "No vendor mitigations available.", - patch_data=patch_data, - ) - vulnerability_intel: VulnerabilityIntel = await vulnerability_intel_llm.ainvoke( - [SystemMessage(content=vul_prompt)], - ) + chunk_intel_results: list[VulnerabilityIntel] = [] + for patch_chunk in patch_chunks: + vul_prompt = VULNERABILITY_INTEL_EXTRACTION_PROMPT.format( + vuln_id=vuln_id, + target_package=target_package.name, + cve_description=cve_description, + vendor_mitigations=vendor_mitigations or "No vendor mitigations available.", + patch_data=patch_chunk, + ) + chunk_intel_results.append( + await vulnerability_intel_llm.ainvoke([SystemMessage(content=vul_prompt)]) + ) + vulnerability_intel = merge_vulnerability_intel_chunks(chunk_intel_results) vulnerability_intel = VulnerabilityIntelSanitizer(parsed_patch).apply( vulnerability_intel ) @@ -1366,57 +1381,81 @@ async def observation_node(state: CodeAgentState) -> dict: downstream_report is not None, upstream_report is not None) - # Extract relevant hunks based on grep target file - raw_patch_diff = "" + # Extract relevant hunks based on grep target file (chunked by token limit) + patch_diff_chunks = [""] if tool_used == "Source Grep" and parsed_patch: - raw_patch_diff = get_relevant_hunks(parsed_patch, tool_input_detail) - + patch_diff_chunks = get_relevant_hunks(parsed_patch, tool_input_detail) + if empty_findings: code_findings = empty_findings elif needs_llm_classification: # Empty source grep - use classification prompt to determine meaning - classification_prompt = L1_EMPTY_RESULT_CLASSIFICATION_PROMPT.format( - tool_used=tool_used, - last_thought=last_thought_text, - tool_input=tool_input_detail, - raw_patch_diff=raw_patch_diff if raw_patch_diff else "No patch diff available", - ) - code_findings = await structured_comprehension_llm.ainvoke( - [SystemMessage(content=classification_prompt)] + # Loop over patch chunks to stay within context window + all_classification_findings: list[str] = [] + last_tool_outcome = "No matches found" + for patch_chunk in patch_diff_chunks: + classification_prompt = L1_EMPTY_RESULT_CLASSIFICATION_PROMPT.format( + tool_used=tool_used, + last_thought=last_thought_text, + tool_input=tool_input_detail, + raw_patch_diff=patch_chunk if patch_chunk else "No patch diff available", + ) + chunk_result = await structured_comprehension_llm.ainvoke( + [SystemMessage(content=classification_prompt)] + ) + all_classification_findings.extend(chunk_result.findings) + last_tool_outcome = chunk_result.tool_outcome + code_findings = CodeFindings( + findings=all_classification_findings, + tool_outcome=last_tool_outcome ) logger.debug("Empty source grep classified: %s", code_findings.findings) else: - # Has actual content - split into chunks and process each - chunks = truncate_tool_output_list(tool_output_for_llm, tool_used, max_tokens=1000) - all_findings = [] + # Has actual content - double loop over patch chunks and tool output chunks + tool_chunks = truncate_tool_output_list(tool_output_for_llm, tool_used, max_tokens=1000) + all_findings: list[str] = [] best_tool_outcome = "" - for chunk in chunks: - comp_prompt = L1_COMPREHENSION_PROMPT.format( - vuln_id=vuln_id, - target_package=target_package_name, - vulnerability_intel=intel_formatted, - raw_patch_diff=raw_patch_diff, - tool_used=tool_used, - tool_input=tool_input_detail, - last_thought=last_thought_text, - tool_output=chunk, - ) - chunk_findings = await invoke_comprehension( - structured_comprehension_llm, - comp_prompt, - tool_used, - tool_input_detail, - chunk, - agent_label="L1", - ) - all_findings.extend(chunk_findings.findings) - # Keep tool_outcome from chunk with actual findings (not FAILED) - if not best_tool_outcome or ( - chunk_findings.findings and - not any("FAILED" in f for f in chunk_findings.findings) - ): - best_tool_outcome = chunk_findings.tool_outcome + for patch_chunk in patch_diff_chunks: + for tool_chunk in tool_chunks: + logger.debug( + "Comprehension token breakdown: " + "intel=%d, patch_chunk=%d, tool_chunk=%d, last_thought=%d, " + "tool_input=%d, total_parts=%d", + count_tokens(intel_formatted), + count_tokens(patch_chunk), + count_tokens(tool_chunk), + count_tokens(last_thought_text), + count_tokens(tool_input_detail), + count_tokens(intel_formatted) + count_tokens(patch_chunk) + + count_tokens(tool_chunk) + count_tokens(last_thought_text) + + count_tokens(tool_input_detail), + ) + comp_prompt = L1_COMPREHENSION_PROMPT.format( + vuln_id=vuln_id, + target_package=target_package_name, + vulnerability_intel=intel_formatted, + raw_patch_diff=patch_chunk, + tool_used=tool_used, + tool_input=tool_input_detail, + last_thought=last_thought_text, + tool_output=tool_chunk, + ) + logger.debug("Comprehension total prompt tokens: %d", count_tokens(comp_prompt)) + chunk_findings = await invoke_comprehension( + structured_comprehension_llm, + comp_prompt, + tool_used, + tool_input_detail, + tool_chunk, + agent_label="L1", + ) + all_findings.extend(chunk_findings.findings) + if not best_tool_outcome or ( + chunk_findings.findings and + not any("FAILED" in f for f in chunk_findings.findings) + ): + best_tool_outcome = chunk_findings.tool_outcome code_findings = CodeFindings( findings=all_findings, diff --git a/src/vuln_analysis/tools/tests/test_transitive_code_search.py b/src/vuln_analysis/tools/tests/test_transitive_code_search.py index 80d3dd8f..ddae2480 100644 --- a/src/vuln_analysis/tools/tests/test_transitive_code_search.py +++ b/src/vuln_analysis/tools/tests/test_transitive_code_search.py @@ -367,7 +367,7 @@ async def test_c_transitive_search_2(): #create sample sbom packages sbom_list = [ - SBOMPackage(name='openssl', version='3.5.1-5.el9', path=None, system='rpm'), + SBOMPackage(name='openssl', version='3.5.1-7.el9_7', path=None, system='rpm'), SBOMPackage(name='libxml2', version='2.9.13-9.el9_6', path=None, system='rpm'), SBOMPackage(name='libxslt', version='1.1.34-13.el9_6', path=None, system='rpm') ]