Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions src/oss_dev/providers/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""Cache abstraction for provider API responses.

Provides a TTL-based, file-backed cache to avoid redundant
gh CLI calls within and across sessions.
"""

from __future__ import annotations

import hashlib
import json
import logging
import os
import time
from pathlib import Path
from typing import Any

logger = logging.getLogger(__name__)

# Default cache settings
DEFAULT_TTL_SECONDS = 300 # 5 minutes
DEFAULT_CACHE_DIR = Path.home() / ".cache" / "oss_dev" / "github"


class CacheEntry:
"""A single cached value with its expiry timestamp."""

def __init__(self, value: Any, ttl: int) -> None:
self.value = value
self.expires_at: float = time.time() + ttl

def is_expired(self) -> bool:
return time.time() > self.expires_at

def to_dict(self) -> dict[str, Any]:
return {"value": self.value, "expires_at": self.expires_at}

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "CacheEntry":
entry = cls.__new__(cls)
entry.value = data["value"]
entry.expires_at = data["expires_at"]
return entry


class ResponseCache:
"""TTL-based, file-backed cache for GitHub CLI API responses.

Each cache key maps to a JSON file on disk so that cache
entries survive process restarts (up to their TTL).

Usage::

cache = ResponseCache(ttl=60)
cache.set("my_key", {"some": "data"})
value = cache.get("my_key") # None if missing / expired
"""

def __init__(
self,
ttl: int = DEFAULT_TTL_SECONDS,
cache_dir: Path | None = None,
enabled: bool = True,
) -> None:
self._ttl = ttl
self._enabled = enabled
self._cache_dir = Path(cache_dir) if cache_dir else DEFAULT_CACHE_DIR

# In-memory layer (avoids repeated disk reads in the same process)
self._memory: dict[str, CacheEntry] = {}

if self._enabled:
self._cache_dir.mkdir(parents=True, exist_ok=True)
logger.debug("Cache initialised at %s (TTL=%ds)", self._cache_dir, ttl)

# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------

def get(self, key: str) -> Any | None:
"""Return the cached value for *key*, or ``None`` if absent/expired."""
if not self._enabled:
return None

# 1. Check in-memory layer first
if key in self._memory:
entry = self._memory[key]
if not entry.is_expired():
logger.debug("Cache hit (memory): %s", key)
return entry.value
# Stale – remove from memory and fall through to disk
del self._memory[key]

# 2. Check disk layer
entry = self._load_from_disk(key)
if entry is not None and not entry.is_expired():
logger.debug("Cache hit (disk): %s", key)
self._memory[key] = entry # promote to memory layer
return entry.value

logger.debug("Cache miss: %s", key)
return None

def set(self, key: str, value: Any) -> None:
"""Store *value* under *key* with the configured TTL."""
if not self._enabled:
return

entry = CacheEntry(value, self._ttl)
self._memory[key] = entry
self._save_to_disk(key, entry)
logger.debug("Cache set: %s (expires in %ds)", key, self._ttl)

def invalidate(self, key: str) -> None:
"""Remove a single entry from both memory and disk."""
self._memory.pop(key, None)
path = self._key_path(key)
if path.exists():
path.unlink()
logger.debug("Cache invalidated: %s", key)

def clear(self) -> None:
"""Wipe the entire cache (memory + disk)."""
self._memory.clear()
if self._cache_dir.exists():
for f in self._cache_dir.glob("*.json"):
f.unlink(missing_ok=True)
logger.debug("Cache cleared")

def purge_expired(self) -> int:
"""Delete all expired disk entries. Returns the number of files removed."""
removed = 0
if not self._cache_dir.exists():
return removed
for path in self._cache_dir.glob("*.json"):
try:
data = json.loads(path.read_text())
entry = CacheEntry.from_dict(data)
if entry.is_expired():
path.unlink(missing_ok=True)
removed += 1
except (json.JSONDecodeError, KeyError, OSError):
path.unlink(missing_ok=True) # corrupt file – remove it
removed += 1
logger.debug("Purged %d expired cache entries", removed)
return removed

# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------

@staticmethod
def make_key(*parts: str) -> str:
"""Build a safe cache key by hashing the given string parts."""
raw = ":".join(str(p) for p in parts)
return hashlib.sha256(raw.encode()).hexdigest()

def _key_path(self, key: str) -> Path:
return self._cache_dir / f"{key}.json"

def _save_to_disk(self, key: str, entry: CacheEntry) -> None:
path = self._key_path(key)
try:
path.write_text(json.dumps(entry.to_dict()))
except OSError as exc:
logger.warning("Cache write failed for key %s: %s", key, exc)

def _load_from_disk(self, key: str) -> CacheEntry | None:
path = self._key_path(key)
if not path.exists():
return None
try:
data = json.loads(path.read_text())
return CacheEntry.from_dict(data)
except (json.JSONDecodeError, KeyError, OSError) as exc:
logger.warning("Cache read failed for key %s: %s", key, exc)
return None
130 changes: 106 additions & 24 deletions src/oss_dev/providers/github/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,45 @@
)
from oss_dev.core.errors import ProviderError
from oss_dev.config.models import Config
from oss_dev.providers.cache import ResponseCache


class GitHubCLIProvider(GitHubProvider):
"""GitHub provider using the gh CLI tool."""
"""GitHub provider using the gh CLI tool.

