diff --git a/github/client.go b/github/client.go index 15fd2f8..8acb8bc 100644 --- a/github/client.go +++ b/github/client.go @@ -8,7 +8,9 @@ import ( "io" "log/slog" "net/http" + "reflect" "strconv" + "strings" "sync" "time" ) @@ -281,78 +283,140 @@ func (c *Client) parseRateLimitHeaders(resp *http.Response) { } } +// getPaginated fetches all pages of results for a given path. +// It handles GitHub's pagination by requesting 100 items per page until +// no more results are returned. +func (c *Client) getPaginated(ctx context.Context, basePath string, result any) error { + // Use reflection to work with any slice type + resultVal := reflect.ValueOf(result) + if resultVal.Kind() != reflect.Ptr || resultVal.Elem().Kind() != reflect.Slice { + return fmt.Errorf("result must be a pointer to a slice") + } + + sliceVal := resultVal.Elem() + + page := 1 + perPage := 100 // GitHub's maximum per_page value + + for { + // Build path with pagination parameters + separator := "?" + if strings.Contains(basePath, "?") { + separator = "&" + } + path := fmt.Sprintf("%s%spage=%d&per_page=%d", basePath, separator, page, perPage) + + // Create a new slice to hold this page's results + pageResult := reflect.New(sliceVal.Type()).Interface() + + if err := c.get(ctx, path, pageResult); err != nil { + return err + } + + // Get the slice value from the pointer + pageSlice := reflect.ValueOf(pageResult).Elem() + + // If we got no results, we're done + if pageSlice.Len() == 0 { + break + } + + // Append this page's results to the total + sliceVal = reflect.AppendSlice(sliceVal, pageSlice) + + // If we got fewer results than per_page, this is the last page + if pageSlice.Len() < perPage { + break + } + + page++ + } + + // Set the final result + resultVal.Elem().Set(sliceVal) + return nil +} + // GetFollowedUsers returns the users that the authenticated user follows. +// This method automatically handles pagination to fetch all followed users. func (c *Client) GetFollowedUsers(ctx context.Context) ([]User, error) { var users []User - if err := c.get(ctx, "/user/following", &users); err != nil { + if err := c.getPaginated(ctx, "/user/following", &users); err != nil { return nil, fmt.Errorf("fetching followed users: %w", err) } return users, nil } // GetFollowedUsersByUsername returns the users that a specific user follows. +// This method automatically handles pagination to fetch all followed users. func (c *Client) GetFollowedUsersByUsername(ctx context.Context, username string) ([]User, error) { var users []User path := fmt.Sprintf("/users/%s/following", username) - if err := c.get(ctx, path, &users); err != nil { + if err := c.getPaginated(ctx, path, &users); err != nil { return nil, fmt.Errorf("fetching users followed by %s: %w", username, err) } return users, nil } // GetStarredRepos returns repositories starred by the authenticated user. +// This method automatically handles pagination to fetch all starred repos. func (c *Client) GetStarredRepos(ctx context.Context) ([]Repository, error) { var repos []Repository - if err := c.get(ctx, "/user/starred", &repos); err != nil { + if err := c.getPaginated(ctx, "/user/starred", &repos); err != nil { return nil, fmt.Errorf("fetching starred repos: %w", err) } return repos, nil } // GetStarredReposByUsername returns repositories starred by a specific user. +// This method automatically handles pagination to fetch all starred repos. func (c *Client) GetStarredReposByUsername(ctx context.Context, username string) ([]Repository, error) { var repos []Repository path := fmt.Sprintf("/users/%s/starred", username) - if err := c.get(ctx, path, &repos); err != nil { + if err := c.getPaginated(ctx, path, &repos); err != nil { return nil, fmt.Errorf("fetching repos starred by %s: %w", username, err) } return repos, nil } // GetOwnedRepos returns repositories owned by the authenticated user. +// This method automatically handles pagination to fetch all owned repos. func (c *Client) GetOwnedRepos(ctx context.Context) ([]Repository, error) { var repos []Repository - if err := c.get(ctx, "/user/repos?type=owner", &repos); err != nil { + if err := c.getPaginated(ctx, "/user/repos?type=owner", &repos); err != nil { return nil, fmt.Errorf("fetching owned repos: %w", err) } return repos, nil } // GetOwnedReposByUsername returns repositories owned by a specific user. +// This method automatically handles pagination to fetch all owned repos. func (c *Client) GetOwnedReposByUsername(ctx context.Context, username string) ([]Repository, error) { var repos []Repository path := fmt.Sprintf("/users/%s/repos?type=owner", username) - if err := c.get(ctx, path, &repos); err != nil { + if err := c.getPaginated(ctx, path, &repos); err != nil { return nil, fmt.Errorf("fetching repos owned by %s: %w", username, err) } return repos, nil } // GetRecentEvents returns recent events for the authenticated user. +// This method automatically handles pagination to fetch all recent events. func (c *Client) GetRecentEvents(ctx context.Context, username string) ([]Event, error) { var events []Event path := fmt.Sprintf("/users/%s/events", username) - if err := c.get(ctx, path, &events); err != nil { + if err := c.getPaginated(ctx, path, &events); err != nil { return nil, fmt.Errorf("fetching events for %s: %w", username, err) } return events, nil } // GetReceivedEvents returns events received by a user (their feed). +// This method automatically handles pagination to fetch all received events. func (c *Client) GetReceivedEvents(ctx context.Context, username string) ([]Event, error) { var events []Event path := fmt.Sprintf("/users/%s/received_events", username) - if err := c.get(ctx, path, &events); err != nil { + if err := c.getPaginated(ctx, path, &events); err != nil { return nil, fmt.Errorf("fetching received events for %s: %w", username, err) } return events, nil diff --git a/github/client_test.go b/github/client_test.go index 085cb9e..78ebc99 100644 --- a/github/client_test.go +++ b/github/client_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "fmt" "log/slog" "net/http" "net/http/httptest" @@ -549,3 +550,141 @@ func TestWithLogger(t *testing.T) { t.Error("expected custom logger to be set") } } + +func TestGetFollowedUsersPagination(t *testing.T) { + // Create test users across multiple pages + // Page 1: 100 users, Page 2: 100 users, Page 3: 2 users (total: 202) + page1Users := make([]User, 100) + for i := 0; i < 100; i++ { + page1Users[i] = User{Login: fmt.Sprintf("user%d", i), ID: int64(i)} + } + + page2Users := make([]User, 100) + for i := 0; i < 100; i++ { + page2Users[i] = User{Login: fmt.Sprintf("user%d", i+100), ID: int64(i + 100)} + } + + page3Users := []User{ + {Login: "user200", ID: 200}, + {Login: "user201", ID: 201}, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check path + if r.URL.Path != "/user/following" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + + // Parse query params + query := r.URL.Query() + page := query.Get("page") + perPage := query.Get("per_page") + + // Verify per_page is set to 100 + if perPage != "100" { + t.Errorf("expected per_page=100, got %s", perPage) + } + + w.Header().Set("Content-Type", "application/json") + + // Return appropriate page + switch page { + case "1": + if err := json.NewEncoder(w).Encode(page1Users); err != nil { + t.Fatalf("encoding page 1: %v", err) + } + case "2": + if err := json.NewEncoder(w).Encode(page2Users); err != nil { + t.Fatalf("encoding page 2: %v", err) + } + case "3": + if err := json.NewEncoder(w).Encode(page3Users); err != nil { + t.Fatalf("encoding page 3: %v", err) + } + default: + t.Errorf("unexpected page number: %s", page) + } + })) + defer server.Close() + + c := NewClient("test-token", WithBaseURL(server.URL)) + result, err := c.GetFollowedUsers(context.Background()) + if err != nil { + t.Fatalf("GetFollowedUsers() error: %v", err) + } + + // Should have fetched all 202 users + if len(result) != 202 { + t.Errorf("expected 202 users, got %d", len(result)) + } + + // Verify first user from page 1 + if result[0].Login != "user0" { + t.Errorf("expected first user 'user0', got %q", result[0].Login) + } + + // Verify last user from page 3 + if result[201].Login != "user201" { + t.Errorf("expected last user 'user201', got %q", result[201].Login) + } + + // Verify a user from page 2 + if result[150].Login != "user150" { + t.Errorf("expected middle user 'user150', got %q", result[150].Login) + } +} + +func TestGetStarredReposPagination(t *testing.T) { + // Test with exactly 100 repos (single page, should not request page 2) + repos := make([]Repository, 100) + for i := 0; i < 100; i++ { + repos[i] = Repository{ + ID: int64(i), + Name: fmt.Sprintf("repo%d", i), + FullName: fmt.Sprintf("owner/repo%d", i), + } + } + + requestCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + + query := r.URL.Query() + page := query.Get("page") + + if page == "1" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(repos); err != nil { + t.Fatalf("encoding repos: %v", err) + } + } else if page == "2" { + // Should not request page 2 if page 1 had exactly 100 items + // Return empty to stop pagination + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode([]Repository{}); err != nil { + t.Fatalf("encoding empty repos: %v", err) + } + } else { + t.Errorf("unexpected page: %s", page) + } + })) + defer server.Close() + + c := NewClient("test-token", WithBaseURL(server.URL)) + result, err := c.GetStarredRepos(context.Background()) + if err != nil { + t.Fatalf("GetStarredRepos() error: %v", err) + } + + if len(result) != 100 { + t.Errorf("expected 100 repos, got %d", len(result)) + } + + // Should have made exactly 2 requests (page 1 and page 2 to check if more data exists) + // Actually, with the < perPage check, it should only make 1 request + // Let me fix the logic - if we get exactly perPage items, we need to check the next page + // So it should make 2 requests + if requestCount != 2 { + t.Errorf("expected 2 requests, got %d", requestCount) + } +}