diff --git a/src/oss_dev/providers/github/client.py b/src/oss_dev/providers/github/client.py index 84a62d6..f483808 100644 --- a/src/oss_dev/providers/github/client.py +++ b/src/oss_dev/providers/github/client.py @@ -9,6 +9,8 @@ import json import re import subprocess +import sys +from datetime import datetime, timezone from typing import Any from oss_dev.core.contracts.provider import ( @@ -21,6 +23,8 @@ from oss_dev.core.errors import ProviderError from oss_dev.config.models import Config +LOW_RATE_LIMIT_THRESHOLD = 10 + class GitHubCLIProvider(GitHubProvider): """GitHub provider using the gh CLI tool.""" @@ -43,8 +47,80 @@ def _require_gh(self) -> None: details={"hint": "sudo apt install gh && gh auth login"}, ) + def _format_rate_limit_reset(self, reset: Any) -> str | None: + try: + reset_timestamp = int(reset) + except (TypeError, ValueError): + return None + + return datetime.fromtimestamp( + reset_timestamp, + tz=timezone.utc, + ).isoformat() + + def _get_rate_limit(self) -> dict[str, Any] | None: + try: + result = subprocess.run( + ["gh", "api", "rate_limit"], + capture_output=True, + text=True, + check=True, + ) + except (subprocess.CalledProcessError, FileNotFoundError): + return None + + try: + payload = json.loads(result.stdout) + except json.JSONDecodeError: + return None + + resources = payload.get("resources") + if not isinstance(resources, dict): + return None + + core = resources.get("core") + if not isinstance(core, dict): + return None + + return core + + def _check_rate_limit(self) -> None: + rate_limit = self._get_rate_limit() + if not rate_limit: + return + + try: + remaining = int(rate_limit["remaining"]) + except (KeyError, TypeError, ValueError): + return + + reset_time = self._format_rate_limit_reset(rate_limit.get("reset")) + + if remaining <= 0: + details: dict[str, Any] = {"remaining": remaining} + if reset_time: + details["reset_time"] = reset_time + + message = "GitHub API rate limit exceeded." + if reset_time: + message = f"{message} Resets at {reset_time}." + + raise ProviderError(message, details=details) + + if remaining <= LOW_RATE_LIMIT_THRESHOLD: + warning = f"Warning: GitHub API rate limit is low ({remaining} requests remaining)." + if reset_time: + warning = f"{warning} Resets at {reset_time}." + print(warning, file=sys.stderr) + def _run_gh(self, args: list[str]) -> str: self._require_gh() + is_api_call = len(args) >= 2 and args[0] == "api" + endpoint = args[1] if len(args) >= 2 else None + + if is_api_call and endpoint != "rate_limit": + self._check_rate_limit() + try: result = subprocess.run( ["gh", *args], @@ -54,9 +130,29 @@ def _run_gh(self, args: list[str]) -> str: ) return result.stdout except subprocess.CalledProcessError as e: + error_text = e.stderr.strip() or e.stdout.strip() + details: dict[str, Any] = {"args": args} + + if "rate limit" in error_text.lower(): + rate_limit = self._get_rate_limit() if endpoint != "rate_limit" else None + reset_time = self._format_rate_limit_reset( + rate_limit.get("reset") if rate_limit else None + ) + if reset_time: + details["reset_time"] = reset_time + if rate_limit and "remaining" in rate_limit: + details["remaining"] = rate_limit["remaining"] + + message = "GitHub API rate limit exceeded." + if reset_time: + message = f"{message} Resets at {reset_time}." + if error_text: + details["error"] = error_text + raise ProviderError(message, details=details) from e + raise ProviderError( - f"GitHub CLI error: {e.stderr.strip() or e.stdout.strip()}", - details={"args": args}, + f"GitHub CLI error: {error_text}", + details=details, ) from e def parse_issue_url(self, url: str) -> dict[str, Any]: diff --git a/tests/oss/test_github_cli_provider.py b/tests/oss/test_github_cli_provider.py new file mode 100644 index 0000000..28630d5 --- /dev/null +++ b/tests/oss/test_github_cli_provider.py @@ -0,0 +1,102 @@ +import subprocess +from unittest.mock import patch + +import pytest + +from oss_dev.config.models import Config +from oss_dev.core.errors import ProviderError +from oss_dev.providers.github.client import GitHubCLIProvider + + +def _completed_process(stdout: str, stderr: str = "") -> subprocess.CompletedProcess[str]: + return subprocess.CompletedProcess( + args=["gh"], + returncode=0, + stdout=stdout, + stderr=stderr, + ) + + +def test_healthy_quota(): + with patch.object(GitHubCLIProvider, "_check_gh", return_value=True): + provider = GitHubCLIProvider(Config()) + + with patch("subprocess.run") as mock_run: + mock_run.side_effect = [ + _completed_process( + '{"resources":{"core":{"limit":5000,"remaining":100,"reset":1710000000,"used":1}}}' + ), + _completed_process("{}"), + ] + + output = provider._run_gh(["api", "repos/owner/repo"]) + + assert output == "{}" + assert mock_run.call_count == 2 + + +def test_zero_quota(): + with patch.object(GitHubCLIProvider, "_check_gh", return_value=True): + provider = GitHubCLIProvider(Config()) + + with patch("subprocess.run") as mock_run: + mock_run.side_effect = [ + _completed_process( + '{"resources":{"core":{"limit":5000,"remaining":0,"reset":1710000000,"used":5000}}}' + ) + ] + + with pytest.raises(ProviderError) as exc_info: + provider._run_gh(["api", "repos/owner/repo"]) + + assert "rate limit exceeded" in str(exc_info.value).lower() + assert mock_run.call_count == 1 + + +def test_low_quota_warning(capsys: pytest.CaptureFixture[str]): + with patch.object(GitHubCLIProvider, "_check_gh", return_value=True): + provider = GitHubCLIProvider(Config()) + + with patch("subprocess.run") as mock_run: + mock_run.side_effect = [ + _completed_process( + '{"resources":{"core":{"limit":5000,"remaining":5,"reset":1710000000,"used":4995}}}' + ), + _completed_process("{}"), + ] + + output = provider._run_gh(["api", "repos/owner/repo"]) + + captured = capsys.readouterr() + assert "rate limit is low" in captured.err.lower() + assert output == "{}" + + +def test_malformed_rate_limit_response(): + with patch.object(GitHubCLIProvider, "_check_gh", return_value=True): + provider = GitHubCLIProvider(Config()) + + with patch("subprocess.run") as mock_run: + mock_run.side_effect = [ + _completed_process("not-json"), + _completed_process("{}"), + ] + + output = provider._run_gh(["api", "repos/owner/repo"]) + + assert output == "{}" + + +def test_rate_limit_endpoint_does_not_recurse(): + with patch.object(GitHubCLIProvider, "_check_gh", return_value=True): + provider = GitHubCLIProvider(Config()) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = _completed_process( + '{"resources":{"core":{"limit":5000,"remaining":100,"reset":1710000000,"used":1}}}' + ) + + output = provider._run_gh(["api", "rate_limit"]) + + assert output == '{"resources":{"core":{"limit":5000,"remaining":100,"reset":1710000000,"used":1}}}' + assert mock_run.call_count == 1