diff --git a/graph_net/test/backward_graph_extractor.sh b/graph_net/test/backward_graph_extractor.sh new file mode 100644 index 000000000..73819cf2c --- /dev/null +++ b/graph_net/test/backward_graph_extractor.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print( +os.path.dirname(graph_net.__file__))") +GRAPHNET_ROOT="$GRAPH_NET_ROOT/../" +OUTPUT_DIR="/tmp/backward_graph_samples" +mkdir -p "$OUTPUT_DIR" + +python3 -m graph_net.apply_sample_pass \ + --model-path-list "graph_net/config/small100_torch_samples_list.txt" \ + --sample-pass-file-path "graph_net/torch/sample_pass/backward_graph_extractor.py" \ + --sample-pass-class-name "BackwardGraphExtractorPass" \ + --sample-pass-config $(base64 -w 0 < bool: + return self.naive_sample_handled(rel_model_path, search_file_name="model.py") + + def resume(self, rel_model_path: str): + model_path_prefix = Path(self.config["model_path_prefix"]) + model_name = f"{os.path.basename(rel_model_path)}_backward" + model_path = model_path_prefix / rel_model_path + output_dir = Path(self.config["output_dir"]) / os.path.dirname(rel_model_path) + device = self._choose_device(self.config["device"]) + extractor = BackwardGraphExtractor(model_name, model_path, output_dir, device) + extractor() + + def _choose_device(self, device) -> str: + if device in ["cpu", "cuda"]: + return device + return "cuda" if torch.cuda.is_available() else "cpu"