diff --git a/benchmarking/static_threshold_analyzer/BUILD.bazel b/benchmarking/static_threshold_analyzer/BUILD.bazel index d3459cc..d3aafb6 100644 --- a/benchmarking/static_threshold_analyzer/BUILD.bazel +++ b/benchmarking/static_threshold_analyzer/BUILD.bazel @@ -30,6 +30,7 @@ py_binary( main = "static_threshold_analyzer.py", srcs = ["static_threshold_analyzer.py"], deps = [ + "//benchmarking/utils:metric_parser", ":static_threshold_analyzer_lib", ], ) diff --git a/benchmarking/static_threshold_analyzer/static_threshold_analyzer.py b/benchmarking/static_threshold_analyzer/static_threshold_analyzer.py index 175f265..b469265 100644 --- a/benchmarking/static_threshold_analyzer/static_threshold_analyzer.py +++ b/benchmarking/static_threshold_analyzer/static_threshold_analyzer.py @@ -19,35 +19,14 @@ import argparse import json import sys -from typing import List from google.protobuf import json_format +from benchmarking.utils import metric_parser from benchmarking.proto import benchmark_result_pb2 -from benchmarking.proto.common import metric_pb2 from benchmarking.static_threshold_analyzer.static_threshold_analyzer_lib import ( StaticAnalyzer, ) -def _parse_metric_specs( - metric_specs_json: str, -) -> List[metric_pb2.MetricSpec]: - """Parses the JSON metric specifications list into a list of MetricSpec protos.""" - try: - metric_specs_list = json.loads(metric_specs_json) - except json.JSONDecodeError as e: - print(f"Error: Failed to parse --metric_specs_json: {e}", file=sys.stderr) - sys.exit(1) - - # Convert list of dicts to a list of MetricSpec protos - metric_specs = [] - for metric_dict in metric_specs_list: - metric_spec = metric_pb2.MetricSpec() - json_format.ParseDict(metric_dict, metric_spec) - metric_specs.append(metric_spec) - - return metric_specs - - def _load_benchmark_result( benchmark_result_file: str, ) -> benchmark_result_pb2.BenchmarkResult: @@ -76,7 +55,12 @@ def main(): parser.add_argument("--benchmark_result_file", required=True) args = parser.parse_args() - metric_specs = _parse_metric_specs(args.metric_specs_json) + try: + metric_specs = metric_parser.parse_metric_specs_from_json(args.metric_specs_json) + except ValueError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + benchmark_result = _load_benchmark_result(args.benchmark_result_file) analyzer = StaticAnalyzer(metric_specs) analyzer.run_analysis(benchmark_result) diff --git a/benchmarking/tb_parser/BUILD.bazel b/benchmarking/tb_parser/BUILD.bazel index 97064ea..9ea4a6a 100644 --- a/benchmarking/tb_parser/BUILD.bazel +++ b/benchmarking/tb_parser/BUILD.bazel @@ -34,6 +34,7 @@ py_binary( srcs = ["tb_parser.py"], deps = [ ":tb_parser_lib", + "//benchmarking/utils:metric_parser", "@pypi//protovalidate", ], ) diff --git a/benchmarking/tb_parser/tb_parser.py b/benchmarking/tb_parser/tb_parser.py index 3280ee0..f71f998 100644 --- a/benchmarking/tb_parser/tb_parser.py +++ b/benchmarking/tb_parser/tb_parser.py @@ -15,41 +15,14 @@ """Script to extract statistics from TensorFlow event files and produce a BenchmarkResult JSON artifact.""" import argparse -import json import sys -from typing import List from google.protobuf import json_format, timestamp_pb2 +from benchmarking.utils import metric_parser from benchmarking.tb_parser import tb_parser_lib -from benchmarking.proto.common import metric_pb2 from benchmarking.proto import benchmark_result_pb2 from protovalidate import validate, ValidationError -def _parse_metric_specs( - metric_specs_json: str, -) -> List[metric_pb2.MetricSpec]: - """Parses the JSON metric specifications list into a list of MetricSpec protos.""" - - try: - metric_specs_list = json.loads(metric_specs_json) - except json.JSONDecodeError as e: - print(f"Error: Failed to parse --metric_specs_json: {e}", file=sys.stderr) - sys.exit(1) - - # Convert list of metric spec dicts to a list of MetricSpec protos - metric_specs = [] - - if not metric_specs_list: - return metric_specs - - for metric_dict in metric_specs_list: - metric_spec = metric_pb2.MetricSpec() - json_format.ParseDict(metric_dict, metric_spec) - metric_specs.append(metric_spec) - - return metric_specs - - def _format_validation_error(violation) -> str: """Formats a single protovalidate violation into a human-readable string.""" field_path_str = ".".join( @@ -80,7 +53,12 @@ def main(): args = parser.parse_args() - metric_specs = _parse_metric_specs(args.metric_specs_json) + try: + metric_specs = metric_parser.parse_metric_specs_from_json(args.metric_specs_json) + except ValueError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + tb_parser = tb_parser_lib.TensorBoardParser(metric_specs) computed_stats = tb_parser.parse_and_compute(args.tblog_dir) diff --git a/benchmarking/utils/BUILD.bazel b/benchmarking/utils/BUILD.bazel new file mode 100644 index 0000000..bc3f0b1 --- /dev/null +++ b/benchmarking/utils/BUILD.bazel @@ -0,0 +1,24 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_library") + +py_library( + name = "metric_parser", + srcs = ["metric_parser.py"], + visibility = ["//benchmarking:__subpackages__"], + deps = [ + "//benchmarking/proto/common:metric_py_proto", + ], +) diff --git a/benchmarking/utils/metric_parser.py b/benchmarking/utils/metric_parser.py new file mode 100644 index 0000000..2e8ba12 --- /dev/null +++ b/benchmarking/utils/metric_parser.py @@ -0,0 +1,31 @@ +"""Utility library for parsing metric specifications from JSON.""" + +import collections.abc +import json +from google.protobuf import json_format +from benchmarking.proto.common import metric_pb2 + + +def parse_metric_specs_from_json( + metric_specs_json: str, +) -> collections.abc.Sequence[metric_pb2.MetricSpec]: + """Parses a JSON string into a list of MetricSpec protos. + + Gracefully handles "null" or empty inputs by returning an empty list. + Raises ValueError if the input string is not valid JSON. + """ + try: + metric_specs_list = json.loads(metric_specs_json) + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse metric_specs_json: {e}") from e + + metric_specs = [] + if not metric_specs_list: + return metric_specs + + for metric_dict in metric_specs_list: + metric_spec = metric_pb2.MetricSpec() + json_format.ParseDict(metric_dict, metric_spec) + metric_specs.append(metric_spec) + + return metric_specs