diff --git a/src/oss_dev/providers/cache.py b/src/oss_dev/providers/cache.py new file mode 100644 index 0000000..627c25a --- /dev/null +++ b/src/oss_dev/providers/cache.py @@ -0,0 +1,176 @@ +"""Cache abstraction for provider API responses. + +Provides a TTL-based, file-backed cache to avoid redundant +gh CLI calls within and across sessions. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import os +import time +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + +# Default cache settings +DEFAULT_TTL_SECONDS = 300 # 5 minutes +DEFAULT_CACHE_DIR = Path.home() / ".cache" / "oss_dev" / "github" + + +class CacheEntry: + """A single cached value with its expiry timestamp.""" + + def __init__(self, value: Any, ttl: int) -> None: + self.value = value + self.expires_at: float = time.time() + ttl + + def is_expired(self) -> bool: + return time.time() > self.expires_at + + def to_dict(self) -> dict[str, Any]: + return {"value": self.value, "expires_at": self.expires_at} + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "CacheEntry": + entry = cls.__new__(cls) + entry.value = data["value"] + entry.expires_at = data["expires_at"] + return entry + + +class ResponseCache: + """TTL-based, file-backed cache for GitHub CLI API responses. + + Each cache key maps to a JSON file on disk so that cache + entries survive process restarts (up to their TTL). + + Usage:: + + cache = ResponseCache(ttl=60) + cache.set("my_key", {"some": "data"}) + value = cache.get("my_key") # None if missing / expired + """ + + def __init__( + self, + ttl: int = DEFAULT_TTL_SECONDS, + cache_dir: Path | None = None, + enabled: bool = True, + ) -> None: + self._ttl = ttl + self._enabled = enabled + self._cache_dir = Path(cache_dir) if cache_dir else DEFAULT_CACHE_DIR + + # In-memory layer (avoids repeated disk reads in the same process) + self._memory: dict[str, CacheEntry] = {} + + if self._enabled: + self._cache_dir.mkdir(parents=True, exist_ok=True) + logger.debug("Cache initialised at %s (TTL=%ds)", self._cache_dir, ttl) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def get(self, key: str) -> Any | None: + """Return the cached value for *key*, or ``None`` if absent/expired.""" + if not self._enabled: + return None + + # 1. Check in-memory layer first + if key in self._memory: + entry = self._memory[key] + if not entry.is_expired(): + logger.debug("Cache hit (memory): %s", key) + return entry.value + # Stale – remove from memory and fall through to disk + del self._memory[key] + + # 2. Check disk layer + entry = self._load_from_disk(key) + if entry is not None and not entry.is_expired(): + logger.debug("Cache hit (disk): %s", key) + self._memory[key] = entry # promote to memory layer + return entry.value + + logger.debug("Cache miss: %s", key) + return None + + def set(self, key: str, value: Any) -> None: + """Store *value* under *key* with the configured TTL.""" + if not self._enabled: + return + + entry = CacheEntry(value, self._ttl) + self._memory[key] = entry + self._save_to_disk(key, entry) + logger.debug("Cache set: %s (expires in %ds)", key, self._ttl) + + def invalidate(self, key: str) -> None: + """Remove a single entry from both memory and disk.""" + self._memory.pop(key, None) + path = self._key_path(key) + if path.exists(): + path.unlink() + logger.debug("Cache invalidated: %s", key) + + def clear(self) -> None: + """Wipe the entire cache (memory + disk).""" + self._memory.clear() + if self._cache_dir.exists(): + for f in self._cache_dir.glob("*.json"): + f.unlink(missing_ok=True) + logger.debug("Cache cleared") + + def purge_expired(self) -> int: + """Delete all expired disk entries. Returns the number of files removed.""" + removed = 0 + if not self._cache_dir.exists(): + return removed + for path in self._cache_dir.glob("*.json"): + try: + data = json.loads(path.read_text()) + entry = CacheEntry.from_dict(data) + if entry.is_expired(): + path.unlink(missing_ok=True) + removed += 1 + except (json.JSONDecodeError, KeyError, OSError): + path.unlink(missing_ok=True) # corrupt file – remove it + removed += 1 + logger.debug("Purged %d expired cache entries", removed) + return removed + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @staticmethod + def make_key(*parts: str) -> str: + """Build a safe cache key by hashing the given string parts.""" + raw = ":".join(str(p) for p in parts) + return hashlib.sha256(raw.encode()).hexdigest() + + def _key_path(self, key: str) -> Path: + return self._cache_dir / f"{key}.json" + + def _save_to_disk(self, key: str, entry: CacheEntry) -> None: + path = self._key_path(key) + try: + path.write_text(json.dumps(entry.to_dict())) + except OSError as exc: + logger.warning("Cache write failed for key %s: %s", key, exc) + + def _load_from_disk(self, key: str) -> CacheEntry | None: + path = self._key_path(key) + if not path.exists(): + return None + try: + data = json.loads(path.read_text()) + return CacheEntry.from_dict(data) + except (json.JSONDecodeError, KeyError, OSError) as exc: + logger.warning("Cache read failed for key %s: %s", key, exc) + return None \ No newline at end of file diff --git a/src/oss_dev/providers/github/client.py b/src/oss_dev/providers/github/client.py index 84a62d6..cd6d364 100644 --- a/src/oss_dev/providers/github/client.py +++ b/src/oss_dev/providers/github/client.py @@ -20,15 +20,45 @@ ) from oss_dev.core.errors import ProviderError from oss_dev.config.models import Config +from oss_dev.providers.cache import ResponseCache class GitHubCLIProvider(GitHubProvider): - """GitHub provider using the gh CLI tool.""" + """GitHub provider using the gh CLI tool. + + Optionally wraps every read-only API call in a :class:`ResponseCache` + so that repeated requests for the same resource within (and across) + sessions skip the ``gh`` subprocess entirely. + + Cache behaviour is controlled via ``Config.cache``: + + .. code-block:: toml + + [cache] + enabled = true # flip to false to disable entirely + ttl = 300 # seconds before an entry is considered stale + dir = "~/.cache/oss_dev/github" # optional custom path + """ def __init__(self, config: Config) -> None: self._config = config self._gh_available = self._check_gh() + # ── Cache setup ──────────────────────────────────────────────── + cache_cfg = getattr(config, "cache", None) + if cache_cfg is not None: + enabled = getattr(cache_cfg, "enabled", True) + ttl = getattr(cache_cfg, "ttl", 300) + cache_dir = getattr(cache_cfg, "dir", None) + else: + enabled, ttl, cache_dir = True, 300, None + + self._cache = ResponseCache(ttl=ttl, cache_dir=cache_dir, enabled=enabled) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + def _check_gh(self) -> bool: try: subprocess.run(["gh", "--version"], capture_output=True, check=True) @@ -44,6 +74,7 @@ def _require_gh(self) -> None: ) def _run_gh(self, args: list[str]) -> str: + """Execute a ``gh`` command and return its stdout.""" self._require_gh() try: result = subprocess.run( @@ -59,6 +90,25 @@ def _run_gh(self, args: list[str]) -> str: details={"args": args}, ) from e + def _run_gh_cached(self, cache_key: str, args: list[str]) -> str: + """Return a cached response when available; otherwise call ``gh``. + + Only read-only (GET-equivalent) calls should use this helper. + Mutating operations (PR creation, etc.) must call :meth:`_run_gh` + directly so they are never served from cache. + """ + cached = self._cache.get(cache_key) + if cached is not None: + return cached + + output = self._run_gh(args) + self._cache.set(cache_key, output) + return output + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + def parse_issue_url(self, url: str) -> dict[str, Any]: pattern = r"github\.com/([^/]+)/([^/]+)/issues/(\d+)" match = re.search(pattern, url) @@ -71,12 +121,17 @@ def parse_issue_url(self, url: str) -> dict[str, Any]: } async def fetch_issue(self, owner: str, repo: str, issue_number: int) -> Issue: - output = self._run_gh([ - "api", - f"repos/{owner}/{repo}/issues/{issue_number}", - "--jq", - "{title: .title, body: .body, state: .state, labels: [.labels[].name], number: .number}", - ]) + cache_key = ResponseCache.make_key("fetch_issue", owner, repo, str(issue_number)) + output = self._run_gh_cached( + cache_key, + [ + "api", + f"repos/{owner}/{repo}/issues/{issue_number}", + "--jq", + "{title: .title, body: .body, state: .state," + " labels: [.labels[].name], number: .number}", + ], + ) data = json.loads(output) return Issue( number=data.get("number", issue_number), @@ -92,12 +147,17 @@ async def fetch_issue(self, owner: str, repo: str, issue_number: int) -> Issue: async def list_issues( self, owner: str, repo: str, state: str = "open", limit: int = 10 ) -> list[Issue]: - output = self._run_gh([ - "api", - f"repos/{owner}/{repo}/issues", - "--jq", - f".[:{limit}] | .[] | {{title: .title, number: .number, state: .state, labels: [.labels[].name], url: .html_url}}", - ]) + cache_key = ResponseCache.make_key("list_issues", owner, repo, state, str(limit)) + output = self._run_gh_cached( + cache_key, + [ + "api", + f"repos/{owner}/{repo}/issues", + "--jq", + f".[:{limit}] | .[] | {{title: .title, number: .number," + f" state: .state, labels: [.labels[].name], url: .html_url}}", + ], + ) issues = [] for line in output.strip().split("\n"): if line.strip(): @@ -125,6 +185,7 @@ async def create_pr( head: str, base: str = "main", ) -> PullRequest: + # Mutating operation – never cached output = self._run_gh([ "pr", "create", @@ -145,11 +206,15 @@ async def create_pr( async def get_pr_status( self, owner: str, repo: str, pr_number: int ) -> PRStatus: - output = self._run_gh([ - "pr", "view", str(pr_number), - "--repo", f"{owner}/{repo}", - "--json", "state,isDraft,reviewDecision,url", - ]) + cache_key = ResponseCache.make_key("get_pr_status", owner, repo, str(pr_number)) + output = self._run_gh_cached( + cache_key, + [ + "pr", "view", str(pr_number), + "--repo", f"{owner}/{repo}", + "--json", "state,isDraft,reviewDecision,url", + ], + ) data = json.loads(output) return PRStatus( state=data.get("state", "unknown"), @@ -161,12 +226,16 @@ async def get_pr_status( async def get_pr_comments( self, owner: str, repo: str, pr_number: int ) -> list[Comment]: - output = self._run_gh([ - "api", - f"repos/{owner}/{repo}/pulls/{pr_number}/comments", - "--jq", - ".[] | {body: .body, user: .user.login, created_at: .created_at}", - ]) + cache_key = ResponseCache.make_key("get_pr_comments", owner, repo, str(pr_number)) + output = self._run_gh_cached( + cache_key, + [ + "api", + f"repos/{owner}/{repo}/pulls/{pr_number}/comments", + "--jq", + ".[] | {body: .body, user: .user.login, created_at: .created_at}", + ], + ) comments = [] for line in output.strip().split("\n"): if line.strip(): @@ -179,3 +248,16 @@ async def get_pr_comments( ) ) return comments + + # ------------------------------------------------------------------ + # Cache management helpers (convenience, e.g. for CLI commands) + # ------------------------------------------------------------------ + + def invalidate_issue_cache(self, owner: str, repo: str, issue_number: int) -> None: + """Force-expire the cache for a specific issue.""" + key = ResponseCache.make_key("fetch_issue", owner, repo, str(issue_number)) + self._cache.invalidate(key) + + def clear_cache(self) -> None: + """Wipe the entire provider cache.""" + self._cache.clear() \ No newline at end of file diff --git a/tests/providers/__init__.py b/tests/providers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/providers/test_cache.py b/tests/providers/test_cache.py new file mode 100644 index 0000000..a4fe382 --- /dev/null +++ b/tests/providers/test_cache.py @@ -0,0 +1,285 @@ +"""Unit tests for the caching layer and its integration with GitHubCLIProvider. + +Run with: + pytest tests/providers/test_cache.py -v +""" + +from __future__ import annotations + +import json +import time +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from oss_dev.providers.cache import CacheEntry, ResponseCache + + +# ====================================================================== +# Helpers / fixtures +# ====================================================================== + +@pytest.fixture() +def tmp_cache(tmp_path: Path) -> ResponseCache: + """A ResponseCache that writes to a temp directory.""" + return ResponseCache(ttl=60, cache_dir=tmp_path, enabled=True) + + +def _make_config(enabled: bool = True, ttl: int = 300, cache_dir=None): + """Build a minimal Config-like object with a cache sub-namespace.""" + cache_ns = SimpleNamespace(enabled=enabled, ttl=ttl, dir=cache_dir) + return SimpleNamespace(cache=cache_ns) + + +# ====================================================================== +# CacheEntry +# ====================================================================== + +class TestCacheEntry: + def test_not_expired_within_ttl(self): + entry = CacheEntry("hello", ttl=60) + assert not entry.is_expired() + + def test_expired_after_ttl(self): + entry = CacheEntry("hello", ttl=0) + # expires_at is set to time.time() + 0, so it may already be expired + # Force it to be in the past + entry.expires_at = time.time() - 1 + assert entry.is_expired() + + def test_round_trip_dict(self): + entry = CacheEntry({"key": "value"}, ttl=30) + restored = CacheEntry.from_dict(entry.to_dict()) + assert restored.value == {"key": "value"} + assert restored.expires_at == pytest.approx(entry.expires_at, abs=0.01) + + +# ====================================================================== +# ResponseCache – basic get / set +# ====================================================================== + +class TestResponseCacheGetSet: + def test_get_returns_none_on_miss(self, tmp_cache): + assert tmp_cache.get("nonexistent") is None + + def test_set_then_get_returns_value(self, tmp_cache): + tmp_cache.set("k1", "hello") + assert tmp_cache.get("k1") == "hello" + + def test_set_persists_to_disk(self, tmp_cache, tmp_path): + tmp_cache.set("k2", {"a": 1}) + key = "k2" + disk_file = tmp_path / f"{key}.json" + assert disk_file.exists() + data = json.loads(disk_file.read_text()) + assert data["value"] == {"a": 1} + + def test_get_reads_from_disk_after_memory_eviction(self, tmp_path): + # Write directly via one instance, read via another (no shared memory) + c1 = ResponseCache(ttl=60, cache_dir=tmp_path, enabled=True) + c1.set("disk_key", "disk_value") + + c2 = ResponseCache(ttl=60, cache_dir=tmp_path, enabled=True) + assert c2.get("disk_key") == "disk_value" + + def test_expired_entry_returns_none(self, tmp_path): + cache = ResponseCache(ttl=60, cache_dir=tmp_path, enabled=True) + cache.set("exp_key", "will_expire") + # Manually expire the disk entry + key = ResponseCache.make_key() # not used here; key is literal "exp_key" + entry_path = tmp_path / "exp_key.json" + data = json.loads(entry_path.read_text()) + data["expires_at"] = time.time() - 1 + entry_path.write_text(json.dumps(data)) + # Also clear memory to force disk read + cache._memory.clear() + assert cache.get("exp_key") is None + + +# ====================================================================== +# ResponseCache – disabled mode +# ====================================================================== + +class TestResponseCacheDisabled: + def test_get_always_returns_none_when_disabled(self, tmp_path): + cache = ResponseCache(ttl=60, cache_dir=tmp_path, enabled=False) + cache.set("k", "v") # should be a no-op + assert cache.get("k") is None + + def test_no_files_written_when_disabled(self, tmp_path): + cache = ResponseCache(ttl=60, cache_dir=tmp_path, enabled=False) + cache.set("k", "v") + assert list(tmp_path.glob("*.json")) == [] + + +# ====================================================================== +# ResponseCache – invalidate / clear / purge +# ====================================================================== + +class TestResponseCacheManagement: + def test_invalidate_removes_entry(self, tmp_cache): + tmp_cache.set("del_me", 42) + tmp_cache.invalidate("del_me") + assert tmp_cache.get("del_me") is None + + def test_clear_removes_all_entries(self, tmp_cache): + for i in range(5): + tmp_cache.set(f"key{i}", i) + tmp_cache.clear() + for i in range(5): + assert tmp_cache.get(f"key{i}") is None + + def test_purge_expired_removes_stale_files(self, tmp_path): + cache = ResponseCache(ttl=60, cache_dir=tmp_path, enabled=True) + cache.set("good", "ok") + cache.set("bad", "stale") + + # Manually expire "bad" + bad_path = tmp_path / "bad.json" + data = json.loads(bad_path.read_text()) + data["expires_at"] = time.time() - 1 + bad_path.write_text(json.dumps(data)) + + removed = cache.purge_expired() + assert removed == 1 + assert not bad_path.exists() + assert (tmp_path / "good.json").exists() + + +# ====================================================================== +# ResponseCache – make_key +# ====================================================================== + +class TestMakeKey: + def test_same_parts_produce_same_key(self): + assert ResponseCache.make_key("a", "b", "c") == ResponseCache.make_key("a", "b", "c") + + def test_different_parts_produce_different_keys(self): + assert ResponseCache.make_key("a", "1") != ResponseCache.make_key("a", "2") + + def test_key_is_hex_string(self): + key = ResponseCache.make_key("owner", "repo", "42") + assert all(c in "0123456789abcdef" for c in key) + + +# ====================================================================== +# GitHubCLIProvider – caching integration +# ====================================================================== + +class TestGitHubCLIProviderCaching: + """Validate that the provider uses the cache correctly. + + We patch ``_run_gh`` (the raw subprocess wrapper) so no real + ``gh`` process is spawned. The cache is an in-process + ResponseCache with a temp directory. + """ + + def _make_provider(self, tmp_path, enabled=True, ttl=300): + # Import here so the test file doesn't fail if the module path + # changes during development + from oss_dev.providers.github.client import GitHubCLIProvider + + config = _make_config(enabled=enabled, ttl=ttl, cache_dir=str(tmp_path)) + with patch.object(GitHubCLIProvider, "_check_gh", return_value=True): + provider = GitHubCLIProvider(config) + # Replace cache dir to use tmp_path + provider._cache = ResponseCache(ttl=ttl, cache_dir=tmp_path, enabled=enabled) + return provider + + @pytest.mark.asyncio + async def test_fetch_issue_cached_on_second_call(self, tmp_path): + from oss_dev.providers.github.client import GitHubCLIProvider + + provider = self._make_provider(tmp_path) + issue_json = json.dumps({ + "number": 1, "title": "Test", "body": "Body", + "state": "open", "labels": [], + }) + + with patch.object(provider, "_run_gh", return_value=issue_json) as mock_gh: + await provider.fetch_issue("owner", "repo", 1) + await provider.fetch_issue("owner", "repo", 1) # should hit cache + + # gh should only have been called once + assert mock_gh.call_count == 1 + + @pytest.mark.asyncio + async def test_list_issues_cached(self, tmp_path): + from oss_dev.providers.github.client import GitHubCLIProvider + + provider = self._make_provider(tmp_path) + line = json.dumps({ + "number": 7, "title": "Issue 7", + "state": "open", "labels": [], "url": "https://github.com/o/r/issues/7", + }) + + with patch.object(provider, "_run_gh", return_value=line) as mock_gh: + await provider.list_issues("owner", "repo", limit=1) + await provider.list_issues("owner", "repo", limit=1) + + assert mock_gh.call_count == 1 + + @pytest.mark.asyncio + async def test_create_pr_never_cached(self, tmp_path): + """Mutating calls must always hit gh, never cache.""" + from oss_dev.providers.github.client import GitHubCLIProvider + + provider = self._make_provider(tmp_path) + pr_json = json.dumps({"url": "https://github.com/o/r/pull/1", "number": 1, "title": "PR"}) + + with patch.object(provider, "_run_gh", return_value=pr_json) as mock_gh: + await provider.create_pr("owner", "repo", "Title", "Body", "feat/branch") + await provider.create_pr("owner", "repo", "Title", "Body", "feat/branch") + + assert mock_gh.call_count == 2 # called every time + + @pytest.mark.asyncio + async def test_cache_disabled_always_calls_gh(self, tmp_path): + from oss_dev.providers.github.client import GitHubCLIProvider + + provider = self._make_provider(tmp_path, enabled=False) + issue_json = json.dumps({ + "number": 1, "title": "T", "body": "B", + "state": "open", "labels": [], + }) + + with patch.object(provider, "_run_gh", return_value=issue_json) as mock_gh: + await provider.fetch_issue("owner", "repo", 1) + await provider.fetch_issue("owner", "repo", 1) + + assert mock_gh.call_count == 2 + + @pytest.mark.asyncio + async def test_get_pr_status_cached(self, tmp_path): + from oss_dev.providers.github.client import GitHubCLIProvider + + provider = self._make_provider(tmp_path) + pr_status = json.dumps({ + "state": "OPEN", "isDraft": False, + "reviewDecision": None, "url": "https://github.com/o/r/pull/5", + }) + + with patch.object(provider, "_run_gh", return_value=pr_status) as mock_gh: + await provider.get_pr_status("owner", "repo", 5) + await provider.get_pr_status("owner", "repo", 5) + + assert mock_gh.call_count == 1 + + @pytest.mark.asyncio + async def test_invalidate_forces_fresh_call(self, tmp_path): + from oss_dev.providers.github.client import GitHubCLIProvider + + provider = self._make_provider(tmp_path) + issue_json = json.dumps({ + "number": 2, "title": "T", "body": "B", + "state": "open", "labels": [], + }) + + with patch.object(provider, "_run_gh", return_value=issue_json) as mock_gh: + await provider.fetch_issue("owner", "repo", 2) + provider.invalidate_issue_cache("owner", "repo", 2) + await provider.fetch_issue("owner", "repo", 2) + + assert mock_gh.call_count == 2 \ No newline at end of file