Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: pre-commit

on:
push:
branches: ["main", "master"]
pull_request:
branches: ["main", "master"]

jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.10"

- name: Install pre-commit
run: pip install pre-commit

- name: Run pre-commit
run: pre-commit run --all-files
18 changes: 18 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
repos:
- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black
language_version: python3.10

- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black"]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.7
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
6 changes: 4 additions & 2 deletions astroml/benchmarking/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
format_time,
format_memory,
set_random_seed,
get_device_info
get_device_info,
get_environment_info
)

__all__ = [
Expand Down Expand Up @@ -58,5 +59,6 @@
"format_time",
"format_memory",
"set_random_seed",
"get_device_info"
"get_device_info",
"get_environment_info"
]
29 changes: 29 additions & 0 deletions astroml/benchmarking/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,9 @@ def run_benchmark(self) -> BenchmarkResult:
# Save results
self._save_results(result)

# Save configuration with environment info for reproducibility
self._save_config()

if self.config.save_model:
self._save_model()

Expand Down Expand Up @@ -424,6 +427,32 @@ def _save_results(self, result: BenchmarkResult):
json.dump(metadata, f, indent=2)
print(f"Metadata saved to {metadata_path}")

def _save_config(self):
"""Save benchmark configuration with environment info for reproducibility."""
from .utils import get_environment_info

config_path = Path(self.config.output_dir) / "benchmark-config.yaml"

# Create output directory if it doesn't exist
config_path.parent.mkdir(parents=True, exist_ok=True)

# Collect environment information
env_info = get_environment_info()

# Build config dict with environment info
config_dict = {
'benchmark_config': self.config.to_dict(),
'environment': env_info,
'timestamp': time.time()
}

# Save as YAML
import yaml
with open(config_path, 'w') as f:
yaml.dump(config_dict, f, default_flow_style=False, sort_keys=False)

print(f"Configuration saved to {config_path}")

def _save_model(self):
"""Save trained model."""
if self.model is not None:
Expand Down
33 changes: 33 additions & 0 deletions astroml/benchmarking/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,36 @@ def callback(epoch: int, loss: float, metrics: Dict[str, float]):
print(f" {metric}: {value:.4f}")

return callback


def get_environment_info() -> Dict[str, Any]:
"""Collect environment information for reproducibility."""
import sys
import platform
from importlib.metadata import version

env_info = {
'python_version': sys.version,
'platform': platform.platform(),
'platform_system': platform.system(),
'platform_release': platform.release(),
'platform_version': platform.version(),
'platform_machine': platform.machine(),
'processor': platform.processor(),
}

# Get library versions
libraries = ['torch', 'numpy', 'scikit-learn', 'pandas', 'torch-geometric']
for lib in libraries:
try:
env_info[f'{lib}_version'] = version(lib)
except Exception:
try:
# Fallback for packages with different import names
import importlib
module = importlib.import_module(lib.replace('-', '_'))
env_info[f'{lib}_version'] = getattr(module, '__version__', 'unknown')
except Exception:
env_info[f'{lib}_version'] = 'not_installed'

return env_info
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,8 @@ mypy>=1.7.0
jupyter>=1.0.0
notebook>=7.0.0
ipykernel>=6.26.0
pre-commit>=3.7.0
isort>=5.13.0
ruff>=0.4.0
```

Loading