Optionally wraps every read-only API call in a :class:`ResponseCache`
so that repeated requests for the same resource within (and across)
sessions skip the ``gh`` subprocess entirely.

Cache behaviour is controlled via ``Config.cache``:

.. code-block:: toml

[cache]
enabled = true # flip to false to disable entirely
ttl = 300 # seconds before an entry is considered stale
dir = "~/.cache/oss_dev/github" # optional custom path
"""

def __init__(self, config: Config) -> None:
self._config = config
self._gh_available = self._check_gh()

# ── Cache setup ────────────────────────────────────────────────
cache_cfg = getattr(config, "cache", None)
if cache_cfg is not None:
enabled = getattr(cache_cfg, "enabled", True)
ttl = getattr(cache_cfg, "ttl", 300)
cache_dir = getattr(cache_cfg, "dir", None)
else:
enabled, ttl, cache_dir = True, 300, None

self._cache = ResponseCache(ttl=ttl, cache_dir=cache_dir, enabled=enabled)

# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------

def _check_gh(self) -> bool:
try:
subprocess.run(["gh", "--version"], capture_output=True, check=True)
Expand All @@ -44,6 +74,7 @@ def _require_gh(self) -> None:
)

def _run_gh(self, args: list[str]) -> str:
"""Execute a ``gh`` command and return its stdout."""
self._require_gh()
try:
result = subprocess.run(
Expand All @@ -59,6 +90,25 @@ def _run_gh(self, args: list[str]) -> str:
details={"args": args},
) from e

def _run_gh_cached(self, cache_key: str, args: list[str]) -> str:
"""Return a cached response when available; otherwise call ``gh``.

