Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 72 additions & 8 deletions github/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
"io"
"log/slog"
"net/http"
"reflect"
"strconv"
"strings"
"sync"
"time"
)
Expand Down Expand Up @@ -281,78 +283,140 @@ func (c *Client) parseRateLimitHeaders(resp *http.Response) {
}
}

// getPaginated fetches all pages of results for a given path.
// It handles GitHub's pagination by requesting 100 items per page until
// no more results are returned.
func (c *Client) getPaginated(ctx context.Context, basePath string, result any) error {
// Use reflection to work with any slice type
resultVal := reflect.ValueOf(result)
if resultVal.Kind() != reflect.Ptr || resultVal.Elem().Kind() != reflect.Slice {
return fmt.Errorf("result must be a pointer to a slice")
}

sliceVal := resultVal.Elem()

page := 1
perPage := 100 // GitHub's maximum per_page value

for {
// Build path with pagination parameters
separator := "?"
if strings.Contains(basePath, "?") {
separator = "&"
}
path := fmt.Sprintf("%s%spage=%d&per_page=%d", basePath, separator, page, perPage)

// Create a new slice to hold this page's results
pageResult := reflect.New(sliceVal.Type()).Interface()

if err := c.get(ctx, path, pageResult); err != nil {
return err
}

// Get the slice value from the pointer
pageSlice := reflect.ValueOf(pageResult).Elem()

// If we got no results, we're done
if pageSlice.Len() == 0 {
break
}

// Append this page's results to the total
sliceVal = reflect.AppendSlice(sliceVal, pageSlice)

// If we got fewer results than per_page, this is the last page
if pageSlice.Len() < perPage {
break
}

page++
}

// Set the final result
resultVal.Elem().Set(sliceVal)
return nil
}

