diff --git a/src/vllm/parsers.rs b/src/vllm/parsers.rs index 61fd497..4dce493 100644 --- a/src/vllm/parsers.rs +++ b/src/vllm/parsers.rs @@ -62,14 +62,24 @@ impl VllmState { for subgraph in subgraphs.iter() { let size_or_range = subgraph.size_or_range(); - let artifact_count = subgraph.artifacts.len(); + let (pass_artifacts, artifacts): (Vec<_>, Vec<_>) = subgraph + .artifacts + .iter() + .cloned() + .partition(|a| a.name.contains("vllm_post_grad.")); + let artifact_count = artifacts.len(); + let pass_artifact_count = pass_artifacts.len(); + let has_pass_artifacts = pass_artifact_count > 0; groups .entry(size_or_range) .or_default() .push(VllmSubgraphWithArtifacts { submod_name: subgraph.display_submod_name(), - artifacts: subgraph.artifacts.clone(), + artifacts, artifact_count, + pass_artifacts, + pass_artifact_count, + has_pass_artifacts, }); } @@ -83,6 +93,16 @@ impl VllmState { .collect() } + // Get pattern artifacts from pre_subgraph_artifacts + pub fn build_pattern_artifacts(&self) -> Vec { + self.pre_subgraph_artifacts + .borrow() + .iter() + .filter(|a| a.name.starts_with("vllm_patterns.")) + .cloned() + .collect() + } + // Get dynamo artifacts from pre_subgraph_artifacts pub fn build_dynamo_artifacts(&self) -> Vec { let dynamo_names = [ @@ -256,7 +276,10 @@ impl StructuredLogParser for VllmPiecewiseSplitGraphParser { // 1. "before_post_grad_graph" artifact — the graph before any passes run. // Stored as the diff baseline; no file output (ArtifactParser handles that). // -// 2. "vllm_post_grad.." graph dump — the graph after a pass. +// 2. "vllm_patterns." graph dump — pattern matcher patterns. +// Output as a standalone .py file (no diffing). +// +// 3. "vllm_post_grad.." graph dump — the graph after a pass. // Diffed against `previous_payload` to produce a side-by-side HTML diff, // then becomes the new baseline for the next pass. pub struct VllmPostGradPassDiffParser { @@ -403,7 +426,9 @@ impl StructuredLogParser for VllmPostGradPassDiffParser { fn get_metadata<'e>(&self, e: &'e Envelope) -> Option> { if let Some(graph_dump) = &e.graph_dump { - if graph_dump.name.starts_with("vllm_post_grad.") { + if graph_dump.name.starts_with("vllm_post_grad.") + || graph_dump.name.starts_with("vllm_patterns.") + { return Some(Metadata::GraphDump(graph_dump)); } } @@ -437,6 +462,13 @@ impl StructuredLogParser for VllmPostGradPassDiffParser { *self.state.has_vllm_artifacts.borrow_mut() = true; + // Handle vllm_patterns.* graph dumps: output as standalone .py file + if graph_dump.name.starts_with("vllm_patterns.") { + let filename = format!("{}.py", graph_dump.name); + let f = build_file_path(&filename, lineno, compile_id); + return Ok(vec![ParserOutput::PayloadFile(f)]); + } + // e.g. "vllm_post_grad.0.FusionPass" -> pass_name = "0.FusionPass" let pass_name = graph_dump .name @@ -510,6 +542,8 @@ pub fn generate_vllm_summary( .unwrap_or_default(); let dynamo_artifacts = state.build_dynamo_artifacts(); let has_dynamo_artifacts = !dynamo_artifacts.is_empty(); + let pattern_artifacts = state.build_pattern_artifacts(); + let has_pattern_artifacts = !pattern_artifacts.is_empty(); let piecewise_graph_file = state.piecewise_graph_file.borrow().clone(); let has_piecewise = piecewise_graph_file.is_some(); let compile_range_groups = state.build_compile_range_groups(); @@ -522,6 +556,8 @@ pub fn generate_vllm_summary( config, dynamo_artifacts, has_dynamo_artifacts, + pattern_artifacts, + has_pattern_artifacts, piecewise_graph_file, has_piecewise, compile_range_groups, diff --git a/src/vllm/templates.rs b/src/vllm/templates.rs index 3a41a17..95bd720 100644 --- a/src/vllm/templates.rs +++ b/src/vllm/templates.rs @@ -276,6 +276,17 @@ pub const VLLM_SUMMARY_TEMPLATE: &str = r#" {{ endif }} + {{ if has_pattern_artifacts }} +

Inductor Pass Patterns

+
+
    + {{ for artifact in pattern_artifacts }} +
  • {artifact.name} {artifact.suffix}
  • + {{ endfor }} +
+
+ {{ endif }} +

Inductor Compilation

{{ for group in compile_range_groups }} @@ -299,6 +310,18 @@ pub const VLLM_SUMMARY_TEMPLATE: &str = r#" {{ endif }} + {{ if subgraph.has_pass_artifacts }} +
+
+ Inductor Pass Graphs & Diffs ({subgraph.pass_artifact_count} files) +
    + {{ for artifact in subgraph.pass_artifacts }} +
  • {artifact.name} {artifact.suffix}
  • + {{ endfor }} +
+
+
+ {{ endif }} {{ endfor }} diff --git a/src/vllm/types.rs b/src/vllm/types.rs index 3dbdfc9..30c9872 100644 --- a/src/vllm/types.rs +++ b/src/vllm/types.rs @@ -63,6 +63,8 @@ pub struct VllmSummaryContext { pub has_config: bool, pub dynamo_artifacts: Vec, pub has_dynamo_artifacts: bool, + pub pattern_artifacts: Vec, + pub has_pattern_artifacts: bool, pub piecewise_graph_file: Option, pub has_piecewise: bool, pub compile_range_groups: Vec, @@ -81,6 +83,9 @@ pub struct VllmSubgraphWithArtifacts { pub submod_name: String, pub artifacts: Vec, pub artifact_count: usize, + pub pass_artifacts: Vec, + pub pass_artifact_count: usize, + pub has_pass_artifacts: bool, } #[derive(Debug, Clone, Serialize)] diff --git a/tests/inputs/vllm_post_grad_diff.log b/tests/inputs/vllm_post_grad_diff.log index e679eaf..43d2218 100644 --- a/tests/inputs/vllm_post_grad_diff.log +++ b/tests/inputs/vllm_post_grad_diff.log @@ -1,3 +1,10 @@ +V0127 17:17:45.075000 1175001 torch/foo.py:50] {"graph_dump": {"name": "vllm_patterns.FusionPass"}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "b2984d935e100cb9f04245b5b6b51833"} + # Patterns for FusionPass + def pattern_0(): + x = KeywordArg("x") + sin = CallFunction(torch.sin, x) + cos = CallFunction(torch.cos, x) + return CallFunction(operator.add, sin, cos) V0127 17:17:45.175000 1175001 torch/foo.py:100] {"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "e830dc536dd44eda7a0b9e5b2440b620"} def forward(self, x): a = torch.sin(x) diff --git a/tests/integration_test.rs b/tests/integration_test.rs index b07883f..162dc59 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -2743,6 +2743,9 @@ fn test_parse_vllm_post_grad_diff() { assert!(output.is_ok()); let map: HashMap = output.unwrap().into_iter().collect(); + // Check pattern file exists + assert!(prefix_exists(&map, "-_0_0_0/vllm_patterns.FusionPass")); + // Check post-pass graph txt files exist assert!(prefix_exists(&map, "-_0_0_0/vllm_post_grad.0.FusionPass")); assert!(prefix_exists(&map, "-_0_0_0/vllm_post_grad.1.ReshapePass"));