Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 98 additions & 2 deletions src/oss_dev/providers/github/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import json
import re
import subprocess
import sys
from datetime import datetime, timezone
from typing import Any

from oss_dev.core.contracts.provider import (
Expand All @@ -21,6 +23,8 @@
from oss_dev.core.errors import ProviderError
from oss_dev.config.models import Config

LOW_RATE_LIMIT_THRESHOLD = 10


class GitHubCLIProvider(GitHubProvider):
"""GitHub provider using the gh CLI tool."""
Expand All @@ -43,8 +47,80 @@ def _require_gh(self) -> None:
details={"hint": "sudo apt install gh && gh auth login"},
)

def _format_rate_limit_reset(self, reset: Any) -> str | None:
try:
reset_timestamp = int(reset)
except (TypeError, ValueError):
return None

return datetime.fromtimestamp(
reset_timestamp,
tz=timezone.utc,
).isoformat()

def _get_rate_limit(self) -> dict[str, Any] | None:
try:
result = subprocess.run(
["gh", "api", "rate_limit"],
capture_output=True,
text=True,
check=True,
)
except (subprocess.CalledProcessError, FileNotFoundError):
return None

try:
payload = json.loads(result.stdout)
except json.JSONDecodeError:
return None

resources = payload.get("resources")
if not isinstance(resources, dict):
return None

core = resources.get("core")
if not isinstance(core, dict):
return None

return core

def _check_rate_limit(self) -> None:
rate_limit = self._get_rate_limit()
if not rate_limit:
return

try:
remaining = int(rate_limit["remaining"])
except (KeyError, TypeError, ValueError):
return

reset_time = self._format_rate_limit_reset(rate_limit.get("reset"))

if remaining <= 0:
details: dict[str, Any] = {"remaining": remaining}
if reset_time:
details["reset_time"] = reset_time

message = "GitHub API rate limit exceeded."
if reset_time:
message = f"{message} Resets at {reset_time}."

raise ProviderError(message, details=details)

if remaining <= LOW_RATE_LIMIT_THRESHOLD:
warning = f"Warning: GitHub API rate limit is low ({remaining} requests remaining)."
if reset_time:
warning = f"{warning} Resets at {reset_time}."
print(warning, file=sys.stderr)

def _run_gh(self, args: list[str]) -> str:
self._require_gh()
is_api_call = len(args) >= 2 and args[0] == "api"
endpoint = args[1] if len(args) >= 2 else None

if is_api_call and endpoint != "rate_limit":
self._check_rate_limit()

try:
result = subprocess.run(
["gh", *args],
Expand All @@ -54,9 +130,29 @@ def _run_gh(self, args: list[str]) -> str:
)
return result.stdout
except subprocess.CalledProcessError as e:
error_text = e.stderr.strip() or e.stdout.strip()
details: dict[str, Any] = {"args": args}

if "rate limit" in error_text.lower():
rate_limit = self._get_rate_limit() if endpoint != "rate_limit" else None
reset_time = self._format_rate_limit_reset(
rate_limit.get("reset") if rate_limit else None
)
if reset_time:
details["reset_time"] = reset_time
if rate_limit and "remaining" in rate_limit:
details["remaining"] = rate_limit["remaining"]

message = "GitHub API rate limit exceeded."
if reset_time:
message = f"{message} Resets at {reset_time}."
if error_text:
details["error"] = error_text
raise ProviderError(message, details=details) from e

raise ProviderError(
f"GitHub CLI error: {e.stderr.strip() or e.stdout.strip()}",
details={"args": args},
f"GitHub CLI error: {error_text}",
details=details,
) from e

def parse_issue_url(self, url: str) -> dict[str, Any]:
Expand Down
102 changes: 102 additions & 0 deletions tests/oss/test_github_cli_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import subprocess
from unittest.mock import patch

import pytest

from oss_dev.config.models import Config
from oss_dev.core.errors import ProviderError
from oss_dev.providers.github.client import GitHubCLIProvider


def _completed_process(stdout: str, stderr: str = "") -> subprocess.CompletedProcess[str]:
return subprocess.CompletedProcess(
args=["gh"],
returncode=0,
stdout=stdout,
stderr=stderr,
)


def test_healthy_quota():
with patch.object(GitHubCLIProvider, "_check_gh", return_value=True):
provider = GitHubCLIProvider(Config())

with patch("subprocess.run") as mock_run:
mock_run.side_effect = [
_completed_process(
'{"resources":{"core":{"limit":5000,"remaining":100,"reset":1710000000,"used":1}}}'
),
_completed_process("{}"),
]

output = provider._run_gh(["api", "repos/owner/repo"])

assert output == "{}"
assert mock_run.call_count == 2


def test_zero_quota():
with patch.object(GitHubCLIProvider, "_check_gh", return_value=True):
provider = GitHubCLIProvider(Config())

with patch("subprocess.run") as mock_run:
mock_run.side_effect = [
_completed_process(
'{"resources":{"core":{"limit":5000,"remaining":0,"reset":1710000000,"used":5000}}}'
)
]

with pytest.raises(ProviderError) as exc_info:
provider._run_gh(["api", "repos/owner/repo"])

assert "rate limit exceeded" in str(exc_info.value).lower()
assert mock_run.call_count == 1


def test_low_quota_warning(capsys: pytest.CaptureFixture[str]):
with patch.object(GitHubCLIProvider, "_check_gh", return_value=True):
provider = GitHubCLIProvider(Config())

with patch("subprocess.run") as mock_run:
mock_run.side_effect = [
_completed_process(
'{"resources":{"core":{"limit":5000,"remaining":5,"reset":1710000000,"used":4995}}}'
),
_completed_process("{}"),
]

output = provider._run_gh(["api", "repos/owner/repo"])

captured = capsys.readouterr()
assert "rate limit is low" in captured.err.lower()
assert output == "{}"


def test_malformed_rate_limit_response():
with patch.object(GitHubCLIProvider, "_check_gh", return_value=True):
provider = GitHubCLIProvider(Config())

with patch("subprocess.run") as mock_run:
mock_run.side_effect = [
_completed_process("not-json"),
_completed_process("{}"),
]

output = provider._run_gh(["api", "repos/owner/repo"])

assert output == "{}"


def test_rate_limit_endpoint_does_not_recurse():
with patch.object(GitHubCLIProvider, "_check_gh", return_value=True):
provider = GitHubCLIProvider(Config())

with patch("subprocess.run") as mock_run:
mock_run.return_value = _completed_process(
'{"resources":{"core":{"limit":5000,"remaining":100,"reset":1710000000,"used":1}}}'
)

output = provider._run_gh(["api", "rate_limit"])

assert output == '{"resources":{"core":{"limit":5000,"remaining":100,"reset":1710000000,"used":1}}}'
assert mock_run.call_count == 1
Loading