// GetFollowedUsers returns the users that the authenticated user follows.
// This method automatically handles pagination to fetch all followed users.
func (c *Client) GetFollowedUsers(ctx context.Context) ([]User, error) {
var users []User
if err := c.get(ctx, "/user/following", &users); err != nil {
if err := c.getPaginated(ctx, "/user/following", &users); err != nil {
return nil, fmt.Errorf("fetching followed users: %w", err)
}
return users, nil
}

// GetFollowedUsersByUsername returns the users that a specific user follows.
// This method automatically handles pagination to fetch all followed users.
func (c *Client) GetFollowedUsersByUsername(ctx context.Context, username string) ([]User, error) {
var users []User
path := fmt.Sprintf("/users/%s/following", username)
if err := c.get(ctx, path, &users); err != nil {
if err := c.getPaginated(ctx, path, &users); err != nil {
return nil, fmt.Errorf("fetching users followed by %s: %w", username, err)
}
return users, nil
}

// GetStarredRepos returns repositories starred by the authenticated user.
// This method automatically handles pagination to fetch all starred repos.
func (c *Client) GetStarredRepos(ctx context.Context) ([]Repository, error) {
var repos []Repository
if err := c.get(ctx, "/user/starred", &repos); err != nil {
if err := c.getPaginated(ctx, "/user/starred", &repos); err != nil {
return nil, fmt.Errorf("fetching starred repos: %w", err)
}
return repos, nil
}

// GetStarredReposByUsername returns repositories starred by a specific user.
// This method automatically handles pagination to fetch all starred repos.
func (c *Client) GetStarredReposByUsername(ctx context.Context, username string) ([]Repository, error) {
var repos []Repository
path := fmt.Sprintf("/users/%s/starred", username)
if err := c.get(ctx, path, &repos); err != nil {
if err := c.getPaginated(ctx, path, &repos); err != nil {
return nil, fmt.Errorf("fetching repos starred by %s: %w", username, err)
}
return repos, nil
}

// GetOwnedRepos returns repositories owned by the authenticated user.
// This method automatically handles pagination to fetch all owned repos.
func (c *Client) GetOwnedRepos(ctx context.Context) ([]Repository, error) {
var repos []Repository
if err := c.get(ctx, "/user/repos?type=owner", &repos); err != nil {
if err := c.getPaginated(ctx, "/user/repos?type=owner", &repos); err != nil {
return nil, fmt.Errorf("fetching owned repos: %w", err)
}
return repos, nil
}

// GetOwnedReposByUsername returns repositories owned by a specific user.
// This method automatically handles pagination to fetch all owned repos.
func (c *Client) GetOwnedReposByUsername(ctx context.Context, username string) ([]Repository, error) {
var repos []Repository
path := fmt.Sprintf("/users/%s/repos?type=owner", username)
if err := c.get(ctx, path, &repos); err != nil {
if err := c.getPaginated(ctx, path, &repos); err != nil {
return nil, fmt.Errorf("fetching repos owned by %s: %w", username, err)
}
return repos, nil
}

// GetRecentEvents returns recent events for the authenticated user.
// This method automatically handles pagination to fetch all recent events.
func (c *Client) GetRecentEvents(ctx context.Context, username string) ([]Event, error) {
var events []Event
path := fmt.Sprintf("/users/%s/events", username)
if err := c.get(ctx, path, &events); err != nil {
if err := c.getPaginated(ctx, path, &events); err != nil {
return nil, fmt.Errorf("fetching events for %s: %w", username, err)
}
return events, nil
}

// GetReceivedEvents returns events received by a user (their feed).
// This method automatically handles pagination to fetch all received events.
func (c *Client) GetReceivedEvents(ctx context.Context, username string) ([]Event, error) {
var events []Event
path := fmt.Sprintf("/users/%s/received_events", username)
if err := c.get(ctx, path, &events); err != nil {
if err := c.getPaginated(ctx, path, &events); err != nil {
return nil, fmt.Errorf("fetching received events for %s: %w", username, err)
}
return events, nil
Expand Down
139 changes: 139 additions & 0 deletions github/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -549,3 +550,141 @@ func TestWithLogger(t *testing.T) {
t.Error("expected custom logger to be set")
}
}

func TestGetFollowedUsersPagination(t *testing.T) {
// Create test users across multiple pages
// Page 1: 100 users, Page 2: 100 users, Page 3: 2 users (total: 202)
page1Users := make([]User, 100)
for i := 0; i < 100; i++ {
page1Users[i] = User{Login: fmt.Sprintf("user%d", i), ID: int64(i)}
}

page2Users := make([]User, 100)
for i := 0; i < 100; i++ {
page2Users[i] = User{Login: fmt.Sprintf("user%d", i+100), ID: int64(i + 100)}
}

page3Users := []User{
{Login: "user200", ID: 200},
{Login: "user201", ID: 201},
}

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check path
if r.URL.Path != "/user/following" {
t.Errorf("unexpected path: %s", r.URL.Path)
}

// Parse query params
query := r.URL.Query()
page := query.Get("page")
perPage := query.Get("per_page")

// Verify per_page is set to 100
if perPage != "100" {
t.Errorf("expected per_page=100, got %s", perPage)
}

w.Header().Set("Content-Type", "application/json")

// Return appropriate page
switch page {
case "1":
if err := json.NewEncoder(w).Encode(page1Users); err != nil {
t.Fatalf("encoding page 1: %v", err)
}
case "2":
if err := json.NewEncoder(w).Encode(page2Users); err != nil {
t.Fatalf("encoding page 2: %v", err)
}
case "3":
if err := json.NewEncoder(w).Encode(page3Users); err != nil {
t.Fatalf("encoding page 3: %v", err)
}
default:
t.Errorf("unexpected page number: %s", page)
}
}))
defer server.Close()

c := NewClient("test-token", WithBaseURL(server.URL))
result, err := c.GetFollowedUsers(context.Background())
if err != nil {
t.Fatalf("GetFollowedUsers() error: %v", err)
}

// Should have fetched all 202 users
if len(result) != 202 {
t.Errorf("expected 202 users, got %d", len(result))
}

// Verify first user from page 1
if result[0].Login != "user0" {
t.Errorf("expected first user 'user0', got %q", result[0].Login)
}

// Verify last user from page 3
if result[201].Login != "user201" {
t.Errorf("expected last user 'user201', got %q", result[201].Login)
}

// Verify a user from page 2
if result[150].Login != "user150" {
t.Errorf("expected middle user 'user150', got %q", result[150].Login)
}
}

func TestGetStarredReposPagination(t *testing.T) {
// Test with exactly 100 repos (single page, should not request page 2)
repos := make([]Repository, 100)
for i := 0; i < 100; i++ {
repos[i] = Repository{
ID: int64(i),
Name: fmt.Sprintf("repo%d", i),
FullName: fmt.Sprintf("owner/repo%d", i),
}
}

requestCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount++

query := r.URL.Query()
page := query.Get("page")

if page == "1" {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(repos); err != nil {
t.Fatalf("encoding repos: %v", err)
}
} else if page == "2" {
// Should not request page 2 if page 1 had exactly 100 items
// Return empty to stop pagination
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode([]Repository{}); err != nil {
t.Fatalf("encoding empty repos: %v", err)
}
} else {
t.Errorf("unexpected page: %s", page)
}
}))
defer server.Close()

c := NewClient("test-token", WithBaseURL(server.URL))
result, err := c.GetStarredRepos(context.Background())
if err != nil {
t.Fatalf("GetStarredRepos() error: %v", err)
}

if len(result) != 100 {
t.Errorf("expected 100 repos, got %d", len(result))
}

// Should have made exactly 2 requests (page 1 and page 2 to check if more data exists)
// Actually, with the < perPage check, it should only make 1 request
// Let me fix the logic - if we get exactly perPage items, we need to check the next page
// So it should make 2 requests
if requestCount != 2 {
t.Errorf("expected 2 requests, got %d", requestCount)
}
}
Loading