Only read-only (GET-equivalent) calls should use this helper.
Mutating operations (PR creation, etc.) must call :meth:`_run_gh`
directly so they are never served from cache.
"""
cached = self._cache.get(cache_key)
if cached is not None:
return cached

output = self._run_gh(args)
self._cache.set(cache_key, output)
return output

# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------

def parse_issue_url(self, url: str) -> dict[str, Any]:
pattern = r"github\.com/([^/]+)/([^/]+)/issues/(\d+)"
match = re.search(pattern, url)
Expand All @@ -71,12 +121,17 @@ def parse_issue_url(self, url: str) -> dict[str, Any]:
}

async def fetch_issue(self, owner: str, repo: str, issue_number: int) -> Issue:
output = self._run_gh([
"api",
f"repos/{owner}/{repo}/issues/{issue_number}",
"--jq",
"{title: .title, body: .body, state: .state, labels: [.labels[].name], number: .number}",
])
cache_key = ResponseCache.make_key("fetch_issue", owner, repo, str(issue_number))
output = self._run_gh_cached(
cache_key,
[
"api",
f"repos/{owner}/{repo}/issues/{issue_number}",
"--jq",
"{title: .title, body: .body, state: .state,"
" labels: [.labels[].name], number: .number}",
],
)
data = json.loads(output)
return Issue(
number=data.get("number", issue_number),
Expand All @@ -92,12 +147,17 @@ async def fetch_issue(self, owner: str, repo: str, issue_number: int) -> Issue:
async def list_issues(
self, owner: str, repo: str, state: str = "open", limit: int = 10
) -> list[Issue]:
output = self._run_gh([
"api",
f"repos/{owner}/{repo}/issues",
"--jq",
f".[:{limit}] | .[] | {{title: .title, number: .number, state: .state, labels: [.labels[].name], url: .html_url}}",
])
cache_key = ResponseCache.make_key("list_issues", owner, repo, state, str(limit))
output = self._run_gh_cached(
cache_key,
[
"api",
f"repos/{owner}/{repo}/issues",
"--jq",
f".[:{limit}] | .[] | {{title: .title, number: .number,"
f" state: .state, labels: [.labels[].name], url: .html_url}}",
],
)
issues = []
for line in output.strip().split("\n"):
if line.strip():
Expand Down Expand Up @@ -125,6 +185,7 @@ async def create_pr(
head: str,
base: str = "main",
) -> PullRequest:
# Mutating operation – never cached
output = self._run_gh([
"pr",
"create",
Expand All @@ -145,11 +206,15 @@ async def create_pr(
async def get_pr_status(
self, owner: str, repo: str, pr_number: int
) -> PRStatus:
output = self._run_gh([
"pr", "view", str(pr_number),
"--repo", f"{owner}/{repo}",
"--json", "state,isDraft,reviewDecision,url",
])
cache_key = ResponseCache.make_key("get_pr_status", owner, repo, str(pr_number))
output = self._run_gh_cached(
cache_key,
[
"pr", "view", str(pr_number),
"--repo", f"{owner}/{repo}",
"--json", "state,isDraft,reviewDecision,url",
],
)
data = json.loads(output)
return PRStatus(
state=data.get("state", "unknown"),
Expand All @@ -161,12 +226,16 @@ async def get_pr_status(
async def get_pr_comments(
self, owner: str, repo: str, pr_number: int
) -> list[Comment]:
output = self._run_gh([
"api",
f"repos/{owner}/{repo}/pulls/{pr_number}/comments",
"--jq",
".[] | {body: .body, user: .user.login, created_at: .created_at}",
])
cache_key = ResponseCache.make_key("get_pr_comments", owner, repo, str(pr_number))
output = self._run_gh_cached(
cache_key,
[
"api",
f"repos/{owner}/{repo}/pulls/{pr_number}/comments",
"--jq",
".[] | {body: .body, user: .user.login, created_at: .created_at}",
],
)
comments = []
for line in output.strip().split("\n"):
if line.strip():
Expand All @@ -179,3 +248,16 @@ async def get_pr_comments(
)
)
return comments

# ------------------------------------------------------------------
# Cache management helpers (convenience, e.g. for CLI commands)
# ------------------------------------------------------------------

def invalidate_issue_cache(self, owner: str, repo: str, issue_number: int) -> None:
"""Force-expire the cache for a specific issue."""
key = ResponseCache.make_key("fetch_issue", owner, repo, str(issue_number))
self._cache.invalidate(key)

def clear_cache(self) -> None:
"""Wipe the entire provider cache."""
self._cache.clear()
Empty file added tests/providers/__init__.py
Empty file.
Loading