diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 0000000..7005969 --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,24 @@ +name: pre-commit + +on: + push: + branches: ["main", "master"] + pull_request: + branches: ["main", "master"] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install pre-commit + run: pip install pre-commit + + - name: Run pre-commit + run: pre-commit run --all-files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..745623a --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +repos: + - repo: https://github.com/psf/black + rev: 24.4.2 + hooks: + - id: black + language_version: python3.10 + + - repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + args: ["--profile", "black"] + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.7 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] diff --git a/astroml/benchmarking/__init__.py b/astroml/benchmarking/__init__.py index 4d7b055..63bf2a8 100644 --- a/astroml/benchmarking/__init__.py +++ b/astroml/benchmarking/__init__.py @@ -25,7 +25,8 @@ format_time, format_memory, set_random_seed, - get_device_info + get_device_info, + get_environment_info ) __all__ = [ @@ -58,5 +59,6 @@ "format_time", "format_memory", "set_random_seed", - "get_device_info" + "get_device_info", + "get_environment_info" ] diff --git a/astroml/benchmarking/core.py b/astroml/benchmarking/core.py index 12090bb..9b4ff94 100644 --- a/astroml/benchmarking/core.py +++ b/astroml/benchmarking/core.py @@ -366,6 +366,9 @@ def run_benchmark(self) -> BenchmarkResult: # Save results self._save_results(result) + # Save configuration with environment info for reproducibility + self._save_config() + if self.config.save_model: self._save_model() @@ -424,6 +427,32 @@ def _save_results(self, result: BenchmarkResult): json.dump(metadata, f, indent=2) print(f"Metadata saved to {metadata_path}") + def _save_config(self): + """Save benchmark configuration with environment info for reproducibility.""" + from .utils import get_environment_info + + config_path = Path(self.config.output_dir) / "benchmark-config.yaml" + + # Create output directory if it doesn't exist + config_path.parent.mkdir(parents=True, exist_ok=True) + + # Collect environment information + env_info = get_environment_info() + + # Build config dict with environment info + config_dict = { + 'benchmark_config': self.config.to_dict(), + 'environment': env_info, + 'timestamp': time.time() + } + + # Save as YAML + import yaml + with open(config_path, 'w') as f: + yaml.dump(config_dict, f, default_flow_style=False, sort_keys=False) + + print(f"Configuration saved to {config_path}") + def _save_model(self): """Save trained model.""" if self.model is not None: diff --git a/astroml/benchmarking/utils.py b/astroml/benchmarking/utils.py index a53e8f6..faf30ec 100644 --- a/astroml/benchmarking/utils.py +++ b/astroml/benchmarking/utils.py @@ -227,3 +227,36 @@ def callback(epoch: int, loss: float, metrics: Dict[str, float]): print(f" {metric}: {value:.4f}") return callback + + +def get_environment_info() -> Dict[str, Any]: + """Collect environment information for reproducibility.""" + import sys + import platform + from importlib.metadata import version + + env_info = { + 'python_version': sys.version, + 'platform': platform.platform(), + 'platform_system': platform.system(), + 'platform_release': platform.release(), + 'platform_version': platform.version(), + 'platform_machine': platform.machine(), + 'processor': platform.processor(), + } + + # Get library versions + libraries = ['torch', 'numpy', 'scikit-learn', 'pandas', 'torch-geometric'] + for lib in libraries: + try: + env_info[f'{lib}_version'] = version(lib) + except Exception: + try: + # Fallback for packages with different import names + import importlib + module = importlib.import_module(lib.replace('-', '_')) + env_info[f'{lib}_version'] = getattr(module, '__version__', 'unknown') + except Exception: + env_info[f'{lib}_version'] = 'not_installed' + + return env_info diff --git a/requirements.txt b/requirements.txt index 6b7754c..04c326a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -65,3 +65,8 @@ mypy>=1.7.0 jupyter>=1.0.0 notebook>=7.0.0 ipykernel>=6.26.0 +pre-commit>=3.7.0 +isort>=5.13.0 +ruff>=0.4.0 +``` +