diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index da25ed1..f1289e3 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,73 +14,13 @@ env: IMAGE: riscv-runner PROD_URL: https://riseriscvrunnerappqdvknz9s-ghfe.functions.fnc.fr-par.scw.cloud STAGING_URL: https://riseriscvrunnerappst73ndwr0w-ghfe.functions.fnc.fr-par.scw.cloud - GO_GHFE_URL: https://riseriscvrunnerappst73ndwr0w-ghfe-go.functions.fnc.fr-par.scw.cloud jobs: test: - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - fetch-depth: 0 # diff-cover needs the base commit - - name: Setup Python 3.12 - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 - with: - python-version: "3.12" - - name: Run tests - run: | - python -m venv .venv - source .venv/bin/activate - pip install --upgrade pip - pip install -r container/requirements.txt - pip install -r requirements-dev.txt - PYTHONPATH=${{ github.workspace }}/container pytest - - - name: Add diff coverage to step summary - if: github.event_name == 'pull_request' || github.event_name == 'push' - run: | - if [[ "${{ github.event_name }}" = "pull_request" ]]; then - BASE_SHA="${{ github.event.pull_request.base.sha }}" - BASE_REF_NAME="refs/heads/${{ github.event.pull_request.base.ref }}" - BASE_REF_URL="${{ github.event.pull_request.base.repo.html_url }}/tree/${BASE_SHA}" - elif [[ "${{ github.event_name }}" = "push" ]]; then - if [[ "${{ github.ref_name == github.event.repository.default_branch }}" = "true" ]]; then - # If we are on default branch - if [[ "${{ github.event.forced }}" = "true" ]]; then - # If we are force-pushing, we don't know what's the previous commit to compare to - echo "::error::.github/workflows/release.yml Branch ${{ github.ref_name}} was just force-pushed, can't measure diff-coverage" - exit 0 # do not fail the workflow nonetheless - fi - BASE_SHA="${{ github.event.before }}" - BASE_REF_NAME="${{ github.ref }}" - else - # If we are not on default branch, compare to default branch - git fetch origin ${{ github.event.repository.default_branch }} - BASE_SHA="$(git rev-parse origin/${{ github.event.repository.default_branch }})" - BASE_REF_NAME="${{ github.event.repository.default_branch }}" - fi - BASE_REF_URL="${{ github.event.repository.html_url }}/tree/${BASE_SHA}" - fi - if [[ -n "${BASE_SHA}" ]]; then - source .venv/bin/activate - diff-cover coverage.xml \ - --compare-branch "${BASE_SHA}" \ - --markdown-report diff-cover.md \ - --fail-under 80 - { - echo "" - echo "**Base ref: [${BASE_REF_NAME}](${BASE_REF_URL})**" - echo "" - cat diff-cover.md - } >> "$GITHUB_STEP_SUMMARY" - fi - - test-go: runs-on: ubuntu-latest defaults: run: - working-directory: container-go + working-directory: container steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: @@ -88,7 +28,7 @@ jobs: - name: Setup Go uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v6.0.0 with: - go-version-file: container-go/go.mod + go-version-file: container/go.mod - name: go vet run: go vet ./... - name: gofmt check @@ -105,9 +45,11 @@ jobs: build: needs: [test] runs-on: ubuntu-latest + defaults: + run: + working-directory: container steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - - uses: rlespinasse/github-slug-action@9e7def61550737ba68c62d34a32dd31792e3f429 # v5.5.0 - name: Setup Docker Buildx uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0 @@ -125,7 +67,6 @@ jobs: uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6.0.0 - name: Build ghfe image - id: ghfe uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0 with: platforms: linux/amd64 @@ -137,11 +78,10 @@ jobs: cache-from: | type=gha,scope=docker cache-to: | - ${{ github.ref_name == github.event.repository.default_branch && 'type=gha,scope=type=gha,scope=docker' || '' }} + ${{ github.ref_name == github.event.repository.default_branch && 'type=gha,scope=docker' || '' }} push: ${{ github.repository_owner == 'riseproject-dev' }} - name: Build scheduler image - id: scheduler uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0 with: platforms: linux/amd64 @@ -153,63 +93,13 @@ jobs: cache-from: | type=gha,scope=docker cache-to: | - ${{ github.ref_name == github.event.repository.default_branch && 'type=gha,scope=type=gha,scope=docker' || '' }} - push: ${{ github.repository_owner == 'riseproject-dev' }} - - build-go: - needs: [test-go] - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - - name: Setup Docker Buildx - uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0 - - - name: Login to Container Registry - if: github.repository_owner == 'riseproject-dev' - uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0 - with: - registry: ${{ env.REGISTRY }} - username: nologin - password: ${{ secrets.SCW_SECRET_KEY }} - - - name: Extract metadata for Docker - id: meta - uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6.0.0 - - - name: Build ghfe-go image - uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0 - with: - platforms: linux/amd64 - context: container-go - file: container-go/Dockerfile - target: ghfe - tags: ${{ env.REGISTRY }}/${{ env.IMAGE }}:ghfe-sha-${{ github.sha }}-go - labels: ${{ steps.meta.outputs.labels }} - cache-from: | - type=gha,scope=docker-go - cache-to: | - ${{ github.ref_name == github.event.repository.default_branch && 'type=gha,scope=docker-go' || '' }} - push: ${{ github.repository_owner == 'riseproject-dev' }} - - - name: Build scheduler-go image - uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0 - with: - platforms: linux/amd64 - context: container-go - file: container-go/Dockerfile - target: scheduler - tags: ${{ env.REGISTRY }}/${{ env.IMAGE }}:scheduler-sha-${{ github.sha }}-go - labels: ${{ steps.meta.outputs.labels }} - cache-from: | - type=gha,scope=docker-go - cache-to: | - ${{ github.ref_name == github.event.repository.default_branch && 'type=gha,scope=docker-go' || '' }} + ${{ github.ref_name == github.event.repository.default_branch && 'type=gha,scope=docker' || '' }} push: ${{ github.repository_owner == 'riseproject-dev' }} deploy-staging: if: github.repository_owner == 'riseproject-dev' && github.ref_name == github.event.repository.default_branch name: "deploy to staging" - needs: [build, build-go] + needs: [build] runs-on: ubuntu-latest environment: staging concurrency: @@ -240,19 +130,8 @@ jobs: -t ${{ env.REGISTRY }}/${{ env.IMAGE }}:scheduler-staging \ ${{ env.REGISTRY }}/${{ env.IMAGE }}:scheduler-sha-${{ github.sha }} - - name: Tag ghfe-go image for staging - run: >- - docker buildx imagetools create \ - -t ${{ env.REGISTRY }}/${{ env.IMAGE }}:ghfe-staging-go \ - ${{ env.REGISTRY }}/${{ env.IMAGE }}:ghfe-sha-${{ github.sha }}-go - - - name: Tag scheduler-go image for staging - run: >- - docker buildx imagetools create \ - -t ${{ env.REGISTRY }}/${{ env.IMAGE }}:scheduler-staging-go \ - ${{ env.REGISTRY }}/${{ env.IMAGE }}:scheduler-sha-${{ github.sha }}-go - - name: Deploy to Scaleway + working-directory: container run: npm ci && npx serverless deploy --stage=staging env: SCW_SECRET_KEY: ${{ secrets.SCW_SECRET_KEY }} @@ -279,7 +158,7 @@ jobs: deploy-prod: if: github.repository_owner == 'riseproject-dev' && github.ref_name == github.event.repository.default_branch name: "deploy to prod" - needs: [build, build-go, deploy-staging] + needs: [build, deploy-staging] runs-on: ubuntu-latest environment: prod concurrency: @@ -310,19 +189,8 @@ jobs: -t ${{ env.REGISTRY }}/${{ env.IMAGE }}:scheduler-prod \ ${{ env.REGISTRY }}/${{ env.IMAGE }}:scheduler-sha-${{ github.sha }} - - name: Tag ghfe-go image for prod - run: >- - docker buildx imagetools create \ - -t ${{ env.REGISTRY }}/${{ env.IMAGE }}:ghfe-prod-go \ - ${{ env.REGISTRY }}/${{ env.IMAGE }}:ghfe-sha-${{ github.sha }}-go - - - name: Tag scheduler-go image for prod - run: >- - docker buildx imagetools create \ - -t ${{ env.REGISTRY }}/${{ env.IMAGE }}:scheduler-prod-go \ - ${{ env.REGISTRY }}/${{ env.IMAGE }}:scheduler-sha-${{ github.sha }}-go - - name: Deploy to Scaleway + working-directory: container run: npm ci && npx serverless@3 deploy --stage=main env: SCW_SECRET_KEY: ${{ secrets.SCW_SECRET_KEY }} diff --git a/README.md b/README.md index ff1044a..1bc62d7 100644 --- a/README.md +++ b/README.md @@ -52,13 +52,12 @@ The system is split into two containers: GitHub (workflow_job webhook) | v -ghfe (ghfe.py) - | - Proxies webhooks to staging for staging entities (prod only) +ghfe (container/cmd/ghfe) | - Verifies webhook signature | - Validates labels, determines entity type (org or personal) | - Resolves (entity_id, job_labels) -> (k8s_pool, k8s_image) | - Writes job to PostgreSQL - | - Serves /usage, /history + | - Serves /setup/{org,personal}, /trace/* | - NO GitHub API calls, NO k8s calls | v @@ -69,7 +68,7 @@ PostgreSQL (state store) | - LISTEN/NOTIFY: wakes scheduler on new jobs | v -Scheduler (scheduler.py) +Scheduler (container/cmd/scheduler) | - sync_jobs_state: sync job status with GitHub | - sync_workers_state: runs under a per-scheduler LOCK TABLE workers advisory, | 5 phases (atomic, single transaction): @@ -219,9 +218,10 @@ chronological order: Each row carries the full payload as `JSONB`, plus filter/index keys (`installation_id`, `app_id`, `entity_type`, `entity_id`, `entity_name`) -and a free-form `outcome` string. The `WebhookOutcome` enum in -`constants.py` is the canonical list of outcome values; the column itself -is `TEXT` so new outcomes don't require schema migrations. `entity_id` +and a free-form `outcome` string. The `WebhookOutcome` type in +`container/internal/contract.go` is the canonical list of outcome values; +the column itself is `TEXT` so new outcomes don't require schema +migrations. `entity_id` is the GitHub `account.id`, which is stable across renames and reinstalls — uninstalling and reinstalling the app produces a new `installation_id` but keeps the same `entity_id`. @@ -236,10 +236,11 @@ has no UNIQUE constraint on payload so a duplicate log row is acceptable (the trace endpoints can dedupe by `delivery_id` from the JSONB payload when needed). -The scheduler's `_gh_authenticate_app` wrapper logs only failures -(`gh.authenticate_app` is `@ttl_cache`-decorated, so success is the hot -path). `cachetools.func.ttl_cache` does not cache exceptions, so transient -errors don't poison subsequent calls. +The scheduler's `ghAuthenticate` wrapper +(`container/cmd/scheduler/gh_auth.go`) only records failures: the +underlying `AuthenticateApp` is TTL-cached, so success is the hot path +and would drown the log. Failures are not cached, so transient errors +don't poison subsequent calls. #### State reconstruction @@ -391,12 +392,13 @@ The scheduler iterates pending jobs in FIFO order. For each job: ### Configuration -Per-entity configuration is defined in `ENTITY_CONFIG` in `constants.py`, keyed by entity ID (org ID or user ID): +Per-entity configuration is defined in `EntityConfigs` in +`container/internal/constants.go`, keyed by entity ID (org ID or user ID): | Field | Type | Description | |-------|------|-------------| -| `max_workers` | int or None | Maximum concurrent workers across all pools. None = unlimited | -| `staging` | bool | If true, webhooks are proxied from prod to staging | +| `MaxWorkers` | `*int` | Maximum concurrent workers across all pools. `nil` = unlimited | +| `Staging` | `[]string` | Repository names whose webhooks should be proxied from prod to staging | ### HTTP routes @@ -406,8 +408,8 @@ Per-entity configuration is defined in `ENTITY_CONFIG` in `constants.py`, keyed |-------|--------|-------------| | `/` | POST | Webhook endpoint for `workflow_job` events | | `/health` | GET | Health check (returns `ok`) | -| `/usage` | GET | Human-readable view of per-pool jobs and workers | -| `/history` | GET | Job history sorted by status (pending, running, completed) then creation time | +| `/setup/org` | GET | GitHub App post-install landing page for organization installations | +| `/setup/personal` | GET | GitHub App post-install landing page for personal-account installations | | `/trace/entity/` | GET | Installation event log for an entity (requires bearer token) | | `/trace/installation/` | GET | Resolves to `entity_id` then returns its event log | | `/trace/job/` | GET | Resolves to `entity_id` via `jobs.entity_id` then returns its event log | @@ -418,38 +420,26 @@ Per-entity configuration is defined in `ENTITY_CONFIG` in `constants.py`, keyed | Route | Method | Description | |-------|--------|-------------| | `/health` | GET | Health check (returns `ok`) | +| `/usage` | GET | Human-readable view of per-pool jobs and workers (`/usage.json` for JSON) | +| `/history`, `/jobs` | GET | Job history sorted by status then creation time (`.json` variants for JSON) | +| `/workers` | GET | Worker history with `failure_info` for failed workers (`.json` variant for JSON) | ### Key files | File | Purpose | |------|---------| -| `container/constants.py` | Environment configuration, entity config, image tags | -| `container/ghfe.py` | Flask webhook handler -- validates requests, writes to PostgreSQL | -| `container/scheduler.py` | Scheduler -- GH reconciliation, demand matching, cleanup, worker status sync | -| `container/k8s.py` | Kubernetes pod provisioning, deletion, capacity checks, failure info collection | -| `container/db.py` | PostgreSQL database operations | -| `container/github.py` | GitHub API functions (auth, runner groups, JIT config, job status) | -| `container/Dockerfile` | Docker image for the ghfe and scheduler containers | -| `container-go/` | Go reimplementation of ghfe and scheduler (see `container-go/CONTRACT.md`) | +| `container/cmd/ghfe/` | Webhook handler — validates requests, writes to PostgreSQL, serves `/setup/*` and `/trace/*` | +| `container/cmd/scheduler/` | Scheduler — GH reconciliation, demand matching, cleanup, worker status sync; serves `/usage`, `/history`, `/jobs`, `/workers` | +| `container/internal/constants.go` | Environment configuration, `EntityConfigs`, timeouts, image tags | +| `container/internal/contract.go` | Shared types, `WebhookOutcome` enum, DB/GitHub/Kube interfaces | +| `container/internal/db.go` | PostgreSQL operations (pgx) | +| `container/internal/github.go` | GitHub App auth + REST client | +| `container/internal/k8s.go` | Kubernetes pod provisioning, deletion, capacity checks, failure-info collection | +| `container/internal/testutil/` | In-memory fakes shared by `cmd/` tests | +| `container/Dockerfile` | Multi-stage build producing the `ghfe` and `scheduler` images | +| `container/serverless.yml` | Scaleway Serverless deployment manifest | | `scripts/trace_installation.py` | CLI client for the `/trace/*` endpoints — chronological table + diagnosis hints | -### Go cutover - -`container-go/` ships a Go reimplementation deployed alongside the Python tree -as the `ghfe-go` and `scheduler-go` Scaleway functions. Cutover is gradual: - -- **ghfe**: set `GO_GHFE_URL` on the Python ghfe function to the Go ghfe URL, - then populate `GO_GHFE_ROUTING={"entities":[, ...]}` with the - GitHub owner ids to forward. Only `workflow_job` webhooks are routed; the - staging proxy at `container/ghfe.py:509` still runs first. Rollback is - removing entries from `GO_GHFE_ROUTING`. -- **scheduler**: single deployment. Once staging has soaked on the Go - scheduler, swap the prod `scheduler` function's image to - `scheduler-prod-go`. Rollback is the inverse image swap. - -See `container-go/CONTRACT.md` for the frozen behavioral surface the Go port -must preserve. - ### Infrastructure | Service | Product | Purpose | @@ -465,20 +455,16 @@ Production and staging each have their own k8s cluster, provisioned via the `scr ## Development -Create a python venv and install dev dependencies: -```bash -python3.12 -m venv .venv -source .venv/bin/activate -pip install --upgrade pip -pip install -r requirements-dev.txt -``` +The containers are pure Go. From `container/`: -Run tests: ```bash -source .venv/bin/activate && PYTHONPATH=container python3 -m pytest +go vet ./... +gofmt -l . # exits 0 with no output if everything is formatted +go test -race ./... ``` -Tests mock PostgreSQL and Kubernetes -- no live services are required. +Tests run against in-memory fakes for PostgreSQL, the GitHub API, and the +Kubernetes API — no live services are required. ## Deployment diff --git a/container-go/.dockerignore b/container-go/.dockerignore deleted file mode 100644 index 271bfc8..0000000 --- a/container-go/.dockerignore +++ /dev/null @@ -1,5 +0,0 @@ -*_test.go -.dockerignore -Dockerfile -README.md -CONTRACT.md diff --git a/container-go/CONTRACT.md b/container-go/CONTRACT.md deleted file mode 100644 index e0e5731..0000000 --- a/container-go/CONTRACT.md +++ /dev/null @@ -1,255 +0,0 @@ -# container-go: external behavior contract - -Frozen reference for the Go port. Source-of-truth citations are `container/*.py:LINE`. -DDL is in the root `README.md` ("Database schema"). - -## 1. HTTP surface - -### ghfe (port 8080) - -| Route | Method | Auth | Body / params | Success | Errors | -|---|---|---|---|---|---| -| `/health` | GET | — | — | 200 `ok` (text) | — | -| `/` | POST | HMAC-SHA256 + headers | webhook JSON | 200 text | 400 (bad headers/JSON), 401 (signature) | -| `/setup/org` | GET | — | `?installation_id=N` | 200 HTML | 400, 404, 502 | -| `/setup/personal` | GET | — | `?installation_id=N` | 200 HTML | 400, 404, 502 | -| `/trace/entity/` | GET | `Authorization: Bearer $TRACE_API_SECRET` | path int | 200 JSON `{"events":[...]}` | 401, 404 | -| `/trace/installation/` | GET | bearer | path int | 200 JSON `{"events":[...]}` | 401, 404 | -| `/trace/job/` | GET | bearer | path int | 200 JSON `{"events":[...]}` | 401, 404 | -| `/trace/payload/` | GET | bearer | path int (event id) | 200 JSON `{"payload":{...}}` | 401, 404 | - -Access log is **opt-in per request**: `g.print_perf_log` defaults `False` (`ghfe.py:49`) and is set `True` only inside the staging-proxy branch (`ghfe.py:513`) and once a `workflow_job` has cleared signature, entity, and label checks (`ghfe.py:568`). The `after_request` hook (`ghfe.py:53-65`) emits `"%s %s -> %d in %.1fms"` only when both `g.print_perf_log` is true and the request isn't `GET /health`. Setup, trace, ping, ignored workflow_job events, and invalid-payload responses produce no access log line. The Go port must preserve this — health checks and discarded webhooks stay silent at INFO. - -### scheduler (port 8080) - -| Route | Method | Query | Response | -|---|---|---|---| -| `/health` | GET | — | 200 text `ok` | -| `/usage`, `/usage.json` | GET | — | HTML / JSON `{"jobs":[...],"workers":[...]}` | -| `/history`, `/history.json` | GET | `start`, `end`, `page=0`, `per_page=100` | HTML / JSON array | -| `/jobs`, `/jobs.json` | GET | same as `/history` | HTML / JSON array | -| `/workers`, `/workers.json` | GET | `start`, `end`, `page=0`, `per_page=100` | HTML / JSON array | - -- `start`/`end` accept `YYYY-MM-DD` or `-Xd`. 400 on parse failure (`scheduler.py:774-787`, `821-834`). -- Paginated routes emit a GitHub-style `Link` header with `rel="first"|"prev"|"next"|"last"` (`scheduler.py:739-761`). -- Both binaries bind `0.0.0.0:8080` (`scheduler.py:888-889`). - -## 2. Webhook contract - -- Signature: `X-Hub-Signature-256: sha256=`, timing-safe compare (`ghfe.py:70-103`). -- Required headers: `X-GitHub-Event`, `X-Hub-Signature-256`, `X-GitHub-Hook-Installation-Target-Id` (int, must match `GHAPP_ORG_ID` or `GHAPP_PERSONAL_ID`). -- Accepted events: `ping`, `installation`, `installation_repositories`, `installation_target`, `workflow_job` (`ghfe.py:427-626`). -- `workflow_job` accepted actions: `queued`, `in_progress`, `completed`. Everything else → `IGNORED_ACTION` (`ghfe.py:502-505`). -- Trimmed payload (`ghfe.py:106-177`, constant `_WORKFLOW_JOB_DROP_KEYS`): - - `sender`, `repository.owner`: drop 11 `*_url` fields each. - - `repository`: drop 31 `*_url` fields plus `license`. - - `organization`: drop 8 `*_url` fields. - - `workflow_job`: drop `url`, `run_url`, `check_run_url`, `steps[]`. **Preserve `workflow_job.html_url`.** - -### Entity extraction by event type - -| Event | entity_name | entity_id | -|---|---|---| -| `ping` | (none — only `app_id` logged) | — | -| `installation` | `installation.account.login` | `installation.target_id` | -| `installation_repositories` | `installation.account.login` | `installation.target_id` | -| `installation_target` | `account.login` | `account.id` | -| `workflow_job` | `repository.owner.login` | `repository.owner.id` (both orgs and users; see `ghfe.py:authorize_entity`) | - -Every webhook delivery writes exactly one `installation_events` row carrying `source`, `event`, `outcome`, `payload` (full body), even on auth failure. - -## 3. Label → pool/image resolution - -From `match_labels_to_k8s` (`ghfe.py:226-255`). Returns `None` when no rule matches (caller emits `IGNORED_NO_LABEL`). - -| Org match | Label predicate | Pool | Image | -|---|---|---|---| -| PyTorch, or `riseproject-dev` + repo in `{pytorch, executorch}` | `linux.riscv64.xlarge` or `linux.riscv64.2xlarge` in labels | `scw-em-rv1` | `RUNNER_IMAGE_UBUNTU_24_04` | -| same | `ubuntu-24.04-riscv` in labels | `scw-em-rv1` | `RUNNER_IMAGE_UBUNTU_24_04` | -| GGML, or `riseproject-dev` + repo in `{llama.cpp, llama.cpp-validation}` | labels == `["ubuntu-24.04-riscv"]` exactly | `cloudv10x-jupiter` | `RUNNER_IMAGE_UBUNTU_24_04` | -| any other | labels == `["ubuntu-24.04-riscv"]` exactly | `scw-em-rv1` | `RUNNER_IMAGE_UBUNTU_24_04` | -| any other | anything else (`ubuntu-26.04-riscv` etc.) | — | — (returns `None`) | - -Constants: `PYTORCH_ORG_ID`, `GGML_ORG_ORG_ID`, `RISEPROJECT_DEV_ORG_ID` from `constants.py`. - -## 4. Environment variables - -Existing (unchanged): - -| Var | Purpose | -|---|---| -| `PROD` | `"true"` → prod schema and routing branch active | -| `PROD_URL`, `STAGING_URL` | self-URLs; `STAGING_URL` is the proxy target | -| `POSTGRES_URL` | DSN | -| `K8S_KUBECONFIG` | YAML body (not a path); `yaml.safe_load` then `new_client_from_config_dict` (`k8s.py:24-26`) | -| `GHAPP_ORG_ID` | `2167633` (`constants.py:29`) | -| `GHAPP_ORG_PRIVATE_KEY` | RSA PEM | -| `GHAPP_PERSONAL_ID` | `3131217` (`constants.py:31`) | -| `GHAPP_PERSONAL_PRIVATE_KEY` | RSA PEM | -| `GHAPP_WEBHOOK_SECRET` | HMAC key | -| `TRACE_API_SECRET` | bearer for `/trace/*` | -| `LOGLEVEL` | default `INFO` | - -New for the Go cutover: - -| Var | Scope | Purpose | -|---|---|---| -| `GO_GHFE_URL` | Python ghfe | base URL of `ghfe-go`; required when `GO_GHFE_ROUTING` is non-empty | -| `GO_GHFE_ROUTING` | Python ghfe only | JSON `{"entities":[, …]}`; empty / unset = route nothing | - -## 5. Routing semantics - -Routing applies to **`workflow_job` webhooks only** (same scope as the existing staging proxy at `ghfe.py:509-522`). Every other event type — `ping`, `installation`, `installation_repositories`, `installation_target`, plus anything unrecognised — is handled locally by Python ghfe. There is a single scheduler deployment at any time; it consumes every row in the DB regardless of which ghfe wrote it. - -- **Staging proxy** (unchanged): inside the `workflow_job` branch, if `PROD` is true and the `(entity_id, repo_name)` pair is listed in `STAGING_ENTITIES`, the request is forwarded to `STAGING_URL`. Repoint `STAGING_URL` at the deployed Go ghfe staging URL to route staging traffic to Go. -- **`GO_GHFE_ROUTING`**: runs immediately after the staging proxy in the same `workflow_job` branch. If `GO_GHFE_URL` is set and `entity_id` (= `repository.owner.id`) is in the routing list, the raw body + headers (drop `Host`) are forwarded to `GO_GHFE_URL` with a 30s timeout; the response is returned verbatim. Otherwise Python handles the webhook as today. -- **Routing list parsing**: JSON `{"entities":[, …]}`. Each entry is the GitHub owner id. Parsed once in `constants.py` into `frozenset[int]`. Empty / unset = nothing routed. No DB lookup. -- **Scheduler**: single deployment, no entity filter. Cutover Python → Go is one image swap on the existing function, performed independently of ghfe routing. - -## 6. Database - -DDL lives in `README.md` ("Database schema") — that section is the source of truth and the runtime no longer auto-applies it. Notes the Go code must honour: - -- `search_path` is set to `prod` or `staging` on every borrowed connection. -- `LISTEN` channel: `{schema}_queue_event` (`db.py:847`). NOTIFY payload is `str(job_id)` (`db.py:314`). -- Status enum transitions are forward-only (`pending → running → completed|failed`); every `UPDATE` includes the status precondition. -- `jobs.k8s_pod` is `COALESCE`'d, never overwritten once set (`db.py:334, 366`). -- `failure_info` is JSONB with required key `version` (1 or 2). v2 is the only shape new code writes. - -### Functions to reimplement (signatures from `db.py`) - -Reads (`SELECT`): `get_pending_jobs`, `get_active_jobs`, `get_active_jobs_and_workers`, `get_workers_for_reconcile(terminal_lookback_seconds=3600)`, `get_pool_demand(entity_id, job_labels)`, `get_total_workers_for_entity(entity_id)`, `job_exists_for_pod(pod_name)`, `get_events_by_entity_id(entity_id)`, `get_entity_id_for_installation(installation_id)`, `get_entity_id_for_job(job_id)`. - -Writes: `add_job`, `mark_job_running`, `mark_job_completed`, `mark_job_failed`, `add_worker` (raises `DuplicateRunnerNameException` on PK collision), `mark_worker_running`, `mark_worker_completed`, `mark_worker_failed`, `mark_worker_orphaned`, `add_installation_event` (returns `BIGSERIAL id`). - -Listen: `wait_for_job(timeout)` — `select()` on the LISTEN conn, drains buffered NOTIFYs (`db.py:851-865`). - -## 7. GitHub App auth (`github.py`) - -- JWT (`github.py:36-43`): RS256, claims `iat=now()`, `exp=iat+600`, `iss=app_id`. -- Installation token: `POST https://api.github.com/app/installations/{installation_id}/access_tokens`, expect 201 (`github.py:46-74`). -- Cache: TTL = `60*59` (59 min), keyed by `(installation_id, app_id)`, `maxsize=1024`, LRU. -- JIT runner config: - - Org: `POST /orgs/{name}/actions/runners/generate-jitconfig` (`github.py:141-163`). - - Repo: `POST /repos/{full_name}/actions/runners/generate-jitconfig` (`github.py:166-188`). - - Body: `{name, runner_group_id, labels, work_folder: "../../../work"}`. Returns `encoded_jit_config` (201). -- Runner group ensure (`github.py:103-138`): GET groups → if absent, POST `{name, visibility:"all", allows_public_repositories:true}`. Returns group id. -- List runners: `_paginated_get` on org-group or repo URLs (`github.py:212-221`). Walk `Link: rel="next"`. -- Delete runner: DELETE org or repo URL; treat 204 and 404 as success (`github.py:224-247`). -- Get job info: `GET /repos/{full_name}/actions/jobs/{job_id}`, return body on 200 (`github.py:250-268`). - -## 8. Kubernetes pod manifest - -`provision_runner` (`k8s.py:29-104`) produces a pod with: - -| Field | Value | -|---|---| -| labels | `app=rise-riscv-runner`, `riseproject.dev/entity_id`, `riseproject.dev/entity_name`, `riseproject.dev/board` | -| nodeSelector | `riseproject.dev/board=` | -| activeDeadlineSeconds | `525600` | -| restartPolicy | `Never` | -| hostNetwork | `true` | -| containers | one (`name=runner`); no sidecar | -| securityContext.privileged | `true` | -| env | `RUNNER_WAIT_FOR_DOCKER_IN_SECONDS=60`, `RUNNER_JITCONFIG=` | -| resources.limits | `riseproject.com/runner=1`; **also** `ephemeral-storage=90Gi` iff `k8s_pool` starts with `scw-em-` (`k8s.py:46`) | -| volumes | two `emptyDir`: `docker-graph` → `/var/lib/docker`, `k0s` → `/var/lib/k0s` | -| namespace | `default` (also used by every other `k8s.py` op) | - -Other k8s operations: - -- `ListPods()` — `list_namespaced_pod(label_selector="app=rise-riscv-runner")`. -- `GetPodEvents(pod_name)` — `list_namespaced_event(field_selector=involvedObject.name=...)`, sorted by `last_timestamp || event_time || creation_timestamp`. -- `DeletePod(pod)` — `delete_namespaced_pod`; swallow 404. -- `KillPod(pod)` — patch `spec.activeDeadlineSeconds=1` (no delete; pod transitions to `Failed:DeadlineExceeded`). -- `CollectPodFailureInfo(pod, reason: FailureReason)` — returns `{version: 2, reason, pod_reason, pod_message, containers: {name: {exit_code, reason, message, logs}}, events: [{type, reason, message, count, first_seen, last_seen}], collect_error?}`. -- `AvailableSlots(pool)` — sum allocatable `riseproject.com/runner` over nodes matching `riseproject.dev/board=`, subtract count of `Pending|Running` runner pods on the same selector. Returns `Capacity{Total, Active, Available}`. - -## 9. Reconciliation algorithm - -`scheduler.py:429-462` — one tick, all five phases inside one `LOCK TABLE workers IN EXCLUSIVE MODE` critical section (`scheduler.py:865-867`). - -1. **Orphan sweep** (`sync_workers_state` step 1, ref. line 440 → fn 239-244): workers in `pending|running` with no matching pod → `mark_worker_orphaned` (status becomes `completed`, no failure_info). -2. **Pod-phase sync** (line 443 → 247-266): map K8s pod `phase` to DB status — `Running` → `mark_worker_running`; `Succeeded` → `mark_worker_completed`; `Failed` → `mark_worker_failed` with collected `failure_info`. -3. **Health checks** (line 453 → 269-373): group active workers by GitHub runner scope, fetch runner list from GitHub, classify each pod against timeouts and kill via `KillPod` when exceeded. -4. **GitHub-side cleanup** (line 458 → 376-403): delete runners on GitHub for workers whose DB row is terminal or missing. -5. **Terminal-pod GC** (line 462 → 406-426): delete K8s pods in `Succeeded|Failed` once `finished_at` age exceeds `POD_DELETE_GRACE_SECONDS`. - -Phases operate on a snapshot taken at the start of each phase — no cross-phase mutation reuse (invariant from `be1434c`). - -### Timeouts (`constants.py:43-46`) - -| Name | Value (s) | Applied in | -|---|---|---| -| `RUNNER_REGISTRATION_TIMEOUT_SECONDS` | 120 | `scheduler.py:328, 338, 364, 367` | -| `RUNNER_PENDING_TIMEOUT_SECONDS` | 600 | `scheduler.py:347` | -| `POD_PENDING_TIMEOUT_SECONDS` | 600 | `scheduler.py:315` | -| `POD_DELETE_GRACE_SECONDS` | 21600 | `scheduler.py:421` | - -## 10. Demand match - -`demand_match` (`scheduler.py:466-581`): - -1. Pull all `pending` jobs (FIFO). For each distinct `k8s_pool`, fetch `AvailableSlots` **once**, then decrement locally per provision (`scheduler.py:485-489, 578`). -2. Skip pool when `available_slots <= 0` (note: `<= 0`, not `== 0`; concurrent runs can push below). -3. Per job: refetch row, skip if no longer `pending`. -4. Compute `(job_count, worker_count)` via `get_pool_demand(entity_id, job_labels)`. Skip if `job_count <= worker_count` ("demand met"). -5. Check `entity_worker_count >= max_workers` (`ENTITY_CONFIG.get(entity_id, {"max_workers": 20})`). Skip if cap reached. -6. Generate runner name `{RUNNER_NAME_PREFIX}{rand9}` (`[a-z0-9]{9}`). Retry up to **5** times on `DuplicateRunnerNameException`. -7. On insert success, JIT runner config + `provision_runner`. On any failure, `add_worker` already created the row → `mark_worker_failed` with `failure_info.reason=pod_allocation_failure` (invariant from `9a9d611`). - -## 11. Operational defaults - -| Setting | Value | Source | -|---|---|---| -| Ports | `8080` (both binaries) | `scheduler.py:889` | -| Exit codes | `0` normal; `1` uncaught; `2` init failure (parity goal) | — | -| Scheduler poll interval | `15s` (`POLL_INTERVAL`), implemented via `db.wait_for_job(POLL_INTERVAL)` | `scheduler.py:27, 903` | -| Webhook proxy timeout | `30s` | `ghfe.py:519` | -| K8s namespace | `default` | `k8s.py` (all ops) | - -## 12. Logging convention - -`log/slog` text handler, level from `LOGLEVEL` (default `INFO`). Translation rule: - -- Keep the English message text from Python verbatim. -- Move every positional arg to a slog attr; pick a stable attr name (`job_id`, `pod_name`, `installation_id`, `k8s_pool`, `worker_status`, `runner_status`, `reason`, `outcome`, `source`, `status_code`). -- Identity is always passed as a single `internal.Entity` attr under the key `entity`; the text handler expands it to `entity.type`, `entity.name`, `entity.id` (in that order) via `Entity.LogValue`. Don't pass the three fields separately. -- Errors trail with `"err", err`. - -Examples: - -| Python (`scheduler.py:512`) | Go | -|---|---| -| `logger.info("Demand met for entity=%s entity_id=%s entity_type=%s labels=%s, jobs_count=%d workers_count=%d", entity_name, entity_id, entity_type, labels, job_count, worker_count)` | `slog.Info("Demand met for entity", "entity", j.Entity(), "labels", labels, "jobs_count", j, "workers_count", w)` | -| `logger.warning("Runner name %s collision, regenerating", candidate)` (`scheduler.py:540`) | `slog.Warn("Runner name collision, regenerating", "runner_name", candidate)` | -| `logger.error("kill_pod failed for %s: %s", pod.metadata.name, e)` (`scheduler.py:222`) | `slog.Error("kill_pod failed", "pod_name", name, "err", err)` | -| `logger.info("Stored job ...")` (any place storing) | `slog.Info("Stored job", "job_id", id, …)` | - -Access log line stays one record per request: `slog.Info("request", "method", m, "path", p, "status", s, "elapsed_ms", e)`. Health-check requests are not logged (parity with `ghfe.py:62`). - -Every webhook outcome (`OK`, `IGNORED_*`, `INVALID_SIGNATURE`, `INVALID_PAYLOAD`, `JOB_NOT_FOUND`, `UNAUTHENTICATED`, `INTERNAL_ERROR`) emits exactly one `installation_events` row (invariant from `b909123`); slog records mirror those outcomes by attribute, not by message wording. - -## 13. Hard-won invariants → tests - -Each row becomes a named Go test in Phase B; numbers are git SHAs documenting the bug that introduced the invariant. - -| SHA | Invariant | Go test | -|---|---|---| -| f264661 | `workflow_job` payloads strip ~70 `*_url`, `license`, `steps[]`; keep `workflow_job.html_url` | `TestTrimWorkflowJobPayload_DropsURLsLicenseSteps` | -| aae3ab3 | `ignored_no_label` events log only `workflow_job.{labels,html_url}` + `repository.full_name` | `TestIgnoredNoLabel_PayloadMinimized` | -| b909123 | every webhook + auth attempt writes one `installation_events` row with `source`, `outcome` | `TestInstallationEvents_RowPerOutcome` | -| 9de4c35 | pod spec has `hostNetwork=true` on every pool | `TestProvisionRunner_UsesHostNetwork` | -| 0028278 / 653a5ba | `/var/lib/k0s` and `/var/lib/docker` are `emptyDir` volumes | `TestProvisionRunner_EmptyDirVolumes` | -| 3286cf6 | `ephemeral-storage` requests/limits only on `scw-em-*` pools | `TestProvisionRunner_DiskLimitsOnlyOnScwEM` | -| 40476b8 | `available_slots <= 0` skip (not `== 0`); two concurrent loops can't push it negative | `TestDemandMatch_SkipsWhenSlotsNonPositive` | -| 4232868 | capacity fetched once per pool per iteration, decremented locally | `TestDemandMatch_CapacityFetchedOncePerPool` | -| 9a9d611 | failed `provision_runner` writes `failure_info.reason=pod_allocation_failure` | `TestProvisionRunner_FailureMarksWorker` | -| b9c25e0 | registered-but-offline runners past `RUNNER_REGISTRATION_TIMEOUT_SECONDS` are killed | `TestPhase3_OfflineRunnerPastTimeoutFails` | -| 83469ab | online runner idle past `RUNNER_PENDING_TIMEOUT_SECONDS` → `runner_idle` failure | `TestPhase3_OnlineIdleRunnerPastTimeoutFails` | -| be1434c | phases 1–5 operate on a snapshot taken at phase start | `TestSyncWorkersState_PhasesIsolated` | -| b081af0 | `/workers` renders both v1 and v2 `failure_info` shapes | `TestRenderWorker_RendersV1AndV2FailureInfo` | -| caf0e8a | `/workers` paginates 50/page with GitHub-style `Link` header | `TestWorkers_PaginationAndLinkHeader` | -| 5c5004f | single-container pod; no dind sidecar, no docker-certs volume, no `DOCKER_*` TLS env | `TestProvisionRunner_NoSidecar` | -| 1055cc8 | `/workers` JSON & HTML field name spelling matches existing UI consumers | `TestWorkers_FieldNames` | -| new | Python ghfe forwards `workflow_job` webhooks to `GO_GHFE_URL` iff `entity_id` is in `GO_GHFE_ROUTING`; non-`workflow_job` events are never routed; staging proxy runs before the routing check. | `test_ghfe_routes_workflow_job_when_entity_in_list` (pytest) | diff --git a/container-go/Dockerfile b/container-go/Dockerfile deleted file mode 100644 index 048bd2c..0000000 --- a/container-go/Dockerfile +++ /dev/null @@ -1,18 +0,0 @@ -FROM golang:1.26-trixie AS build -WORKDIR /src -COPY go.mod go.sum ./ -RUN go mod download -COPY . . -RUN CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o /out/ghfe ./cmd/ghfe \ - && CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o /out/scheduler ./cmd/scheduler - -FROM gcr.io/distroless/base-debian13 AS base -WORKDIR /app - -FROM base AS ghfe -COPY --from=build /out/ghfe /app/ghfe -ENTRYPOINT ["/app/ghfe"] - -FROM base AS scheduler -COPY --from=build /out/scheduler /app/scheduler -ENTRYPOINT ["/app/scheduler"] diff --git a/container-go/README.md b/container-go/README.md deleted file mode 100644 index 4adc8a0..0000000 --- a/container-go/README.md +++ /dev/null @@ -1,45 +0,0 @@ -# container-go - -Go reimplementation of `container/`. Two binaries: - -- `cmd/ghfe` — GitHub webhook frontend (port 8080). -- `cmd/scheduler` — reconciliation loop + read-only dashboards (port 8080). - -`CONTRACT.md` is the frozen behavior reference. The root `README.md` is the -source of truth for the database schema. - -## Routing during the cutover - -`container/ghfe.py` (Python) is the entry point GitHub sends webhooks to. -When `GO_GHFE_ROUTING` lists an entity (by name or id) and `GO_GHFE_URL` is -set, Python forwards that entity's webhooks here. Go ghfe processes them -normally — it does not read either of those env vars. - -The scheduler is a single deployment; cutover from Python → Go is one -image swap. Worker rows in the DB are scheduler-agnostic. - -## Layout - -``` -container-go/ - cmd/ - ghfe/ webhook + setup + trace + health - scheduler/ reconciler (5 phases), demand_match, /usage, /history, /workers - internal/ - constants.go Config, ENTITY_CONFIG, timeouts, image tags - contract.go shared types + DB/GitHub/Kube interfaces - db.go pgx-backed DB implementation - github.go GitHub App auth + REST client - k8s.go client-go pod ops + CollectPodFailureInfo - log.go slog init - testutil/ in-memory fakes shared by cmd/ tests -``` - -## Tests - -``` -go test -race ./... -``` - -`internal/k8s.go` is tested against `k8s.io/client-go/kubernetes/fake`. -`cmd/ghfe` and `cmd/scheduler` use the fakes in `internal/testutil/`. diff --git a/container-go/cmd/ghfe/webhook_test.go b/container-go/cmd/ghfe/webhook_test.go deleted file mode 100644 index 7d179f8..0000000 --- a/container-go/cmd/ghfe/webhook_test.go +++ /dev/null @@ -1,266 +0,0 @@ -package main - -import ( - "bytes" - "crypto/hmac" - "crypto/sha256" - "encoding/hex" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/riseproject-dev/riscv-runner-app/container-go/internal" - "github.com/riseproject-dev/riscv-runner-app/container-go/internal/testutil" -) - -const webhookSecret = "test-secret" - -func newTestApp() (*App, *testutil.FakeDB) { - db := testutil.NewFakeDB() - cfg := internal.Config{ - Prod: false, - WebhookSecret: webhookSecret, - ImageUbuntu24: "img24", - ImageUbuntu26: "img26", - RunnerPrefix: "rise-riscv-runner-staging-", - } - return &App{Config: cfg, DB: db, GH: &testutil.FakeGH{}}, db -} - -func signedRequest(t *testing.T, body []byte, event, appID string) *http.Request { - t.Helper() - mac := hmac.New(sha256.New, []byte(webhookSecret)) - mac.Write(body) - sig := "sha256=" + hex.EncodeToString(mac.Sum(nil)) - r := httptest.NewRequest("POST", "/", bytes.NewReader(body)) - r.Header.Set(internal.HookSignatureHeader, sig) - r.Header.Set(internal.HookEventHeader, event) - r.Header.Set(internal.HookAppIDHeader, appID) - return r -} - -// TestWebhook_SignatureMismatch verifies an unsigned request 401s and writes -// no installation_events row (signature is checked before the row gate). -func TestWebhook_SignatureMismatch(t *testing.T) { - app, db := newTestApp() - r := httptest.NewRequest("POST", "/", bytes.NewReader([]byte(`{}`))) - r.Header.Set(internal.HookSignatureHeader, "sha256=bogus") - r.Header.Set(internal.HookEventHeader, "workflow_job") - r.Header.Set(internal.HookAppIDHeader, "1") - w := httptest.NewRecorder() - app.handleWebhook(w, r) - if w.Code != 401 { - t.Fatalf("status=%d want 401", w.Code) - } - if len(db.Events) != 0 { - t.Fatalf("no event row expected, got %d", len(db.Events)) - } -} - -// TestWebhook_PingWritesEventRow covers the b909123 invariant for the ping path. -func TestWebhook_PingWritesEventRow(t *testing.T) { - app, db := newTestApp() - body := []byte(`{"zen":"hi"}`) - w := httptest.NewRecorder() - app.handleWebhook(w, signedRequest(t, body, "ping", "2167633")) - if w.Code != 200 { - t.Fatalf("status=%d", w.Code) - } - if len(db.Events) != 1 { - t.Fatalf("expected 1 event row, got %d", len(db.Events)) - } - if db.Events[0].Row.Event != "ping" || db.Events[0].Row.Outcome != string(internal.OutcomeOK) { - t.Errorf("row=%+v", db.Events[0].Row) - } -} - -// TestWebhook_InstallationEvents covers each of: installation, installation_repositories, -// installation_target, ignored_event. Each must produce exactly one row. -func TestWebhook_InstallationEvents(t *testing.T) { - cases := []struct { - event string - payload map[string]any - want string - }{ - { - event: "installation", - payload: map[string]any{ - "action": "created", - "installation": map[string]any{ - "id": float64(1), - "target_id": float64(99), - "target_type": "Organization", - "account": map[string]any{"login": "org"}, - }, - }, - want: "installation.created", - }, - { - event: "installation_repositories", - payload: map[string]any{ - "action": "added", - "installation": map[string]any{ - "id": float64(1), - "target_id": float64(99), - "target_type": "Organization", - "account": map[string]any{"login": "org"}, - }, - }, - want: "installation_repositories.added", - }, - { - event: "installation_target", - payload: map[string]any{ - "action": "renamed", - "target_type": "Organization", - "account": map[string]any{"id": float64(42), "login": "new"}, - "installation": map[string]any{"id": float64(1)}, - }, - want: "installation_target.renamed", - }, - { - event: "unknown_event", - payload: map[string]any{}, - want: "unknown_event", - }, - } - for _, tc := range cases { - t.Run(tc.event, func(t *testing.T) { - app, db := newTestApp() - body, _ := json.Marshal(tc.payload) - w := httptest.NewRecorder() - app.handleWebhook(w, signedRequest(t, body, tc.event, "2167633")) - if w.Code != 200 { - t.Fatalf("status=%d", w.Code) - } - if len(db.Events) != 1 { - t.Fatalf("expected 1 event row, got %d", len(db.Events)) - } - got := db.Events[0].Row.Event - if got != tc.want { - t.Errorf("event=%q want %q", got, tc.want) - } - }) - } -} - -// TestWebhook_WorkflowJob_IgnoredAction asserts unrecognised actions trip the -// ignored_action outcome. -func TestWebhook_WorkflowJob_IgnoredAction(t *testing.T) { - app, db := newTestApp() - body := mustJSON(map[string]any{ - "action": "waiting", - "installation": map[string]any{"id": float64(1)}, - "repository": map[string]any{ - "id": float64(2), "full_name": "x/y", - "owner": map[string]any{"id": float64(99), "type": "Organization", "login": "x"}, - }, - "workflow_job": map[string]any{"id": float64(7), "labels": []any{"ubuntu-24.04-riscv"}}, - }) - w := httptest.NewRecorder() - app.handleWebhook(w, signedRequest(t, body, "workflow_job", "2167633")) - if len(db.Events) != 1 || db.Events[0].Row.Outcome != string(internal.OutcomeIgnoredAction) { - t.Fatalf("expected ignored_action, got rows=%+v", db.Events) - } -} - -// TestIgnoredNoLabel_PayloadMinimized verifies aae3ab3: ignored_no_label -// keeps only workflow_job.{labels,html_url} + repository.full_name. -func TestIgnoredNoLabel_PayloadMinimized(t *testing.T) { - app, db := newTestApp() - body := mustJSON(map[string]any{ - "action": "queued", - "installation": map[string]any{"id": float64(1)}, - "repository": map[string]any{ - "id": float64(2), "full_name": "x/y", "url": "drop", - "owner": map[string]any{"id": float64(99), "type": "Organization", "login": "x", "url": "drop"}, - }, - "workflow_job": map[string]any{ - "id": float64(7), - "labels": []any{"ubuntu-26.04-riscv"}, - "html_url": "https://example.com", - "url": "drop", - "steps": []any{"a"}, - }, - }) - w := httptest.NewRecorder() - app.handleWebhook(w, signedRequest(t, body, "workflow_job", "2167633")) - if len(db.Events) != 1 || db.Events[0].Row.Outcome != string(internal.OutcomeIgnoredNoLabel) { - t.Fatalf("expected ignored_no_label, rows=%+v", db.Events) - } - var payload map[string]any - if err := json.Unmarshal(db.Events[0].Payload, &payload); err != nil { - t.Fatal(err) - } - job, _ := payload["workflow_job"].(map[string]any) - if job["url"] != nil { - t.Errorf("workflow_job.url leaked") - } - if job["html_url"] != "https://example.com" { - t.Errorf("html_url lost: %v", job["html_url"]) - } - if _, has := job["steps"]; has { - t.Errorf("steps leaked") - } - if _, has := payload["sender"]; has { - t.Errorf("sender leaked") - } - repo, _ := payload["repository"].(map[string]any) - if repo["full_name"] != "x/y" { - t.Errorf("repository.full_name lost: %v", repo["full_name"]) - } - if _, has := repo["url"]; has { - t.Errorf("repository.url leaked") - } -} - -// TestWebhook_QueuedJobStored asserts a valid queued event writes an event -// row with job_stored and persists the job. -func TestWebhook_QueuedJobStored(t *testing.T) { - app, db := newTestApp() - body := mustJSON(map[string]any{ - "action": "queued", - "installation": map[string]any{"id": float64(1)}, - "repository": map[string]any{ - "id": float64(2), "full_name": "x/y", - "owner": map[string]any{"id": float64(99), "type": "Organization", "login": "x"}, - }, - "workflow_job": map[string]any{ - "id": float64(7), - "name": "build", - "labels": []any{"ubuntu-24.04-riscv"}, - "html_url": "https://example.com", - }, - }) - w := httptest.NewRecorder() - app.handleWebhook(w, signedRequest(t, body, "workflow_job", "2167633")) - if w.Code != 200 { - t.Fatalf("status=%d", w.Code) - } - if len(db.Jobs) != 1 || db.Jobs[0].JobID != 7 { - t.Fatalf("job not stored: %+v", db.Jobs) - } - if len(db.Events) != 1 || db.Events[0].Row.Outcome != string(internal.OutcomeJobStored) { - t.Errorf("event row mismatch: %+v", db.Events) - } -} - -// TestWebhook_BodyTooShortForSignature errors out with 400/401 not a panic. -func TestWebhook_MissingHeaders(t *testing.T) { - app, _ := newTestApp() - r := httptest.NewRequest("POST", "/", bytes.NewReader([]byte(`{}`))) - w := httptest.NewRecorder() - app.handleWebhook(w, r) - if w.Code != 400 { - t.Fatalf("expected 400 on missing event header, got %d", w.Code) - } -} - -func mustJSON(v any) []byte { - b, err := json.Marshal(v) - if err != nil { - panic(err) - } - return b -} diff --git a/container-go/cmd/scheduler/handlers_test.go b/container-go/cmd/scheduler/handlers_test.go deleted file mode 100644 index 1ab47a4..0000000 --- a/container-go/cmd/scheduler/handlers_test.go +++ /dev/null @@ -1,67 +0,0 @@ -package main - -import ( - "net/http/httptest" - "strings" - "testing" - - "github.com/riseproject-dev/riscv-runner-app/container-go/internal" -) - -// TestWorkers_PaginationAndLinkHeader locks invariant caf0e8a: /workers.json -// emits the GitHub-style Link header with rel="next"/"prev". -func TestWorkers_PaginationAndLinkHeader(t *testing.T) { - app, db, _, _ := schedTestApp() - for i := 0; i < 250; i++ { - db.Workers = append(db.Workers, internal.Worker{PodName: "p", Status: "completed"}) - } - - // Reset GetAllWorkers paging using a shim that respects perPage so total > perPage. - db.Workers = db.Workers[:50] // We rely on FakeDB returning total=len(Workers) for any page. - - // Page 0 of 50 with per_page=10 → expect next + last links, no prev/first. - r := httptest.NewRequest("GET", "/workers.json?per_page=10&page=0", nil) - w := httptest.NewRecorder() - app.handleWorkers(w, r) - link := w.Header().Get("Link") - if !strings.Contains(link, `rel="next"`) || !strings.Contains(link, `rel="last"`) { - t.Fatalf("page 0 link header missing next/last: %q", link) - } - if strings.Contains(link, `rel="prev"`) || strings.Contains(link, `rel="first"`) { - t.Fatalf("page 0 link header should not contain prev/first: %q", link) - } - - // Page 2 of 50 with per_page=10 → expect both directions. - r2 := httptest.NewRequest("GET", "/workers.json?per_page=10&page=2", nil) - w2 := httptest.NewRecorder() - app.handleWorkers(w2, r2) - link2 := w2.Header().Get("Link") - for _, rel := range []string{`rel="first"`, `rel="prev"`, `rel="next"`, `rel="last"`} { - if !strings.Contains(link2, rel) { - t.Errorf("middle page link header missing %s: %q", rel, link2) - } - } - - // Page 4 (final) of 50 with per_page=10 → only prev/first. - r3 := httptest.NewRequest("GET", "/workers.json?per_page=10&page=4", nil) - w3 := httptest.NewRecorder() - app.handleWorkers(w3, r3) - link3 := w3.Header().Get("Link") - if !strings.Contains(link3, `rel="prev"`) || !strings.Contains(link3, `rel="first"`) { - t.Errorf("last page link header missing prev/first: %q", link3) - } - if strings.Contains(link3, `rel="next"`) { - t.Errorf("last page link header should not contain next: %q", link3) - } -} - -// TestHandlers_HealthOK is a smoke test for /health. -func TestHandlers_HealthOK(t *testing.T) { - app, _, _, _ := schedTestApp() - r := httptest.NewRequest("GET", "/health", nil) - w := httptest.NewRecorder() - app.handleHealth(w, r) - if w.Code != 200 || strings.TrimSpace(w.Body.String()) != "ok" { - t.Fatalf("health response: %d %q", w.Code, w.Body.String()) - } -} diff --git a/container-go/cmd/scheduler/sync_workers_test.go b/container-go/cmd/scheduler/sync_workers_test.go deleted file mode 100644 index db47357..0000000 --- a/container-go/cmd/scheduler/sync_workers_test.go +++ /dev/null @@ -1,102 +0,0 @@ -package main - -import ( - "context" - "testing" - "time" - - "github.com/riseproject-dev/riscv-runner-app/container-go/internal" -) - -// pendingWorker builds a Worker + matching Pod ready for the phase-3 tests. -func pendingWorker(name string, runningAt *time.Time) internal.Worker { - return internal.Worker{ - PodName: name, Provider: "github", - EntityID: 1, EntityName: "e", EntityType: "Organization", - InstallationID: 9, K8sPool: "scw-em-rv1", K8sImage: "img", - Status: "running", RunningAt: runningAt, - } -} - -func runningPod(name string) internal.Pod { - now := time.Now().UTC() - return internal.Pod{ - Name: name, Phase: "Running", CreationTime: now.Add(-30 * time.Minute), - Containers: []internal.ContainerStatus{{Name: "runner", Running: true, RunningStarted: &now}}, - } -} - -// TestPhase3_OfflineRunnerPastTimeoutFails covers b9c25e0: a GH runner in -// "offline" status past RUNNER_REGISTRATION_TIMEOUT_SECONDS gets killed. -func TestPhase3_OfflineRunnerPastTimeoutFails(t *testing.T) { - app, db, gh, kube := schedTestApp() - stale := time.Now().Add(-2 * internal.RunnerRegistrationTimeout) - - w := pendingWorker("rise-riscv-runner-staging-abc", &stale) - pod := runningPod(w.PodName) - db.Workers = []internal.Worker{w} - db.WorkerStatus[w.PodName] = "running" - kube.PodsByName[w.PodName] = pod - - gh.OnListRunnersOrgGroup = func(_, _ string, _ int64) ([]internal.GHRunner, error) { - return []internal.GHRunner{{ID: 1, Name: w.PodName, Status: "offline", Busy: false}}, nil - } - - if err := app.syncWorkersState(context.Background()); err != nil { - t.Fatalf("syncWorkersState: %v", err) - } - if len(db.MarkFailed) != 1 { - t.Fatalf("expected MarkWorkerFailed, got %v", db.MarkFailed) - } - if db.MarkFailed[0].Info.Reason != internal.ReasonRunnerNeverRegistered { - t.Errorf("reason=%q want runner_never_registered", db.MarkFailed[0].Info.Reason) - } - if len(kube.KillCalls) != 1 || kube.KillCalls[0] != w.PodName { - t.Errorf("expected KillPod for %s, got %v", w.PodName, kube.KillCalls) - } -} - -// TestPhase3_OnlineIdleRunnerPastTimeoutFails covers 83469ab: a runner -// idle past RUNNER_PENDING_TIMEOUT_SECONDS yields a runner_idle failure. -func TestPhase3_OnlineIdleRunnerPastTimeoutFails(t *testing.T) { - app, db, gh, kube := schedTestApp() - stale := time.Now().Add(-2 * internal.RunnerPendingTimeout) - - w := pendingWorker("rise-riscv-runner-staging-xyz", &stale) - pod := runningPod(w.PodName) - db.Workers = []internal.Worker{w} - db.WorkerStatus[w.PodName] = "running" - kube.PodsByName[w.PodName] = pod - - gh.OnListRunnersOrgGroup = func(_, _ string, _ int64) ([]internal.GHRunner, error) { - return []internal.GHRunner{{ID: 2, Name: w.PodName, Status: "online", Busy: false}}, nil - } - - if err := app.syncWorkersState(context.Background()); err != nil { - t.Fatalf("syncWorkersState: %v", err) - } - if len(db.MarkFailed) != 1 || db.MarkFailed[0].Info.Reason != internal.ReasonRunnerIdle { - t.Fatalf("expected RunnerIdle failure, got %v", db.MarkFailed) - } -} - -// TestSyncWorkersState_PhasesIsolated covers be1434c: an orphan in phase 1 -// produces a `completed` worker, then phase 2 still observes the pod-less -// view and does nothing further. (Phases re-fetch their snapshot.) -func TestSyncWorkersState_PhasesIsolated(t *testing.T) { - app, db, _, _ := schedTestApp() - w := pendingWorker("rise-riscv-runner-staging-orphan", nil) - w.Status = "running" - db.Workers = []internal.Worker{w} - db.WorkerStatus[w.PodName] = "running" - - if err := app.syncWorkersState(context.Background()); err != nil { - t.Fatalf("sync: %v", err) - } - if len(db.MarkOrphaned) != 1 || db.MarkOrphaned[0] != w.PodName { - t.Fatalf("expected exactly one MarkWorkerOrphaned, got %v", db.MarkOrphaned) - } - if len(db.MarkFailed) != 0 { - t.Errorf("phase 1 should not produce a MarkFailed call: %v", db.MarkFailed) - } -} diff --git a/container-go/cmd/scheduler/templates_test.go b/container-go/cmd/scheduler/templates_test.go deleted file mode 100644 index 42999bb..0000000 --- a/container-go/cmd/scheduler/templates_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package main - -import ( - "encoding/json" - "net/http/httptest" - "strings" - "testing" - - "github.com/riseproject-dev/riscv-runner-app/container-go/internal" -) - -// TestRenderWorker_RendersV1AndV2FailureInfo locks b081af0: render_worker -// produces text for both shapes — v1 has no reason line, v2 includes reason, -// pod_reason/message, container exit codes/logs, events. -func TestRenderWorker_RendersV1AndV2FailureInfo(t *testing.T) { - app, _, _, _ := schedTestApp() - - v1 := json.RawMessage(`{"version":1,"message":"old"}`) - w := internal.Worker{ - PodName: "p1", - Status: "failed", - FailureInfo: v1, - } - rendered := app.renderWorker(httptest.NewRequest("GET", "/", nil), w) - joined := strings.Join(rendered, "\n") - if strings.Contains(joined, "Reason:") { - t.Errorf("v1 should not render a Reason: line:\n%s", joined) - } - - v2 := json.RawMessage(`{ - "version": 2, - "reason": "pod_failed", - "pod_reason": "OOMKilled", - "pod_message": "out of memory", - "containers": { - "runner": {"exit_code": 137, "reason": "OOMKilled", "message": "kill -9", "logs": "boom\nboom2"} - }, - "events": [{"type":"Warning","reason":"Failed","message":"oops","last_seen":"2025-01-01"}] - }`) - w2 := internal.Worker{PodName: "p2", Status: "failed", FailureInfo: v2} - rendered2 := app.renderWorker(httptest.NewRequest("GET", "/", nil), w2) - joined2 := strings.Join(rendered2, "\n") - wants := []string{"Reason: pod_failed", "Pod: OOMKilled", "Container runner: exit=137", "boom", "Failed: oops"} - for _, want := range wants { - if !strings.Contains(joined2, want) { - t.Errorf("v2 output missing %q:\n%s", want, joined2) - } - } -} - -// TestWorkers_FieldNames locks invariant 1055cc8 — the JSON serialisation of -// internal.Worker must keep the field names UI consumers expect. -func TestWorkers_FieldNames(t *testing.T) { - w := internal.Worker{PodName: "p", Status: "pending"} - b, _ := json.Marshal(w) - s := string(b) - for _, want := range []string{ - `"pod_name"`, - `"status"`, - `"job_labels"`, - `"k8s_pool"`, - `"k8s_image"`, - `"entity_id"`, - `"entity_name"`, - `"installation_id"`, - `"created_at"`, - } { - if !strings.Contains(s, want) { - t.Errorf("missing field %s in %s", want, s) - } - } -} diff --git a/container-go/internal/k8s_test.go b/container-go/internal/k8s_test.go deleted file mode 100644 index bed9525..0000000 --- a/container-go/internal/k8s_test.go +++ /dev/null @@ -1,191 +0,0 @@ -package internal - -import ( - "context" - "strings" - "testing" - - corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/resource" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/client-go/kubernetes/fake" -) - -// fakePod returns the pod the fake clientset stored under runnerName. -func fakePod(t *testing.T, k *K8sClient, runnerName string) *corev1.Pod { - t.Helper() - pod, err := k.cs.CoreV1().Pods("default").Get(context.Background(), runnerName, metav1.GetOptions{}) - if err != nil { - t.Fatalf("get pod: %v", err) - } - return pod -} - -// TestProvisionRunner_UsesHostNetwork asserts pod.spec.hostNetwork=true on every -// pool — invariant 9de4c35. -func TestProvisionRunner_UsesHostNetwork(t *testing.T) { - for _, pool := range []string{"scw-em-rv1", "cloudv10x-jupiter"} { - k := NewK8sClientFromInterface(fake.NewSimpleClientset()) - if err := k.ProvisionRunner(context.Background(), "jit", "runner-"+pool, "img", pool, Entity{ID: 1, Name: "ent"}); err != nil { - t.Fatalf("provision: %v", err) - } - p := fakePod(t, k, "runner-"+pool) - if !p.Spec.HostNetwork { - t.Errorf("pool=%s expected HostNetwork=true", pool) - } - } -} - -// TestProvisionRunner_EmptyDirVolumes asserts the two emptyDir volumes for -// /var/lib/docker and /var/lib/k0s exist on every pool (invariants 0028278/653a5ba). -func TestProvisionRunner_EmptyDirVolumes(t *testing.T) { - k := NewK8sClientFromInterface(fake.NewSimpleClientset()) - if err := k.ProvisionRunner(context.Background(), "jit", "r", "img", "scw-em-rv1", Entity{ID: 1, Name: "ent"}); err != nil { - t.Fatalf("provision: %v", err) - } - p := fakePod(t, k, "r") - type mount struct { - name string - path string - } - want := []mount{{"docker-graph", "/var/lib/docker"}, {"k0s", "/var/lib/k0s"}} - if len(p.Spec.Containers) != 1 { - t.Fatalf("expected single container, got %d", len(p.Spec.Containers)) - } - for _, m := range want { - var foundVolume, foundMount bool - for _, v := range p.Spec.Volumes { - if v.Name == m.name && v.EmptyDir != nil { - foundVolume = true - } - } - for _, vm := range p.Spec.Containers[0].VolumeMounts { - if vm.Name == m.name && vm.MountPath == m.path { - foundMount = true - } - } - if !foundVolume { - t.Errorf("volume %s emptyDir not found", m.name) - } - if !foundMount { - t.Errorf("volumeMount %s at %s not found", m.name, m.path) - } - } -} - -// TestProvisionRunner_DiskLimitsOnlyOnScwEM asserts ephemeral-storage=90Gi -// only on scw-em-* pools (invariant 3286cf6). -func TestProvisionRunner_DiskLimitsOnlyOnScwEM(t *testing.T) { - tests := []struct { - pool string - wantDisk bool - }{ - {"scw-em-rv1", true}, - {"scw-em-something", true}, - {"cloudv10x-jupiter", false}, - } - for _, tc := range tests { - k := NewK8sClientFromInterface(fake.NewSimpleClientset()) - if err := k.ProvisionRunner(context.Background(), "jit", "r-"+tc.pool, "img", tc.pool, Entity{ID: 1, Name: "ent"}); err != nil { - t.Fatalf("[%s] provision: %v", tc.pool, err) - } - p := fakePod(t, k, "r-"+tc.pool) - limits := p.Spec.Containers[0].Resources.Limits - _, has := limits["ephemeral-storage"] - if has != tc.wantDisk { - t.Errorf("pool=%s ephemeral-storage present=%v want=%v", tc.pool, has, tc.wantDisk) - } - if has { - q := limits["ephemeral-storage"] - want := resource.MustParse("90Gi") - if q.Cmp(want) != 0 { - t.Errorf("pool=%s ephemeral-storage=%s want 90Gi", tc.pool, q.String()) - } - } - if _, has := limits["riseproject.com/runner"]; !has { - t.Errorf("pool=%s runner limit missing", tc.pool) - } - } -} - -// TestProvisionRunner_NoSidecar asserts pod has exactly one container, no -// docker-certs volume, no DOCKER_* env (invariant 5c5004f). -func TestProvisionRunner_NoSidecar(t *testing.T) { - k := NewK8sClientFromInterface(fake.NewSimpleClientset()) - if err := k.ProvisionRunner(context.Background(), "jit", "r", "img", "scw-em-rv1", Entity{ID: 1, Name: "ent"}); err != nil { - t.Fatalf("provision: %v", err) - } - p := fakePod(t, k, "r") - if len(p.Spec.Containers) != 1 { - t.Fatalf("expected single container, got %d", len(p.Spec.Containers)) - } - for _, v := range p.Spec.Volumes { - if strings.Contains(v.Name, "docker-cert") { - t.Errorf("docker-certs volume %s leaked into spec", v.Name) - } - } - for _, e := range p.Spec.Containers[0].Env { - if strings.HasPrefix(e.Name, "DOCKER_") { - t.Errorf("DOCKER_* env leaked: %s", e.Name) - } - } - // Required env present - mustHaveEnv(t, p.Spec.Containers[0].Env, "RUNNER_WAIT_FOR_DOCKER_IN_SECONDS", "60") - mustHaveEnv(t, p.Spec.Containers[0].Env, "RUNNER_JITCONFIG", "jit") -} - -func mustHaveEnv(t *testing.T, env []corev1.EnvVar, name, value string) { - t.Helper() - for _, e := range env { - if e.Name == name { - if e.Value != value { - t.Errorf("env %s=%q want %q", name, e.Value, value) - } - return - } - } - t.Errorf("env %s missing", name) -} - -// TestProvisionRunner_Labels asserts the four pod labels are set. -func TestProvisionRunner_Labels(t *testing.T) { - k := NewK8sClientFromInterface(fake.NewSimpleClientset()) - if err := k.ProvisionRunner(context.Background(), "jit", "r", "img", "scw-em-rv1", Entity{ID: 42, Name: "pytorch"}); err != nil { - t.Fatalf("provision: %v", err) - } - p := fakePod(t, k, "r") - want := map[string]string{ - "app": "rise-riscv-runner", - "riseproject.dev/entity_id": "42", - "riseproject.dev/entity_name": "pytorch", - "riseproject.dev/board": "scw-em-rv1", - } - for k, v := range want { - if p.Labels[k] != v { - t.Errorf("label %s=%q want %q", k, p.Labels[k], v) - } - } - if p.Spec.NodeSelector["riseproject.dev/board"] != "scw-em-rv1" { - t.Errorf("nodeSelector board mismatch: %v", p.Spec.NodeSelector) - } -} - -// TestProvisionRunner_TimeoutsAndPrivileged asserts the lesser invariants: -// activeDeadlineSeconds=525600, restartPolicy=Never, container privileged=true. -func TestProvisionRunner_TimeoutsAndPrivileged(t *testing.T) { - k := NewK8sClientFromInterface(fake.NewSimpleClientset()) - if err := k.ProvisionRunner(context.Background(), "jit", "r", "img", "scw-em-rv1", Entity{ID: 1, Name: "ent"}); err != nil { - t.Fatalf("provision: %v", err) - } - p := fakePod(t, k, "r") - if p.Spec.ActiveDeadlineSeconds == nil || *p.Spec.ActiveDeadlineSeconds != 525600 { - t.Errorf("activeDeadlineSeconds=%v want 525600", p.Spec.ActiveDeadlineSeconds) - } - if p.Spec.RestartPolicy != corev1.RestartPolicyNever { - t.Errorf("restartPolicy=%v want Never", p.Spec.RestartPolicy) - } - sc := p.Spec.Containers[0].SecurityContext - if sc == nil || sc.Privileged == nil || !*sc.Privileged { - t.Errorf("container not privileged") - } -} diff --git a/container/.dockerignore b/container/.dockerignore new file mode 100644 index 0000000..cf950c4 --- /dev/null +++ b/container/.dockerignore @@ -0,0 +1,9 @@ +*_test.go +.dockerignore +Dockerfile +README.md +node_modules/ +.serverless/ +serverless.yml +package.json +package-lock.json diff --git a/container/Dockerfile b/container/Dockerfile index 95c5acb..048bd2c 100644 --- a/container/Dockerfile +++ b/container/Dockerfile @@ -1,18 +1,18 @@ -FROM python:3.12-slim-trixie AS base -WORKDIR /app - -RUN pip3 install --no-cache --upgrade pip +FROM golang:1.26-trixie AS build +WORKDIR /src +COPY go.mod go.sum ./ +RUN go mod download +COPY . . +RUN CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o /out/ghfe ./cmd/ghfe \ + && CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o /out/scheduler ./cmd/scheduler -COPY requirements.txt . -RUN pip3 install --no-cache -r requirements.txt --target . - -ENV PATH="/app/bin:${PATH}" -ENV PYTHONPATH="/app" +FROM gcr.io/distroless/base-debian13 AS base +WORKDIR /app -COPY *.py ./ +FROM base AS ghfe +COPY --from=build /out/ghfe /app/ghfe +ENTRYPOINT ["/app/ghfe"] FROM base AS scheduler -CMD [ "python3", "scheduler.py" ] - -FROM base AS ghfe -CMD [ "python3", "ghfe.py" ] +COPY --from=build /out/scheduler /app/scheduler +ENTRYPOINT ["/app/scheduler"] diff --git a/container/README.md b/container/README.md new file mode 100644 index 0000000..660bd45 --- /dev/null +++ b/container/README.md @@ -0,0 +1,37 @@ +# container + +Go implementation of the GitHub webhook frontend and scheduler. Two binaries: + +- `cmd/ghfe` — GitHub webhook frontend (port 8080). +- `cmd/scheduler` — reconciliation loop + read-only dashboards (port 8080). + +The root `README.md` is the source of truth for architecture, the database +schema, and the deployed HTTP routes. + +## Layout + +``` +container/ + cmd/ + ghfe/ webhook + setup + trace + health + scheduler/ reconciler (5 phases), demand_match, /usage, /history, /jobs, /workers + internal/ + constants.go Config, EntityConfigs, timeouts, image tags + contract.go shared types, WebhookOutcome enum, DB/GitHub/Kube interfaces + db.go pgx-backed DB implementation + github.go GitHub App auth + REST client + k8s.go client-go pod ops + CollectPodFailureInfo + log.go slog init + testutil/ in-memory fakes shared by cmd/ tests +``` + +The Go module path is `github.com/riseproject-dev/riscv-runner-app/container`. + +## Tests + +``` +go test -race ./... +``` + +`internal/k8s.go` is tested against `k8s.io/client-go/kubernetes/fake`. +`cmd/ghfe` and `cmd/scheduler` use the fakes in `internal/testutil/`. diff --git a/container-go/cmd/ghfe/main.go b/container/cmd/ghfe/main.go similarity index 96% rename from container-go/cmd/ghfe/main.go rename to container/cmd/ghfe/main.go index 511270c..c543289 100644 --- a/container-go/cmd/ghfe/main.go +++ b/container/cmd/ghfe/main.go @@ -1,5 +1,4 @@ // Command ghfe runs the GitHub webhook frontend. -// See container-go/CONTRACT.md for the full HTTP and webhook surface. package main import ( @@ -12,7 +11,7 @@ import ( "syscall" "time" - "github.com/riseproject-dev/riscv-runner-app/container-go/internal" + "github.com/riseproject-dev/riscv-runner-app/container/internal" ) func main() { diff --git a/container/cmd/ghfe/main_test.go b/container/cmd/ghfe/main_test.go new file mode 100644 index 0000000..1144b88 --- /dev/null +++ b/container/cmd/ghfe/main_test.go @@ -0,0 +1,114 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/riseproject-dev/riscv-runner-app/container/internal" + "github.com/riseproject-dev/riscv-runner-app/container/internal/testutil" +) + +func TestHandleHealth(t *testing.T) { + app := &App{Config: internal.Config{}, DB: testutil.NewFakeDB(), GH: &testutil.FakeGH{}} + r := httptest.NewRequest("GET", "/health", nil) + w := httptest.NewRecorder() + app.handleHealth(w, r) + if w.Code != 200 { + t.Fatalf("status=%d", w.Code) + } + if w.Body.String() != "ok" { + t.Errorf("body=%q", w.Body.String()) + } +} + +func TestRoutes_HealthAndSetupRegistered(t *testing.T) { + app := &App{Config: internal.Config{}, DB: testutil.NewFakeDB(), GH: &testutil.FakeGH{}} + mux := app.Routes() + + for _, c := range []struct { + method, path string + wantStatus int + }{ + {"GET", "/health", 200}, + {"GET", "/setup/org", 400}, // no installation_id → renderMissing + {"GET", "/setup/personal", 400}, // ditto + } { + r := httptest.NewRequest(c.method, c.path, nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, r) + if w.Code != c.wantStatus { + t.Errorf("%s %s: status=%d want %d body=%q", c.method, c.path, w.Code, c.wantStatus, w.Body.String()) + } + } +} + +func TestHTTPError_StatusMapping(t *testing.T) { + cases := []struct { + status int + want string + }{ + {200, "two-hundred"}, + {400, "client"}, + {500, "server"}, + } + for _, c := range cases { + w := httptest.NewRecorder() + httpError(w, c.status, c.want) + if w.Code != c.status { + t.Errorf("status=%d want %d", w.Code, c.status) + } + if w.Body.String() != c.want { + t.Errorf("body=%q want %q", w.Body.String(), c.want) + } + if ct := w.Header().Get("Content-Type"); !strings.HasPrefix(ct, "text/plain") { + t.Errorf("content-type=%q", ct) + } + } +} + +func TestWithPerfLog_EmitsOnlyWhenEnabled(t *testing.T) { + app := &App{Config: internal.Config{}, DB: testutil.NewFakeDB(), GH: &testutil.FakeGH{}} + + // Handler that opts in + hOn := app.withPerfLog(func(w http.ResponseWriter, r *http.Request) { + enablePerfLog(r) + w.WriteHeader(202) + _, _ = w.Write([]byte("done")) + }) + r := httptest.NewRequest("POST", "/foo", strings.NewReader("")) + w := httptest.NewRecorder() + hOn(w, r) + if w.Code != 202 || w.Body.String() != "done" { + t.Errorf("status=%d body=%q", w.Code, w.Body.String()) + } + + // Handler that does not opt in (covers the false-branch of the perf-log gate) + hOff := app.withPerfLog(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("silent")) + }) + r = httptest.NewRequest("POST", "/bar", strings.NewReader("")) + w = httptest.NewRecorder() + hOff(w, r) + if w.Body.String() != "silent" { + t.Errorf("body=%q", w.Body.String()) + } + + // /health is the noisy path that withPerfLog suppresses even when opt-in. + hHealth := app.withPerfLog(func(w http.ResponseWriter, r *http.Request) { + enablePerfLog(r) + w.WriteHeader(200) + }) + r = httptest.NewRequest("GET", "/health", nil) + w = httptest.NewRecorder() + hHealth(w, r) + if w.Code != 200 { + t.Fatalf("status=%d", w.Code) + } +} + +func TestEnablePerfLog_NoContextIsNoop(t *testing.T) { + r := httptest.NewRequest("GET", "/x", nil) + enablePerfLog(r) // must not panic when no perfLogger is in context +} diff --git a/container-go/cmd/ghfe/payload.go b/container/cmd/ghfe/payload.go similarity index 98% rename from container-go/cmd/ghfe/payload.go rename to container/cmd/ghfe/payload.go index 30955e4..af30dd0 100644 --- a/container-go/cmd/ghfe/payload.go +++ b/container/cmd/ghfe/payload.go @@ -1,7 +1,7 @@ package main import ( - "github.com/riseproject-dev/riscv-runner-app/container-go/internal" + "github.com/riseproject-dev/riscv-runner-app/container/internal" ) // senderDropKeys / repoDropKeys / repoOwnerDropKeys / orgDropKeys / diff --git a/container-go/cmd/ghfe/payload_test.go b/container/cmd/ghfe/payload_test.go similarity index 98% rename from container-go/cmd/ghfe/payload_test.go rename to container/cmd/ghfe/payload_test.go index be696ae..eba1ecc 100644 --- a/container-go/cmd/ghfe/payload_test.go +++ b/container/cmd/ghfe/payload_test.go @@ -4,7 +4,7 @@ import ( "encoding/json" "testing" - "github.com/riseproject-dev/riscv-runner-app/container-go/internal" + "github.com/riseproject-dev/riscv-runner-app/container/internal" ) // TestTrimWorkflowJobPayload_DropsURLsLicenseSteps locks down invariant diff --git a/container-go/cmd/ghfe/setup.go b/container/cmd/ghfe/setup.go similarity index 98% rename from container-go/cmd/ghfe/setup.go rename to container/cmd/ghfe/setup.go index 590a00f..a126006 100644 --- a/container-go/cmd/ghfe/setup.go +++ b/container/cmd/ghfe/setup.go @@ -9,7 +9,7 @@ import ( "net/http" "strconv" - "github.com/riseproject-dev/riscv-runner-app/container-go/internal" + "github.com/riseproject-dev/riscv-runner-app/container/internal" ) //go:embed setup.gohtml diff --git a/container-go/cmd/ghfe/setup.gohtml b/container/cmd/ghfe/setup.gohtml similarity index 100% rename from container-go/cmd/ghfe/setup.gohtml rename to container/cmd/ghfe/setup.gohtml diff --git a/container-go/cmd/ghfe/setup_test.go b/container/cmd/ghfe/setup_test.go similarity index 97% rename from container-go/cmd/ghfe/setup_test.go rename to container/cmd/ghfe/setup_test.go index f241451..650968c 100644 --- a/container-go/cmd/ghfe/setup_test.go +++ b/container/cmd/ghfe/setup_test.go @@ -6,8 +6,8 @@ import ( "strings" "testing" - "github.com/riseproject-dev/riscv-runner-app/container-go/internal" - "github.com/riseproject-dev/riscv-runner-app/container-go/internal/testutil" + "github.com/riseproject-dev/riscv-runner-app/container/internal" + "github.com/riseproject-dev/riscv-runner-app/container/internal/testutil" ) func newSetupApp(gh *testutil.FakeGH) *App { diff --git a/container-go/cmd/ghfe/signature.go b/container/cmd/ghfe/signature.go similarity index 100% rename from container-go/cmd/ghfe/signature.go rename to container/cmd/ghfe/signature.go diff --git a/container-go/cmd/ghfe/trace.go b/container/cmd/ghfe/trace.go similarity index 100% rename from container-go/cmd/ghfe/trace.go rename to container/cmd/ghfe/trace.go diff --git a/container/cmd/ghfe/trace_test.go b/container/cmd/ghfe/trace_test.go new file mode 100644 index 0000000..eb46874 --- /dev/null +++ b/container/cmd/ghfe/trace_test.go @@ -0,0 +1,246 @@ +package main + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/riseproject-dev/riscv-runner-app/container/internal" + "github.com/riseproject-dev/riscv-runner-app/container/internal/testutil" +) + +func newTraceApp() (*App, *testutil.FakeDB) { + db := testutil.NewFakeDB() + cfg := internal.Config{TraceSecret: "trace-secret"} + return &App{Config: cfg, DB: db, GH: &testutil.FakeGH{}}, db +} + +func authedReq(method, path, body string) *http.Request { + r := httptest.NewRequest(method, path, nil) + r.Header.Set("Authorization", "Bearer trace-secret") + return r +} + +func TestTrace_RequiresAuth(t *testing.T) { + app, _ := newTraceApp() + mux := app.Routes() + for _, path := range []string{"/trace/entity/1", "/trace/installation/1", "/trace/job/1", "/trace/payload/1"} { + r := httptest.NewRequest("GET", path, nil) // no Authorization + w := httptest.NewRecorder() + mux.ServeHTTP(w, r) + if w.Code != 401 { + t.Errorf("%s: status=%d want 401", path, w.Code) + } + } +} + +func TestTrace_EntityOK(t *testing.T) { + app, db := newTraceApp() + db.OnGetEventsByEntityID = func(id int64) ([]internal.InstallationEvent, error) { + return []internal.InstallationEvent{{ID: 1, Event: "ping", Outcome: "ok"}}, nil + } + r := authedReq("GET", "/trace/entity/42", "") + r.SetPathValue("entity_id", "42") + w := httptest.NewRecorder() + app.handleTraceEntity(w, r) + if w.Code != 200 { + t.Fatalf("status=%d body=%s", w.Code, w.Body.String()) + } + var out struct { + Events []internal.InstallationEvent `json:"events"` + } + if err := json.Unmarshal(w.Body.Bytes(), &out); err != nil { + t.Fatal(err) + } + if len(out.Events) != 1 || out.Events[0].ID != 1 { + t.Errorf("body=%s", w.Body.String()) + } +} + +func TestTrace_EntityInvalidID(t *testing.T) { + app, _ := newTraceApp() + r := authedReq("GET", "/trace/entity/abc", "") + r.SetPathValue("entity_id", "abc") + w := httptest.NewRecorder() + app.handleTraceEntity(w, r) + if w.Code != 400 { + t.Fatalf("status=%d", w.Code) + } +} + +func TestTrace_EntityDBError(t *testing.T) { + app, db := newTraceApp() + db.OnGetEventsByEntityID = func(int64) ([]internal.InstallationEvent, error) { + return nil, errBoom + } + r := authedReq("GET", "/trace/entity/1", "") + r.SetPathValue("entity_id", "1") + w := httptest.NewRecorder() + app.handleTraceEntity(w, r) + if w.Code != 500 { + t.Fatalf("status=%d", w.Code) + } +} + +func TestTrace_InstallationFlow(t *testing.T) { + app, db := newTraceApp() + db.OnGetEntityIDInstall = func(id int64) (int64, bool, error) { return 99, true, nil } + db.OnGetEventsByEntityID = func(int64) ([]internal.InstallationEvent, error) { + return []internal.InstallationEvent{{ID: 7}}, nil + } + r := authedReq("GET", "/trace/installation/3", "") + r.SetPathValue("installation_id", "3") + w := httptest.NewRecorder() + app.handleTraceInstallation(w, r) + if w.Code != 200 { + t.Fatalf("status=%d body=%s", w.Code, w.Body.String()) + } + + // Not found + db.OnGetEntityIDInstall = func(int64) (int64, bool, error) { return 0, false, nil } + r = authedReq("GET", "/trace/installation/4", "") + r.SetPathValue("installation_id", "4") + w = httptest.NewRecorder() + app.handleTraceInstallation(w, r) + if w.Code != 404 { + t.Fatalf("expected 404 not_found, got %d", w.Code) + } + + // DB error on lookup + db.OnGetEntityIDInstall = func(int64) (int64, bool, error) { return 0, false, errBoom } + r = authedReq("GET", "/trace/installation/5", "") + r.SetPathValue("installation_id", "5") + w = httptest.NewRecorder() + app.handleTraceInstallation(w, r) + if w.Code != 500 { + t.Fatalf("expected 500, got %d", w.Code) + } + + // Invalid id + r = authedReq("GET", "/trace/installation/abc", "") + r.SetPathValue("installation_id", "abc") + w = httptest.NewRecorder() + app.handleTraceInstallation(w, r) + if w.Code != 400 { + t.Fatalf("expected 400, got %d", w.Code) + } + + // Events lookup error after ok install + db.OnGetEntityIDInstall = func(int64) (int64, bool, error) { return 99, true, nil } + db.OnGetEventsByEntityID = func(int64) ([]internal.InstallationEvent, error) { return nil, errBoom } + r = authedReq("GET", "/trace/installation/6", "") + r.SetPathValue("installation_id", "6") + w = httptest.NewRecorder() + app.handleTraceInstallation(w, r) + if w.Code != 500 { + t.Fatalf("expected 500, got %d", w.Code) + } +} + +func TestTrace_JobFlow(t *testing.T) { + app, db := newTraceApp() + db.OnGetEntityIDJob = func(int64) (int64, bool, error) { return 99, true, nil } + db.OnGetEventsByEntityID = func(int64) ([]internal.InstallationEvent, error) { + return nil, nil + } + r := authedReq("GET", "/trace/job/1", "") + r.SetPathValue("job_id", "1") + w := httptest.NewRecorder() + app.handleTraceJob(w, r) + if w.Code != 200 { + t.Fatalf("status=%d body=%s", w.Code, w.Body.String()) + } + + db.OnGetEntityIDJob = func(int64) (int64, bool, error) { return 0, false, nil } + r = authedReq("GET", "/trace/job/2", "") + r.SetPathValue("job_id", "2") + w = httptest.NewRecorder() + app.handleTraceJob(w, r) + if w.Code != 404 { + t.Fatalf("expected 404, got %d", w.Code) + } + + db.OnGetEntityIDJob = func(int64) (int64, bool, error) { return 0, false, errBoom } + r = authedReq("GET", "/trace/job/3", "") + r.SetPathValue("job_id", "3") + w = httptest.NewRecorder() + app.handleTraceJob(w, r) + if w.Code != 500 { + t.Fatalf("expected 500, got %d", w.Code) + } + + r = authedReq("GET", "/trace/job/x", "") + r.SetPathValue("job_id", "x") + w = httptest.NewRecorder() + app.handleTraceJob(w, r) + if w.Code != 400 { + t.Fatalf("expected 400, got %d", w.Code) + } + + // 200 path with downstream events error → 500 + db.OnGetEntityIDJob = func(int64) (int64, bool, error) { return 99, true, nil } + db.OnGetEventsByEntityID = func(int64) ([]internal.InstallationEvent, error) { return nil, errBoom } + r = authedReq("GET", "/trace/job/7", "") + r.SetPathValue("job_id", "7") + w = httptest.NewRecorder() + app.handleTraceJob(w, r) + if w.Code != 500 { + t.Fatalf("expected 500, got %d", w.Code) + } +} + +func TestTrace_PayloadFlow(t *testing.T) { + app, db := newTraceApp() + + // Payload found + db.OnGetPayloadByID = func(int64) ([]byte, error) { return []byte(`{"k":"v"}`), nil } + r := authedReq("GET", "/trace/payload/1", "") + r.SetPathValue("event_id", "1") + w := httptest.NewRecorder() + app.handleTracePayload(w, r) + if w.Code != 200 { + t.Fatalf("status=%d body=%s", w.Code, w.Body.String()) + } + if w.Body.String() != `{"payload":{"k":"v"}}` { + t.Errorf("body=%q", w.Body.String()) + } + if ct := w.Header().Get("Content-Type"); ct != "application/json" { + t.Errorf("content-type=%q", ct) + } + + // Payload missing → 404 + db.OnGetPayloadByID = func(int64) ([]byte, error) { return nil, nil } + r = authedReq("GET", "/trace/payload/2", "") + r.SetPathValue("event_id", "2") + w = httptest.NewRecorder() + app.handleTracePayload(w, r) + if w.Code != 404 { + t.Fatalf("expected 404, got %d", w.Code) + } + + // DB error → 500 + db.OnGetPayloadByID = func(int64) ([]byte, error) { return nil, errBoom } + r = authedReq("GET", "/trace/payload/3", "") + r.SetPathValue("event_id", "3") + w = httptest.NewRecorder() + app.handleTracePayload(w, r) + if w.Code != 500 { + t.Fatalf("expected 500, got %d", w.Code) + } + + // Invalid id + r = authedReq("GET", "/trace/payload/abc", "") + r.SetPathValue("event_id", "abc") + w = httptest.NewRecorder() + app.handleTracePayload(w, r) + if w.Code != 400 { + t.Fatalf("expected 400, got %d", w.Code) + } +} + +var errBoom = &stubErr{"boom"} + +type stubErr struct{ s string } + +func (e *stubErr) Error() string { return e.s } diff --git a/container-go/cmd/ghfe/webhook.go b/container/cmd/ghfe/webhook.go similarity index 99% rename from container-go/cmd/ghfe/webhook.go rename to container/cmd/ghfe/webhook.go index f3d9c13..af63036 100644 --- a/container-go/cmd/ghfe/webhook.go +++ b/container/cmd/ghfe/webhook.go @@ -7,7 +7,7 @@ import ( "net/http" "strconv" - "github.com/riseproject-dev/riscv-runner-app/container-go/internal" + "github.com/riseproject-dev/riscv-runner-app/container/internal" ) // webhook is POST /, the GitHub-App webhook entry point. diff --git a/container/cmd/ghfe/webhook_test.go b/container/cmd/ghfe/webhook_test.go new file mode 100644 index 0000000..83d57d8 --- /dev/null +++ b/container/cmd/ghfe/webhook_test.go @@ -0,0 +1,680 @@ +package main + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/riseproject-dev/riscv-runner-app/container/internal" + "github.com/riseproject-dev/riscv-runner-app/container/internal/testutil" +) + +const webhookSecret = "test-secret" + +func newTestApp() (*App, *testutil.FakeDB) { + db := testutil.NewFakeDB() + cfg := internal.Config{ + Prod: false, + WebhookSecret: webhookSecret, + ImageUbuntu24: "img24", + ImageUbuntu26: "img26", + RunnerPrefix: "rise-riscv-runner-staging-", + } + return &App{Config: cfg, DB: db, GH: &testutil.FakeGH{}}, db +} + +func signedRequest(t *testing.T, body []byte, event, appID string) *http.Request { + t.Helper() + mac := hmac.New(sha256.New, []byte(webhookSecret)) + mac.Write(body) + sig := "sha256=" + hex.EncodeToString(mac.Sum(nil)) + r := httptest.NewRequest("POST", "/", bytes.NewReader(body)) + r.Header.Set(internal.HookSignatureHeader, sig) + r.Header.Set(internal.HookEventHeader, event) + r.Header.Set(internal.HookAppIDHeader, appID) + return r +} + +// TestWebhook_SignatureMismatch verifies an unsigned request 401s and writes +// no installation_events row (signature is checked before the row gate). +func TestWebhook_SignatureMismatch(t *testing.T) { + app, db := newTestApp() + r := httptest.NewRequest("POST", "/", bytes.NewReader([]byte(`{}`))) + r.Header.Set(internal.HookSignatureHeader, "sha256=bogus") + r.Header.Set(internal.HookEventHeader, "workflow_job") + r.Header.Set(internal.HookAppIDHeader, "1") + w := httptest.NewRecorder() + app.handleWebhook(w, r) + if w.Code != 401 { + t.Fatalf("status=%d want 401", w.Code) + } + if len(db.Events) != 0 { + t.Fatalf("no event row expected, got %d", len(db.Events)) + } +} + +// TestWebhook_PingWritesEventRow covers the b909123 invariant for the ping path. +func TestWebhook_PingWritesEventRow(t *testing.T) { + app, db := newTestApp() + body := []byte(`{"zen":"hi"}`) + w := httptest.NewRecorder() + app.handleWebhook(w, signedRequest(t, body, "ping", "2167633")) + if w.Code != 200 { + t.Fatalf("status=%d", w.Code) + } + if len(db.Events) != 1 { + t.Fatalf("expected 1 event row, got %d", len(db.Events)) + } + if db.Events[0].Row.Event != "ping" || db.Events[0].Row.Outcome != string(internal.OutcomeOK) { + t.Errorf("row=%+v", db.Events[0].Row) + } +} + +// TestWebhook_InstallationEvents covers each of: installation, installation_repositories, +// installation_target, ignored_event. Each must produce exactly one row, and +// each must record the right WebhookOutcome verbatim into installation_events.outcome. +func TestWebhook_InstallationEvents(t *testing.T) { + cases := []struct { + event string + payload map[string]any + wantEvent string + wantOutcome internal.WebhookOutcome + }{ + { + event: "installation", + payload: map[string]any{ + "action": "created", + "installation": map[string]any{ + "id": float64(1), + "target_id": float64(99), + "target_type": "Organization", + "account": map[string]any{"login": "org"}, + }, + }, + wantEvent: "installation.created", + wantOutcome: internal.OutcomeOK, + }, + { + event: "installation_repositories", + payload: map[string]any{ + "action": "added", + "installation": map[string]any{ + "id": float64(1), + "target_id": float64(99), + "target_type": "Organization", + "account": map[string]any{"login": "org"}, + }, + }, + wantEvent: "installation_repositories.added", + wantOutcome: internal.OutcomeOK, + }, + { + event: "installation_target", + payload: map[string]any{ + "action": "renamed", + "target_type": "Organization", + "account": map[string]any{"id": float64(42), "login": "new"}, + "installation": map[string]any{"id": float64(1)}, + }, + wantEvent: "installation_target.renamed", + wantOutcome: internal.OutcomeOK, + }, + { + event: "unknown_event", + payload: map[string]any{}, + wantEvent: "unknown_event", + wantOutcome: internal.OutcomeIgnoredEvent, + }, + } + for _, tc := range cases { + t.Run(tc.event, func(t *testing.T) { + app, db := newTestApp() + body, _ := json.Marshal(tc.payload) + w := httptest.NewRecorder() + app.handleWebhook(w, signedRequest(t, body, tc.event, "2167633")) + if w.Code != 200 { + t.Fatalf("status=%d", w.Code) + } + if len(db.Events) != 1 { + t.Fatalf("expected 1 event row, got %d", len(db.Events)) + } + row := db.Events[0].Row + if row.Event != tc.wantEvent { + t.Errorf("event=%q want %q", row.Event, tc.wantEvent) + } + if row.Outcome != string(tc.wantOutcome) { + t.Errorf("outcome=%q want %q", row.Outcome, tc.wantOutcome) + } + if row.Source != "webhook" { + t.Errorf("source=%q want webhook", row.Source) + } + }) + } +} + +// TestWebhook_WorkflowJob_IgnoredAction asserts unrecognised actions trip the +// ignored_action outcome. +func TestWebhook_WorkflowJob_IgnoredAction(t *testing.T) { + app, db := newTestApp() + body := mustJSON(map[string]any{ + "action": "waiting", + "installation": map[string]any{"id": float64(1)}, + "repository": map[string]any{ + "id": float64(2), "full_name": "x/y", + "owner": map[string]any{"id": float64(99), "type": "Organization", "login": "x"}, + }, + "workflow_job": map[string]any{"id": float64(7), "labels": []any{"ubuntu-24.04-riscv"}}, + }) + w := httptest.NewRecorder() + app.handleWebhook(w, signedRequest(t, body, "workflow_job", "2167633")) + if len(db.Events) != 1 || db.Events[0].Row.Outcome != string(internal.OutcomeIgnoredAction) { + t.Fatalf("expected ignored_action, got rows=%+v", db.Events) + } +} + +// TestIgnoredNoLabel_PayloadMinimized verifies aae3ab3: ignored_no_label +// keeps only workflow_job.{labels,html_url} + repository.full_name. +func TestIgnoredNoLabel_PayloadMinimized(t *testing.T) { + app, db := newTestApp() + body := mustJSON(map[string]any{ + "action": "queued", + "installation": map[string]any{"id": float64(1)}, + "repository": map[string]any{ + "id": float64(2), "full_name": "x/y", "url": "drop", + "owner": map[string]any{"id": float64(99), "type": "Organization", "login": "x", "url": "drop"}, + }, + "workflow_job": map[string]any{ + "id": float64(7), + "labels": []any{"ubuntu-26.04-riscv"}, + "html_url": "https://example.com", + "url": "drop", + "steps": []any{"a"}, + }, + }) + w := httptest.NewRecorder() + app.handleWebhook(w, signedRequest(t, body, "workflow_job", "2167633")) + if len(db.Events) != 1 || db.Events[0].Row.Outcome != string(internal.OutcomeIgnoredNoLabel) { + t.Fatalf("expected ignored_no_label, rows=%+v", db.Events) + } + var payload map[string]any + if err := json.Unmarshal(db.Events[0].Payload, &payload); err != nil { + t.Fatal(err) + } + job, _ := payload["workflow_job"].(map[string]any) + if job["url"] != nil { + t.Errorf("workflow_job.url leaked") + } + if job["html_url"] != "https://example.com" { + t.Errorf("html_url lost: %v", job["html_url"]) + } + if _, has := job["steps"]; has { + t.Errorf("steps leaked") + } + if _, has := payload["sender"]; has { + t.Errorf("sender leaked") + } + repo, _ := payload["repository"].(map[string]any) + if repo["full_name"] != "x/y" { + t.Errorf("repository.full_name lost: %v", repo["full_name"]) + } + if _, has := repo["url"]; has { + t.Errorf("repository.url leaked") + } +} + +// TestWebhook_QueuedJobStored asserts a valid queued event writes an event +// row with job_stored and persists the job. +func TestWebhook_QueuedJobStored(t *testing.T) { + app, db := newTestApp() + body := mustJSON(map[string]any{ + "action": "queued", + "installation": map[string]any{"id": float64(1)}, + "repository": map[string]any{ + "id": float64(2), "full_name": "x/y", + "owner": map[string]any{"id": float64(99), "type": "Organization", "login": "x"}, + }, + "workflow_job": map[string]any{ + "id": float64(7), + "name": "build", + "labels": []any{"ubuntu-24.04-riscv"}, + "html_url": "https://example.com", + }, + }) + w := httptest.NewRecorder() + app.handleWebhook(w, signedRequest(t, body, "workflow_job", "2167633")) + if w.Code != 200 { + t.Fatalf("status=%d", w.Code) + } + if len(db.Jobs) != 1 || db.Jobs[0].JobID != 7 { + t.Fatalf("job not stored: %+v", db.Jobs) + } + if len(db.Events) != 1 || db.Events[0].Row.Outcome != string(internal.OutcomeJobStored) { + t.Errorf("event row mismatch: %+v", db.Events) + } + // CONTRACT §2 entity-extraction table: workflow_job uses repository.owner.{id,login} + // and repository.owner.type, and installation.id is captured. + row := db.Events[0].Row + if row.EntityID == nil || *row.EntityID != 99 { + t.Errorf("entity_id=%v want 99", row.EntityID) + } + if row.EntityName == nil || *row.EntityName != "x" { + t.Errorf("entity_name=%v want x", row.EntityName) + } + if row.EntityType == nil || *row.EntityType != "Organization" { + t.Errorf("entity_type=%v want Organization", row.EntityType) + } + if row.InstallationID == nil || *row.InstallationID != 1 { + t.Errorf("installation_id=%v want 1", row.InstallationID) + } +} + +// TestWebhook_BodyTooShortForSignature errors out with 400/401 not a panic. +func TestWebhook_MissingHeaders(t *testing.T) { + app, _ := newTestApp() + r := httptest.NewRequest("POST", "/", bytes.NewReader([]byte(`{}`))) + w := httptest.NewRecorder() + app.handleWebhook(w, r) + if w.Code != 400 { + t.Fatalf("expected 400 on missing event header, got %d", w.Code) + } +} + +// TestWebhook_BadHeaderAndJSONPaths covers the early-return error branches: +// missing body, bad signature, invalid app-id, malformed JSON. +func TestWebhook_BadHeaderAndJSONPaths(t *testing.T) { + app, _ := newTestApp() + + // Missing event header + r := httptest.NewRequest("POST", "/", bytes.NewReader([]byte(`{}`))) + w := httptest.NewRecorder() + app.handleWebhook(w, r) + if w.Code != 400 { + t.Errorf("missing event: status=%d", w.Code) + } + + // Missing app-id header + body := []byte(`{"zen":"hi"}`) + r = signedRequest(t, body, "ping", "") + r.Header.Del(internal.HookAppIDHeader) + w = httptest.NewRecorder() + app.handleWebhook(w, r) + if w.Code != 400 { + t.Errorf("missing app id: status=%d", w.Code) + } + + // Non-numeric app id + r = signedRequest(t, body, "ping", "abc") + w = httptest.NewRecorder() + app.handleWebhook(w, r) + if w.Code != 400 { + t.Errorf("bad app id: status=%d", w.Code) + } + + // Invalid JSON body (signature still valid since we sign the literal body) + bad := []byte(`{bogus`) + r = signedRequest(t, bad, "ping", "1") + w = httptest.NewRecorder() + app.handleWebhook(w, r) + if w.Code != 400 { + t.Errorf("invalid json: status=%d", w.Code) + } +} + +// TestWebhook_WorkflowJob_MissingPayloadParts covers the missing-required-field +// branches in handleWorkflowJobEvent. +func TestWebhook_WorkflowJob_MissingPayloadParts(t *testing.T) { + app, _ := newTestApp() + + cases := []struct { + name string + payload map[string]any + want int + }{ + { + "missing workflow_job", + map[string]any{ + "action": "queued", + "installation": map[string]any{"id": float64(1)}, + "repository": map[string]any{"id": float64(2), "full_name": "x/y", "owner": map[string]any{"id": float64(99), "type": "Organization", "login": "x"}}, + }, + 400, + }, + { + "missing owner", + map[string]any{ + "action": "queued", + "installation": map[string]any{"id": float64(1)}, + "repository": map[string]any{"id": float64(2), "full_name": "x/y"}, + "workflow_job": map[string]any{"id": float64(7), "labels": []any{"ubuntu-24.04-riscv"}}, + }, + 400, + }, + { + "missing job id", + map[string]any{ + "action": "queued", + "installation": map[string]any{"id": float64(1)}, + "repository": map[string]any{"id": float64(2), "full_name": "x/y", "owner": map[string]any{"id": float64(99), "type": "Organization", "login": "x"}}, + "workflow_job": map[string]any{"labels": []any{"ubuntu-24.04-riscv"}}, + }, + 400, + }, + { + "missing repo full_name", + map[string]any{ + "action": "queued", + "installation": map[string]any{"id": float64(1)}, + "repository": map[string]any{"id": float64(2), "owner": map[string]any{"id": float64(99), "type": "Organization", "login": "x"}}, + "workflow_job": map[string]any{"id": float64(7), "labels": []any{"ubuntu-24.04-riscv"}}, + }, + 400, + }, + { + "missing repo id", + map[string]any{ + "action": "queued", + "installation": map[string]any{"id": float64(1)}, + "repository": map[string]any{"full_name": "x/y", "owner": map[string]any{"id": float64(99), "type": "Organization", "login": "x"}}, + "workflow_job": map[string]any{"id": float64(7), "labels": []any{"ubuntu-24.04-riscv"}}, + }, + 400, + }, + { + "missing owner id", + map[string]any{ + "action": "queued", + "installation": map[string]any{"id": float64(1)}, + "repository": map[string]any{"id": float64(2), "full_name": "x/y", "owner": map[string]any{"type": "Organization", "login": "x"}}, + "workflow_job": map[string]any{"id": float64(7), "labels": []any{"ubuntu-24.04-riscv"}}, + }, + 400, + }, + { + "unsupported entity type", + map[string]any{ + "action": "queued", + "installation": map[string]any{"id": float64(1)}, + "repository": map[string]any{"id": float64(2), "full_name": "x/y", "owner": map[string]any{"id": float64(99), "type": "Bot", "login": "x"}}, + "workflow_job": map[string]any{"id": float64(7), "labels": []any{"ubuntu-24.04-riscv"}}, + }, + 400, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + w := httptest.NewRecorder() + app.handleWebhook(w, signedRequest(t, mustJSON(tc.payload), "workflow_job", "2167633")) + if w.Code != tc.want { + t.Errorf("status=%d want %d body=%q", w.Code, tc.want, w.Body.String()) + } + }) + } +} + +// TestWebhook_WorkflowJob_QueuedMissingInstallOrURL covers the +// inner branches after label matching (install id / html url / entity name). +func TestWebhook_WorkflowJob_QueuedMissingInstallOrURL(t *testing.T) { + app, _ := newTestApp() + base := func(over map[string]any) []byte { + p := map[string]any{ + "action": "queued", + "installation": map[string]any{"id": float64(1)}, + "repository": map[string]any{ + "id": float64(2), "full_name": "x/y", + "owner": map[string]any{"id": float64(99), "type": "Organization", "login": "x"}, + }, + "workflow_job": map[string]any{ + "id": float64(7), + "labels": []any{"ubuntu-24.04-riscv"}, + "html_url": "https://example.com", + }, + } + for k, v := range over { + p[k] = v + } + return mustJSON(p) + } + + for _, tc := range []struct { + name string + body []byte + }{ + { + "missing installation id", + base(map[string]any{"installation": map[string]any{}}), + }, + { + "missing html_url", + base(map[string]any{"workflow_job": map[string]any{"id": float64(7), "labels": []any{"ubuntu-24.04-riscv"}}}), + }, + { + "missing entity login", + base(map[string]any{"repository": map[string]any{"id": float64(2), "full_name": "x/y", "owner": map[string]any{"id": float64(99), "type": "Organization"}}}), + }, + } { + t.Run(tc.name, func(t *testing.T) { + w := httptest.NewRecorder() + app.handleWebhook(w, signedRequest(t, tc.body, "workflow_job", "2167633")) + if w.Code != 400 { + t.Errorf("status=%d want 400 body=%q", w.Code, w.Body.String()) + } + }) + } +} + +// TestWebhook_InProgressAndCompleted exercises the in_progress / completed +// branches incl. job-not-found fallbacks. Each success / not-found case must +// record the corresponding WebhookOutcome into installation_events. +func TestWebhook_InProgressAndCompleted(t *testing.T) { + body := func(action string) []byte { + return mustJSON(map[string]any{ + "action": action, + "installation": map[string]any{"id": float64(1)}, + "repository": map[string]any{ + "id": float64(2), "full_name": "x/y", + "owner": map[string]any{"id": float64(99), "type": "Organization", "login": "x"}, + }, + "workflow_job": map[string]any{ + "id": float64(7), + "labels": []any{"ubuntu-24.04-riscv"}, + "html_url": "https://example.com", + "runner_name": "r-1", + }, + }) + } + + // in_progress + DB has pending → outcome=job_marked_running, body says "marked running". + { + app, db := newTestApp() + db.OnMarkJobRunning = func(int64, string) (string, error) { return "pending", nil } + w := httptest.NewRecorder() + app.handleWebhook(w, signedRequest(t, body("in_progress"), "workflow_job", "2167633")) + if w.Code != 200 || !strings.Contains(w.Body.String(), "marked running") { + t.Errorf("in_progress ok: status=%d body=%q", w.Code, w.Body.String()) + } + if len(db.Events) != 1 || db.Events[0].Row.Outcome != string(internal.OutcomeJobMarkedRunning) { + t.Errorf("expected outcome=job_marked_running, got %+v", db.Events) + } + } + + // in_progress + DB has no row → outcome=job_not_found. + { + app, db := newTestApp() + db.OnMarkJobRunning = func(int64, string) (string, error) { return "", nil } + w := httptest.NewRecorder() + app.handleWebhook(w, signedRequest(t, body("in_progress"), "workflow_job", "2167633")) + if w.Code != 200 || !strings.Contains(w.Body.String(), "not found") { + t.Errorf("status=%d body=%q", w.Code, w.Body.String()) + } + if len(db.Events) != 1 || db.Events[0].Row.Outcome != string(internal.OutcomeJobNotFound) { + t.Errorf("expected outcome=job_not_found, got %+v", db.Events) + } + } + + // in_progress + DB error → 500, no installation_events row (DB write failed + // before recordEvent ran). Internal 5xx errors are not anchored to an outcome. + { + app, db := newTestApp() + db.OnMarkJobRunning = func(int64, string) (string, error) { return "", errBoom } + w := httptest.NewRecorder() + app.handleWebhook(w, signedRequest(t, body("in_progress"), "workflow_job", "2167633")) + if w.Code != 500 { + t.Errorf("in_progress err: status=%d", w.Code) + } + if len(db.Events) != 0 { + t.Errorf("DB-error path should not record an installation_events row, got %+v", db.Events) + } + } + + // completed + DB had running → outcome=job_marked_completed. + { + app, db := newTestApp() + db.OnMarkJobComplete = func(int64, string) (string, error) { return "running", nil } + w := httptest.NewRecorder() + app.handleWebhook(w, signedRequest(t, body("completed"), "workflow_job", "2167633")) + if w.Code != 200 || !strings.Contains(w.Body.String(), "completed") { + t.Errorf("completed: status=%d body=%q", w.Code, w.Body.String()) + } + if len(db.Events) != 1 || db.Events[0].Row.Outcome != string(internal.OutcomeJobMarkedCompleted) { + t.Errorf("expected outcome=job_marked_completed, got %+v", db.Events) + } + } + + // completed + DB has no row → outcome=job_not_found. + { + app, db := newTestApp() + db.OnMarkJobComplete = func(int64, string) (string, error) { return "", nil } + w := httptest.NewRecorder() + app.handleWebhook(w, signedRequest(t, body("completed"), "workflow_job", "2167633")) + if w.Code != 200 || !strings.Contains(w.Body.String(), "not found") { + t.Errorf("status=%d body=%q", w.Code, w.Body.String()) + } + if len(db.Events) != 1 || db.Events[0].Row.Outcome != string(internal.OutcomeJobNotFound) { + t.Errorf("expected outcome=job_not_found, got %+v", db.Events) + } + } + + // completed + DB error → 500, no row. + { + app, db := newTestApp() + db.OnMarkJobComplete = func(int64, string) (string, error) { return "", errBoom } + w := httptest.NewRecorder() + app.handleWebhook(w, signedRequest(t, body("completed"), "workflow_job", "2167633")) + if w.Code != 500 { + t.Errorf("completed err: status=%d", w.Code) + } + if len(db.Events) != 0 { + t.Errorf("DB-error path should not record an installation_events row, got %+v", db.Events) + } + } +} + +// TestWebhook_QueuedAddJobError covers the AddJob DB-error branch. +func TestWebhook_QueuedAddJobError(t *testing.T) { + app, db := newTestApp() + db.OnAddJob = func(internal.Job, []string) (bool, error) { return false, errBoom } + body := mustJSON(map[string]any{ + "action": "queued", + "installation": map[string]any{"id": float64(1)}, + "repository": map[string]any{ + "id": float64(2), "full_name": "x/y", + "owner": map[string]any{"id": float64(99), "type": "Organization", "login": "x"}, + }, + "workflow_job": map[string]any{ + "id": float64(7), + "labels": []any{"ubuntu-24.04-riscv"}, + "html_url": "https://example.com", + }, + }) + w := httptest.NewRecorder() + app.handleWebhook(w, signedRequest(t, body, "workflow_job", "2167633")) + if w.Code != 500 { + t.Fatalf("status=%d body=%q", w.Code, w.Body.String()) + } +} + +// TestWebhook_QueuedAlreadyExists covers the stored=false branch. +func TestWebhook_QueuedAlreadyExists(t *testing.T) { + app, db := newTestApp() + db.OnAddJob = func(internal.Job, []string) (bool, error) { return false, nil } + body := mustJSON(map[string]any{ + "action": "queued", + "installation": map[string]any{"id": float64(1)}, + "repository": map[string]any{ + "id": float64(2), "full_name": "x/y", + "owner": map[string]any{"id": float64(99), "type": "Organization", "login": "x"}, + }, + "workflow_job": map[string]any{ + "id": float64(7), + "labels": []any{"ubuntu-24.04-riscv"}, + "html_url": "https://example.com", + }, + }) + w := httptest.NewRecorder() + app.handleWebhook(w, signedRequest(t, body, "workflow_job", "2167633")) + if w.Code != 200 { + t.Fatalf("status=%d", w.Code) + } + if !strings.Contains(w.Body.String(), "already exists") { + t.Errorf("body=%q", w.Body.String()) + } + if len(db.Events) != 1 || db.Events[0].Row.Outcome != string(internal.OutcomeJobAlreadyExists) { + t.Errorf("outcome=%+v", db.Events) + } +} + +// TestWebhook_InstallationEvent_Missing covers the bad-payload branches in +// handleInstallationEvent / handleInstallationTargetEvent. +func TestWebhook_InstallationEvent_Missing(t *testing.T) { + app, _ := newTestApp() + + cases := []struct { + event string + payload map[string]any + }{ + {"installation", map[string]any{"action": "created"}}, // no installation + {"installation", map[string]any{"action": "created", "installation": map[string]any{"id": float64(1)}}}, // no account + {"installation_target", map[string]any{"action": "renamed"}}, // no account + {"installation_target", map[string]any{"action": "renamed", "account": map[string]any{"login": "n", "id": float64(1)}}}, // no installation + } + for _, tc := range cases { + w := httptest.NewRecorder() + app.handleWebhook(w, signedRequest(t, mustJSON(tc.payload), tc.event, "2167633")) + if w.Code != 400 { + t.Errorf("%s: status=%d body=%q", tc.event, w.Code, w.Body.String()) + } + } +} + +// TestAsInt64 covers each numeric input path. +func TestAsInt64(t *testing.T) { + cases := []struct { + in any + want int64 + }{ + {float64(7), 7}, + {int64(8), 8}, + {int(9), 9}, + {"nope", 0}, + {nil, 0}, + } + for _, c := range cases { + if got := asInt64(c.in); got != c.want { + t.Errorf("asInt64(%v)=%d want %d", c.in, got, c.want) + } + } +} + +func mustJSON(v any) []byte { + b, err := json.Marshal(v) + if err != nil { + panic(err) + } + return b +} diff --git a/container-go/cmd/scheduler/demand_match.go b/container/cmd/scheduler/demand_match.go similarity index 99% rename from container-go/cmd/scheduler/demand_match.go rename to container/cmd/scheduler/demand_match.go index 6fbde72..93e8e16 100644 --- a/container-go/cmd/scheduler/demand_match.go +++ b/container/cmd/scheduler/demand_match.go @@ -9,7 +9,7 @@ import ( "log/slog" "math/big" - "github.com/riseproject-dev/riscv-runner-app/container-go/internal" + "github.com/riseproject-dev/riscv-runner-app/container/internal" ) // demandMatch iterates pending jobs FIFO, groups by k8s_pool, and provisions diff --git a/container-go/cmd/scheduler/demand_match_test.go b/container/cmd/scheduler/demand_match_test.go similarity index 96% rename from container-go/cmd/scheduler/demand_match_test.go rename to container/cmd/scheduler/demand_match_test.go index 088d42a..0b37223 100644 --- a/container-go/cmd/scheduler/demand_match_test.go +++ b/container/cmd/scheduler/demand_match_test.go @@ -5,8 +5,8 @@ import ( "errors" "testing" - "github.com/riseproject-dev/riscv-runner-app/container-go/internal" - "github.com/riseproject-dev/riscv-runner-app/container-go/internal/testutil" + "github.com/riseproject-dev/riscv-runner-app/container/internal" + "github.com/riseproject-dev/riscv-runner-app/container/internal/testutil" ) // schedTestApp wires a scheduler App with fakes shared across tests. diff --git a/container-go/cmd/scheduler/gh_auth.go b/container/cmd/scheduler/gh_auth.go similarity index 98% rename from container-go/cmd/scheduler/gh_auth.go rename to container/cmd/scheduler/gh_auth.go index 40ec367..3d6a6a1 100644 --- a/container-go/cmd/scheduler/gh_auth.go +++ b/container/cmd/scheduler/gh_auth.go @@ -8,7 +8,7 @@ import ( "log/slog" "strconv" - "github.com/riseproject-dev/riscv-runner-app/container-go/internal" + "github.com/riseproject-dev/riscv-runner-app/container/internal" ) // authPayload is the JSON body written into installation_events.payload diff --git a/container/cmd/scheduler/gh_auth_test.go b/container/cmd/scheduler/gh_auth_test.go new file mode 100644 index 0000000..fc58d95 --- /dev/null +++ b/container/cmd/scheduler/gh_auth_test.go @@ -0,0 +1,143 @@ +package main + +import ( + "context" + "encoding/json" + "testing" + + "github.com/riseproject-dev/riscv-runner-app/container/internal" +) + +func TestI64s(t *testing.T) { + if got := i64s(42); got != "42" { + t.Errorf("got %q", got) + } +} + +func TestOrgRunnerKey_EntityAndTarget(t *testing.T) { + org := orgRunnerKey{EntityType: internal.EntityOrganization, EntityName: "acme", EntityID: 9, InstallationID: 1} + if got := org.Target(); got != "acme" { + t.Errorf("org target=%q", got) + } + if e := org.Entity(); e.Name != "acme" || e.ID != 9 || e.Type != internal.EntityOrganization { + t.Errorf("org entity=%+v", e) + } + if got := org.String(); got == "" { + t.Errorf("empty stringify") + } + + user := orgRunnerKey{EntityType: internal.EntityUser, EntityName: "luhenry", EntityID: 7, RepoFullName: "luhenry/repo"} + if got := user.Target(); got != "luhenry/repo" { + t.Errorf("user target=%q", got) + } +} + +func TestRunnerKeyForWorker(t *testing.T) { + repo := "luhenry/repo" + w := internal.Worker{ + EntityType: "User", + EntityID: 7, + EntityName: "luhenry", + InstallationID: 1, + RepoFullName: &repo, + } + k := runnerKeyForWorker(w) + if k.EntityType != internal.EntityUser || k.RepoFullName != "luhenry/repo" { + t.Errorf("got %+v", k) + } + + // Org: repo not set + w2 := internal.Worker{EntityType: "Organization", EntityID: 9, EntityName: "acme", InstallationID: 1} + k2 := runnerKeyForWorker(w2) + if k2.RepoFullName != "" { + t.Errorf("org should not carry repo: %+v", k2) + } +} + +func TestGhAuthenticate_Success(t *testing.T) { + app, _, gh, _ := schedTestApp() + gh.OnAuthenticateApp = func(instID, appID int64) (string, error) { + if appID != internal.GHAppOrgID { + t.Errorf("expected org app id, got %d", appID) + } + return "tok", nil + } + tok, err := app.ghAuthenticate(context.Background(), 1, + internal.Entity{Type: internal.EntityOrganization, Name: "acme", ID: 9}, authCtx{}) + if err != nil || tok != "tok" { + t.Fatalf("tok=%q err=%v", tok, err) + } +} + +func TestGhAuthenticate_404RecordsRow(t *testing.T) { + app, db, gh, _ := schedTestApp() + gh.OnAuthenticateApp = func(int64, int64) (string, error) { + return "", &internal.GitHubAPIError{StatusCode: 404, Message: "not found"} + } + _, err := app.ghAuthenticate(context.Background(), 5, + internal.Entity{Type: internal.EntityOrganization, Name: "acme", ID: 9}, + authCtx{RepoFullName: "acme/r", JobID: 7}) + if err == nil { + t.Fatal("expected error returned") + } + if len(db.Events) != 1 { + t.Fatalf("expected one event row, got %d", len(db.Events)) + } + row := db.Events[0].Row + if row.Source != "scheduler" || row.Event != "auth_attempt.404" || row.Outcome != string(internal.OutcomeAuth404) { + t.Errorf("row=%+v", row) + } + // Payload contains repo + job + var p authPayload + if err := json.Unmarshal(db.Events[0].Payload, &p); err != nil { + t.Fatal(err) + } + if p.Repository == nil || p.Repository.FullName != "acme/r" { + t.Errorf("payload repo=%+v", p.Repository) + } + if p.WorkflowJob == nil || p.WorkflowJob.ID != 7 { + t.Errorf("payload job=%+v", p.WorkflowJob) + } +} + +func TestGhAuthenticate_OtherAPIError(t *testing.T) { + app, db, gh, _ := schedTestApp() + gh.OnAuthenticateApp = func(int64, int64) (string, error) { + return "", &internal.GitHubAPIError{StatusCode: 500, Message: "boom"} + } + _, err := app.ghAuthenticate(context.Background(), 5, + internal.Entity{Type: internal.EntityUser, Name: "luhenry", ID: 7}, authCtx{}) + if err == nil { + t.Fatal("expected error") + } + if len(db.Events) != 1 || db.Events[0].Row.Outcome != string(internal.OutcomeAuthOtherError) { + t.Errorf("row=%+v", db.Events) + } +} + +func TestGhAuthenticate_NonAPIErrorBubblesUp(t *testing.T) { + app, db, gh, _ := schedTestApp() + gh.OnAuthenticateApp = func(int64, int64) (string, error) { return "", errBoom } + _, err := app.ghAuthenticate(context.Background(), 5, + internal.Entity{Type: internal.EntityOrganization, Name: "acme", ID: 9}, authCtx{}) + if err == nil { + t.Fatal("expected error") + } + // No installation_events row is written for non-API errors + if len(db.Events) != 0 { + t.Errorf("unexpected rows: %+v", db.Events) + } +} + +func TestGhAuthenticate_AddEventErrorSwallowed(t *testing.T) { + app, db, gh, _ := schedTestApp() + gh.OnAuthenticateApp = func(int64, int64) (string, error) { + return "", &internal.GitHubAPIError{StatusCode: 404, Message: "n"} + } + db.OnAddEvent = func(internal.InstallationEvent, []byte) (int64, error) { return 0, errBoom } + _, err := app.ghAuthenticate(context.Background(), 1, + internal.Entity{Type: internal.EntityOrganization, Name: "acme", ID: 9}, authCtx{}) + if err == nil { + t.Fatal("expected the 404 error to bubble out regardless of DB failure") + } +} diff --git a/container-go/cmd/scheduler/handlers.go b/container/cmd/scheduler/handlers.go similarity index 99% rename from container-go/cmd/scheduler/handlers.go rename to container/cmd/scheduler/handlers.go index 6f1eae3..453fc58 100644 --- a/container-go/cmd/scheduler/handlers.go +++ b/container/cmd/scheduler/handlers.go @@ -10,7 +10,7 @@ import ( "strings" "time" - "github.com/riseproject-dev/riscv-runner-app/container-go/internal" + "github.com/riseproject-dev/riscv-runner-app/container/internal" ) func (a *App) Routes() *http.ServeMux { diff --git a/container/cmd/scheduler/handlers_test.go b/container/cmd/scheduler/handlers_test.go new file mode 100644 index 0000000..2a33496 --- /dev/null +++ b/container/cmd/scheduler/handlers_test.go @@ -0,0 +1,306 @@ +package main + +import ( + "encoding/json" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/riseproject-dev/riscv-runner-app/container/internal" +) + +// TestWorkers_PaginationAndLinkHeader locks invariant caf0e8a: /workers.json +// emits the GitHub-style Link header with rel="next"/"prev". +func TestWorkers_PaginationAndLinkHeader(t *testing.T) { + app, db, _, _ := schedTestApp() + for i := 0; i < 250; i++ { + db.Workers = append(db.Workers, internal.Worker{PodName: "p", Status: "completed"}) + } + + // Reset GetAllWorkers paging using a shim that respects perPage so total > perPage. + db.Workers = db.Workers[:50] // We rely on FakeDB returning total=len(Workers) for any page. + + // Page 0 of 50 with per_page=10 → expect next + last links, no prev/first. + r := httptest.NewRequest("GET", "/workers.json?per_page=10&page=0", nil) + w := httptest.NewRecorder() + app.handleWorkers(w, r) + link := w.Header().Get("Link") + if !strings.Contains(link, `rel="next"`) || !strings.Contains(link, `rel="last"`) { + t.Fatalf("page 0 link header missing next/last: %q", link) + } + if strings.Contains(link, `rel="prev"`) || strings.Contains(link, `rel="first"`) { + t.Fatalf("page 0 link header should not contain prev/first: %q", link) + } + + // Page 2 of 50 with per_page=10 → expect both directions. + r2 := httptest.NewRequest("GET", "/workers.json?per_page=10&page=2", nil) + w2 := httptest.NewRecorder() + app.handleWorkers(w2, r2) + link2 := w2.Header().Get("Link") + for _, rel := range []string{`rel="first"`, `rel="prev"`, `rel="next"`, `rel="last"`} { + if !strings.Contains(link2, rel) { + t.Errorf("middle page link header missing %s: %q", rel, link2) + } + } + + // Page 4 (final) of 50 with per_page=10 → only prev/first. + r3 := httptest.NewRequest("GET", "/workers.json?per_page=10&page=4", nil) + w3 := httptest.NewRecorder() + app.handleWorkers(w3, r3) + link3 := w3.Header().Get("Link") + if !strings.Contains(link3, `rel="prev"`) || !strings.Contains(link3, `rel="first"`) { + t.Errorf("last page link header missing prev/first: %q", link3) + } + if strings.Contains(link3, `rel="next"`) { + t.Errorf("last page link header should not contain next: %q", link3) + } +} + +// TestHandlers_HealthOK is a smoke test for /health. +func TestHandlers_HealthOK(t *testing.T) { + app, _, _, _ := schedTestApp() + r := httptest.NewRequest("GET", "/health", nil) + w := httptest.NewRecorder() + app.handleHealth(w, r) + if w.Code != 200 || strings.TrimSpace(w.Body.String()) != "ok" { + t.Fatalf("health response: %d %q", w.Code, w.Body.String()) + } +} + +// TestRoutes_AllPathsServed asserts every scheduler route is wired. +func TestRoutes_AllPathsServed(t *testing.T) { + app, _, _, _ := schedTestApp() + mux := app.Routes() + for _, path := range []string{"/health", "/usage", "/usage.json", "/history", "/history.json", "/jobs", "/jobs.json", "/workers", "/workers.json"} { + r := httptest.NewRequest("GET", path, nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, r) + if w.Code >= 500 { + t.Errorf("%s: status=%d body=%q", path, w.Code, w.Body.String()) + } + } +} + +// TestWantsJSON covers the suffix decision. +func TestWantsJSON(t *testing.T) { + for _, c := range []struct { + path string + want bool + }{ + {"/usage.json", true}, + {"/usage", false}, + {"/workers.json?page=1", true}, // request URL strips query before Path + } { + r := httptest.NewRequest("GET", c.path, nil) + if got := wantsJSON(r); got != c.want { + t.Errorf("%s: got %v want %v", c.path, got, c.want) + } + } +} + +// TestUsage_JSONReturnsActiveJobsAndWorkers covers the JSON branch. +func TestUsage_JSONReturnsActiveJobsAndWorkers(t *testing.T) { + app, db, _, _ := schedTestApp() + db.Jobs = []internal.Job{{JobID: 1, EntityID: 9, EntityName: "acme", EntityType: "Organization", JobLabels: []byte(`["x"]`), K8sPool: "scw"}} + db.Workers = []internal.Worker{{PodName: "p", EntityID: 9, EntityName: "acme", EntityType: "Organization", JobLabels: []byte(`["x"]`), K8sPool: "scw"}} + r := httptest.NewRequest("GET", "/usage.json", nil) + w := httptest.NewRecorder() + app.handleUsage(w, r) + if w.Code != 200 { + t.Fatalf("status=%d body=%q", w.Code, w.Body.String()) + } + var out struct { + Jobs []internal.Job `json:"jobs"` + Workers []internal.Worker `json:"workers"` + } + if err := json.Unmarshal(w.Body.Bytes(), &out); err != nil { + t.Fatal(err) + } + if len(out.Jobs) != 1 || len(out.Workers) != 1 { + t.Errorf("body=%s", w.Body.String()) + } +} + +// TestUsage_HTMLGroupsAndOrdering covers the HTML branch including grouping. +func TestUsage_HTMLGroupsAndOrdering(t *testing.T) { + app, db, _, _ := schedTestApp() + now := time.Now().UTC() + db.Jobs = []internal.Job{ + {JobID: 2, EntityID: 1, EntityName: "a", EntityType: "Organization", JobLabels: []byte(`["x"]`), K8sPool: "p1", CreatedAt: now.Add(-time.Minute)}, + {JobID: 1, EntityID: 1, EntityName: "a", EntityType: "Organization", JobLabels: []byte(`["x"]`), K8sPool: "p1", CreatedAt: now.Add(-2 * time.Minute)}, + } + db.Workers = []internal.Worker{ + {PodName: "p1", EntityID: 1, EntityName: "a", EntityType: "Organization", JobLabels: []byte(`["x"]`), K8sPool: "p1", Status: "running", CreatedAt: now}, + {PodName: "p2", EntityID: 2, EntityName: "b", EntityType: "Organization", JobLabels: []byte(`["y"]`), K8sPool: "p2", Status: "completed", CreatedAt: now}, + } + r := httptest.NewRequest("GET", "/usage", nil) + w := httptest.NewRecorder() + app.handleUsage(w, r) + if w.Code != 200 { + t.Fatalf("status=%d body=%q", w.Code, w.Body.String()) + } + body := w.Body.String() + for _, want := range []string{"a /", "b /", "Jobs (2):", "Workers (1):", "p1", "p2"} { + if !strings.Contains(body, want) { + t.Errorf("expected %q in:\n%s", want, body) + } + } + // Staging suffix + if !strings.Contains(body, "Staging") { + t.Errorf("expected Staging suffix: %s", body) + } +} + +// TestUsage_NoActive covers the empty branch. +func TestUsage_NoActive(t *testing.T) { + app, _, _, _ := schedTestApp() + r := httptest.NewRequest("GET", "/usage", nil) + w := httptest.NewRecorder() + app.handleUsage(w, r) + if !strings.Contains(w.Body.String(), "No active pools.") { + t.Errorf("body=%q", w.Body.String()) + } +} + +// TestUsage_DBError covers the 500 branch. +func TestUsage_DBError(t *testing.T) { + app, db, _, _ := schedTestApp() + db.OnGetActiveJobsAndWorkers = func() ([]internal.Job, []internal.Worker, error) { return nil, nil, errBoom } + r := httptest.NewRequest("GET", "/usage", nil) + w := httptest.NewRecorder() + app.handleUsage(w, r) + if w.Code != 500 { + t.Errorf("status=%d", w.Code) + } +} + +// TestHistory_RendersJobs covers /history HTML + the no-jobs branch. +func TestHistory_RendersJobs(t *testing.T) { + app, db, _, _ := schedTestApp() + db.Jobs = []internal.Job{{JobID: 7, Status: "completed", EntityName: "acme", RepoFullName: "acme/r", K8sPool: "scw-em-rv1"}} + + r := httptest.NewRequest("GET", "/history", nil) + w := httptest.NewRecorder() + app.handleJobs(w, r) + if w.Code != 200 || !strings.Contains(w.Body.String(), "acme") { + t.Errorf("body=%q", w.Body.String()) + } + + // JSON branch + r = httptest.NewRequest("GET", "/history.json", nil) + w = httptest.NewRecorder() + app.handleJobs(w, r) + var out []internal.Job + if err := json.Unmarshal(w.Body.Bytes(), &out); err != nil { + t.Fatal(err) + } + if len(out) != 1 || out[0].JobID != 7 { + t.Errorf("body=%s", w.Body.String()) + } + + // Empty rendered output → "No jobs found." + db.Jobs = nil + r = httptest.NewRequest("GET", "/history", nil) + w = httptest.NewRecorder() + app.handleJobs(w, r) + if !strings.Contains(w.Body.String(), "No jobs found.") { + t.Errorf("body=%q", w.Body.String()) + } + + // DB error → 500 + db.OnGetAllJobs = func(string, string, int, int) ([]internal.Job, int, error) { return nil, 0, errBoom } + r = httptest.NewRequest("GET", "/jobs", nil) + w = httptest.NewRecorder() + app.handleJobs(w, r) + if w.Code != 500 { + t.Errorf("status=%d", w.Code) + } +} + +// TestWorkers_EmptyMessage covers the no-rows branch. +func TestWorkers_EmptyMessage(t *testing.T) { + app, _, _, _ := schedTestApp() + r := httptest.NewRequest("GET", "/workers", nil) + w := httptest.NewRecorder() + app.handleWorkers(w, r) + if !strings.Contains(w.Body.String(), "No workers found.") { + t.Errorf("body=%q", w.Body.String()) + } +} + +// TestWorkers_DBError covers /workers 500. +func TestWorkers_DBError(t *testing.T) { + app, db, _, _ := schedTestApp() + db.OnGetAllWorkers = func(string, string, int, int) ([]internal.Worker, int, error) { return nil, 0, errBoom } + r := httptest.NewRequest("GET", "/workers", nil) + w := httptest.NewRecorder() + app.handleWorkers(w, r) + if w.Code != 500 { + t.Errorf("status=%d", w.Code) + } +} + +// TestParsePageParams covers each error branch. +func TestParsePageParams(t *testing.T) { + cases := []struct { + q string + wantErr bool + }{ + {"", false}, + {"start=2024-01-01&end=2024-01-02&page=2&per_page=20", false}, + {"start=-7d", false}, + {"start=garbage", true}, + {"end=garbage", true}, + {"page=-1", true}, + {"page=abc", true}, + {"per_page=0", true}, + {"per_page=xx", true}, + {"start=-Xd", true}, // bad int in -Xd + } + for _, c := range cases { + r := httptest.NewRequest("GET", "/jobs?"+c.q, nil) + _, _, _, _, err := parsePageParams(r) + if (err != nil) != c.wantErr { + t.Errorf("%q: err=%v wantErr=%v", c.q, err, c.wantErr) + } + } +} + +// TestJobs_InvalidParam triggers the 400 path through the handler. +func TestJobs_InvalidParam(t *testing.T) { + app, _, _, _ := schedTestApp() + r := httptest.NewRequest("GET", "/jobs?start=junk", nil) + w := httptest.NewRecorder() + app.handleJobs(w, r) + if w.Code != 400 { + t.Errorf("status=%d", w.Code) + } +} + +func TestWorkers_InvalidParam(t *testing.T) { + app, _, _, _ := schedTestApp() + r := httptest.NewRequest("GET", "/workers?page=abc", nil) + w := httptest.NewRecorder() + app.handleWorkers(w, r) + if w.Code != 400 { + t.Errorf("status=%d", w.Code) + } +} + +// TestWritePreProdSuffix covers the Prod branch of writePre. +func TestWritePreProdSuffix(t *testing.T) { + app, _, _, _ := schedTestApp() + app.Config.Prod = true + w := httptest.NewRecorder() + app.writePre(w, "Title", []string{"line1"}) + if !strings.Contains(w.Body.String(), "Prod") { + t.Errorf("body=%q", w.Body.String()) + } +} + +var errBoom = &stubErr{"boom"} + +type stubErr struct{ s string } + +func (e *stubErr) Error() string { return e.s } diff --git a/container-go/cmd/scheduler/main.go b/container/cmd/scheduler/main.go similarity index 94% rename from container-go/cmd/scheduler/main.go rename to container/cmd/scheduler/main.go index 107b79a..0e28a4f 100644 --- a/container-go/cmd/scheduler/main.go +++ b/container/cmd/scheduler/main.go @@ -1,5 +1,5 @@ // Command scheduler runs the reconciliation loop + the read-only HTTP -// dashboards (/usage, /history, /jobs, /workers). See container-go/CONTRACT.md. +// dashboards (/usage, /history, /jobs, /workers). package main import ( @@ -13,7 +13,7 @@ import ( "syscall" "time" - "github.com/riseproject-dev/riscv-runner-app/container-go/internal" + "github.com/riseproject-dev/riscv-runner-app/container/internal" ) func main() { diff --git a/container-go/cmd/scheduler/sync_jobs.go b/container/cmd/scheduler/sync_jobs.go similarity index 97% rename from container-go/cmd/scheduler/sync_jobs.go rename to container/cmd/scheduler/sync_jobs.go index 37a9caa..16e36d7 100644 --- a/container-go/cmd/scheduler/sync_jobs.go +++ b/container/cmd/scheduler/sync_jobs.go @@ -6,7 +6,7 @@ import ( "fmt" "log/slog" - "github.com/riseproject-dev/riscv-runner-app/container-go/internal" + "github.com/riseproject-dev/riscv-runner-app/container/internal" ) // syncJobsState converges each active job's DB row with GitHub's view of it. diff --git a/container/cmd/scheduler/sync_jobs_test.go b/container/cmd/scheduler/sync_jobs_test.go new file mode 100644 index 0000000..c17efe1 --- /dev/null +++ b/container/cmd/scheduler/sync_jobs_test.go @@ -0,0 +1,184 @@ +package main + +import ( + "context" + "testing" + + "github.com/riseproject-dev/riscv-runner-app/container/internal" +) + +// TestSyncJobsState_GetActiveError covers the early error return. +func TestSyncJobsState_GetActiveError(t *testing.T) { + app, db, _, _ := schedTestApp() + db.OnGetActiveJobs = func() ([]internal.Job, error) { return nil, errBoom } + if err := app.syncJobsState(context.Background()); err == nil { + t.Fatal("expected error") + } +} + +// TestSyncJobsState_SkipsJobMissingRepo covers the early-return on no RepoFullName. +func TestSyncJobsState_SkipsJobMissingRepo(t *testing.T) { + app, db, gh, _ := schedTestApp() + called := false + gh.OnAuthenticateApp = func(int64, int64) (string, error) { called = true; return "t", nil } + db.Jobs = []internal.Job{{JobID: 1, EntityName: "a", EntityType: "Organization", InstallationID: 9}} + if err := app.syncJobsState(context.Background()); err != nil { + t.Fatal(err) + } + if called { + t.Errorf("AuthenticateApp should not be called when RepoFullName is empty") + } +} + +// TestSyncJobsState_InvalidEntityType skips the job without auth. +func TestSyncJobsState_InvalidEntityType(t *testing.T) { + app, db, gh, _ := schedTestApp() + called := false + gh.OnAuthenticateApp = func(int64, int64) (string, error) { called = true; return "t", nil } + db.Jobs = []internal.Job{{JobID: 1, RepoFullName: "a/r", EntityName: "a", EntityType: "Bot"}} + if err := app.syncJobsState(context.Background()); err != nil { + t.Fatal(err) + } + if called { + t.Errorf("auth should not be called for invalid entity type") + } +} + +// TestSyncOneJob_InstallationNotFoundMarksFailed covers the install 404 path. +func TestSyncOneJob_InstallationNotFoundMarksFailed(t *testing.T) { + app, db, gh, _ := schedTestApp() + gh.OnAuthenticateApp = func(int64, int64) (string, error) { + return "", &internal.GitHubAPIError{StatusCode: 404, Message: "not found"} + } + markCalled := false + db.OnMarkJobFailed = func(id int64, info internal.FailureInfo) (string, error) { + markCalled = true + return "pending", nil + } + db.Jobs = []internal.Job{{JobID: 1, RepoFullName: "a/r", EntityName: "a", EntityType: "Organization", InstallationID: 9}} + if err := app.syncJobsState(context.Background()); err != nil { + t.Fatal(err) + } + if !markCalled { + t.Errorf("expected MarkJobFailed to be called") + } +} + +// TestSyncOneJob_NonAPIAuthErrorLogged covers the non-404 auth error. +func TestSyncOneJob_NonAPIAuthErrorLogged(t *testing.T) { + app, db, gh, _ := schedTestApp() + gh.OnAuthenticateApp = func(int64, int64) (string, error) { return "", errBoom } + db.Jobs = []internal.Job{{JobID: 1, RepoFullName: "a/r", EntityName: "a", EntityType: "Organization", InstallationID: 9}} + if err := app.syncJobsState(context.Background()); err != nil { + t.Fatal(err) + } + if len(db.MarkFailed) != 0 { + t.Errorf("non-API error should not mark job failed") + } +} + +// TestSyncOneJob_JobNotFoundMarksFailed covers GetJobInfo 404. +func TestSyncOneJob_JobNotFoundMarksFailed(t *testing.T) { + app, db, gh, _ := schedTestApp() + gh.OnAuthenticateApp = func(int64, int64) (string, error) { return "tok", nil } + gh.OnGetJobInfo = func(string, string, int64) (internal.GHJob, error) { + return internal.GHJob{}, &internal.GitHubAPIError{StatusCode: 404, Message: "n"} + } + called := false + db.OnMarkJobFailed = func(int64, internal.FailureInfo) (string, error) { called = true; return "pending", nil } + db.Jobs = []internal.Job{{JobID: 1, RepoFullName: "a/r", EntityName: "a", EntityType: "Organization", InstallationID: 9}} + if err := app.syncJobsState(context.Background()); err != nil { + t.Fatal(err) + } + if !called { + t.Errorf("expected MarkJobFailed on GetJobInfo 404") + } +} + +// TestSyncOneJob_JobInfoNonAPIError covers the non-404 GetJobInfo error. +func TestSyncOneJob_JobInfoNonAPIError(t *testing.T) { + app, db, gh, _ := schedTestApp() + gh.OnAuthenticateApp = func(int64, int64) (string, error) { return "tok", nil } + gh.OnGetJobInfo = func(string, string, int64) (internal.GHJob, error) { return internal.GHJob{}, errBoom } + db.Jobs = []internal.Job{{JobID: 1, RepoFullName: "a/r", EntityName: "a", EntityType: "Organization", InstallationID: 9}} + if err := app.syncJobsState(context.Background()); err != nil { + t.Fatal(err) + } + if len(db.MarkFailed) != 0 { + t.Errorf("non-API error should not mark job failed") + } +} + +// TestSyncOneJob_CompletedMarksCompleted covers the completed GH status path. +func TestSyncOneJob_CompletedMarksCompleted(t *testing.T) { + app, db, gh, _ := schedTestApp() + gh.OnAuthenticateApp = func(int64, int64) (string, error) { return "tok", nil } + gh.OnGetJobInfo = func(string, string, int64) (internal.GHJob, error) { + return internal.GHJob{Status: "completed", RunnerName: "r"}, nil + } + called := false + db.OnMarkJobComplete = func(int64, string) (string, error) { called = true; return "pending", nil } + db.Jobs = []internal.Job{{JobID: 1, Status: "pending", RepoFullName: "a/r", EntityName: "a", EntityType: "Organization", InstallationID: 9}} + if err := app.syncJobsState(context.Background()); err != nil { + t.Fatal(err) + } + if !called { + t.Errorf("expected MarkJobCompleted") + } +} + +// TestSyncOneJob_ConclusionImpliesCompleted covers the conclusion-present branch. +func TestSyncOneJob_ConclusionImpliesCompleted(t *testing.T) { + app, db, gh, _ := schedTestApp() + gh.OnAuthenticateApp = func(int64, int64) (string, error) { return "tok", nil } + conc := "failure" + gh.OnGetJobInfo = func(string, string, int64) (internal.GHJob, error) { + return internal.GHJob{Status: "in_progress", Conclusion: &conc, RunnerName: "r"}, nil + } + called := false + db.OnMarkJobComplete = func(int64, string) (string, error) { called = true; return "running", nil } + db.Jobs = []internal.Job{{JobID: 1, Status: "running", RepoFullName: "a/r", EntityName: "a", EntityType: "Organization", InstallationID: 9}} + if err := app.syncJobsState(context.Background()); err != nil { + t.Fatal(err) + } + if !called { + t.Errorf("expected MarkJobCompleted for in_progress+conclusion") + } +} + +// TestSyncOneJob_InProgressFromPendingPromotesToRunning covers MarkJobRunning. +func TestSyncOneJob_InProgressFromPendingPromotesToRunning(t *testing.T) { + app, db, gh, _ := schedTestApp() + gh.OnAuthenticateApp = func(int64, int64) (string, error) { return "tok", nil } + gh.OnGetJobInfo = func(string, string, int64) (internal.GHJob, error) { + return internal.GHJob{Status: "in_progress", RunnerName: "r"}, nil + } + called := false + db.OnMarkJobRunning = func(int64, string) (string, error) { called = true; return "pending", nil } + db.Jobs = []internal.Job{{JobID: 1, Status: "pending", RepoFullName: "a/r", EntityName: "a", EntityType: "Organization", InstallationID: 9}} + if err := app.syncJobsState(context.Background()); err != nil { + t.Fatal(err) + } + if !called { + t.Errorf("expected MarkJobRunning") + } +} + +// TestSyncOneJob_InProgressFromRunningIsNoop covers the no-op branch when DB +// already has running. +func TestSyncOneJob_InProgressFromRunningIsNoop(t *testing.T) { + app, db, gh, _ := schedTestApp() + gh.OnAuthenticateApp = func(int64, int64) (string, error) { return "tok", nil } + gh.OnGetJobInfo = func(string, string, int64) (internal.GHJob, error) { + return internal.GHJob{Status: "in_progress", RunnerName: "r"}, nil + } + called := false + db.OnMarkJobRunning = func(int64, string) (string, error) { called = true; return "running", nil } + db.Jobs = []internal.Job{{JobID: 1, Status: "running", RepoFullName: "a/r", EntityName: "a", EntityType: "Organization", InstallationID: 9}} + if err := app.syncJobsState(context.Background()); err != nil { + t.Fatal(err) + } + if called { + t.Errorf("MarkJobRunning should be a no-op when status is already running") + } +} diff --git a/container-go/cmd/scheduler/sync_workers.go b/container/cmd/scheduler/sync_workers.go similarity index 99% rename from container-go/cmd/scheduler/sync_workers.go rename to container/cmd/scheduler/sync_workers.go index 6e05d03..9030a04 100644 --- a/container-go/cmd/scheduler/sync_workers.go +++ b/container/cmd/scheduler/sync_workers.go @@ -6,7 +6,7 @@ import ( "strings" "time" - "github.com/riseproject-dev/riscv-runner-app/container-go/internal" + "github.com/riseproject-dev/riscv-runner-app/container/internal" ) // syncWorkersState runs the 5 reconciliation phases. The whole call sits diff --git a/container/cmd/scheduler/sync_workers_test.go b/container/cmd/scheduler/sync_workers_test.go new file mode 100644 index 0000000..1d625f2 --- /dev/null +++ b/container/cmd/scheduler/sync_workers_test.go @@ -0,0 +1,429 @@ +package main + +import ( + "context" + "testing" + "time" + + "github.com/riseproject-dev/riscv-runner-app/container/internal" +) + +// pendingWorker builds a Worker + matching Pod ready for the phase-3 tests. +func pendingWorker(name string, runningAt *time.Time) internal.Worker { + return internal.Worker{ + PodName: name, Provider: "github", + EntityID: 1, EntityName: "e", EntityType: "Organization", + InstallationID: 9, K8sPool: "scw-em-rv1", K8sImage: "img", + Status: "running", RunningAt: runningAt, + } +} + +func runningPod(name string) internal.Pod { + now := time.Now().UTC() + return internal.Pod{ + Name: name, Phase: "Running", CreationTime: now.Add(-30 * time.Minute), + Containers: []internal.ContainerStatus{{Name: "runner", Running: true, RunningStarted: &now}}, + } +} + +// TestPhase3_OfflineRunnerPastTimeoutFails covers b9c25e0: a GH runner in +// "offline" status past RUNNER_REGISTRATION_TIMEOUT_SECONDS gets killed. +func TestPhase3_OfflineRunnerPastTimeoutFails(t *testing.T) { + app, db, gh, kube := schedTestApp() + stale := time.Now().Add(-2 * internal.RunnerRegistrationTimeout) + + w := pendingWorker("rise-riscv-runner-staging-abc", &stale) + pod := runningPod(w.PodName) + db.Workers = []internal.Worker{w} + db.WorkerStatus[w.PodName] = "running" + kube.PodsByName[w.PodName] = pod + + gh.OnListRunnersOrgGroup = func(_, _ string, _ int64) ([]internal.GHRunner, error) { + return []internal.GHRunner{{ID: 1, Name: w.PodName, Status: "offline", Busy: false}}, nil + } + + if err := app.syncWorkersState(context.Background()); err != nil { + t.Fatalf("syncWorkersState: %v", err) + } + if len(db.MarkFailed) != 1 { + t.Fatalf("expected MarkWorkerFailed, got %v", db.MarkFailed) + } + if db.MarkFailed[0].Info.Reason != internal.ReasonRunnerNeverRegistered { + t.Errorf("reason=%q want runner_never_registered", db.MarkFailed[0].Info.Reason) + } + if len(kube.KillCalls) != 1 || kube.KillCalls[0] != w.PodName { + t.Errorf("expected KillPod for %s, got %v", w.PodName, kube.KillCalls) + } +} + +// TestPhase3_OnlineIdleRunnerPastTimeoutFails covers 83469ab: a runner +// idle past RUNNER_PENDING_TIMEOUT_SECONDS yields a runner_idle failure. +func TestPhase3_OnlineIdleRunnerPastTimeoutFails(t *testing.T) { + app, db, gh, kube := schedTestApp() + stale := time.Now().Add(-2 * internal.RunnerPendingTimeout) + + w := pendingWorker("rise-riscv-runner-staging-xyz", &stale) + pod := runningPod(w.PodName) + db.Workers = []internal.Worker{w} + db.WorkerStatus[w.PodName] = "running" + kube.PodsByName[w.PodName] = pod + + gh.OnListRunnersOrgGroup = func(_, _ string, _ int64) ([]internal.GHRunner, error) { + return []internal.GHRunner{{ID: 2, Name: w.PodName, Status: "online", Busy: false}}, nil + } + + if err := app.syncWorkersState(context.Background()); err != nil { + t.Fatalf("syncWorkersState: %v", err) + } + if len(db.MarkFailed) != 1 || db.MarkFailed[0].Info.Reason != internal.ReasonRunnerIdle { + t.Fatalf("expected RunnerIdle failure, got %v", db.MarkFailed) + } +} + +// TestPhase2_RunningPhaseTransitionsPendingToRunning covers PodPhaseSync. +func TestPhase2_RunningPhaseTransitionsPendingToRunning(t *testing.T) { + app, db, _, kube := schedTestApp() + w := pendingWorker("rise-riscv-runner-staging-p2", nil) + w.Status = "pending" + db.Workers = []internal.Worker{w} + db.WorkerStatus[w.PodName] = "pending" + kube.PodsByName[w.PodName] = runningPod(w.PodName) + + if err := app.syncWorkersState(context.Background()); err != nil { + t.Fatal(err) + } + if len(db.MarkRunning) != 1 { + t.Errorf("expected MarkWorkerRunning, got %v", db.MarkRunning) + } +} + +// TestPhase2_SucceededPodMarksWorkerCompleted covers the Succeeded branch. +func TestPhase2_SucceededPodMarksWorkerCompleted(t *testing.T) { + app, db, _, kube := schedTestApp() + w := pendingWorker("rise-riscv-runner-staging-done", nil) + w.Status = "running" + db.Workers = []internal.Worker{w} + db.WorkerStatus[w.PodName] = "running" + + t0 := time.Now() + pod := internal.Pod{ + Name: w.PodName, Phase: "Succeeded", + Containers: []internal.ContainerStatus{{Name: "runner", Terminated: true, TerminatedAt: &t0}}, + } + kube.PodsByName[w.PodName] = pod + if err := app.syncWorkersState(context.Background()); err != nil { + t.Fatal(err) + } + if len(db.MarkComplete) != 1 { + t.Errorf("expected MarkWorkerCompleted, got %v", db.MarkComplete) + } +} + +// TestPhase2_FailedPodMarksWorkerFailed covers the Failed branch. +func TestPhase2_FailedPodMarksWorkerFailed(t *testing.T) { + app, db, _, kube := schedTestApp() + w := pendingWorker("rise-riscv-runner-staging-bad", nil) + w.Status = "running" + db.Workers = []internal.Worker{w} + db.WorkerStatus[w.PodName] = "running" + + pod := internal.Pod{Name: w.PodName, Phase: "Failed"} + kube.PodsByName[w.PodName] = pod + if err := app.syncWorkersState(context.Background()); err != nil { + t.Fatal(err) + } + if len(db.MarkFailed) != 1 || db.MarkFailed[0].Info.Reason != internal.ReasonPodFailed { + t.Errorf("expected MarkWorkerFailed(pod_failed), got %v", db.MarkFailed) + } +} + +// TestPhase3_PendingPastTimeoutFailsWithStuckPending covers ReasonPodStuckPending. +func TestPhase3_PendingPastTimeoutFailsWithStuckPending(t *testing.T) { + app, db, _, kube := schedTestApp() + w := pendingWorker("rise-riscv-runner-staging-pending", nil) + w.Status = "pending" + db.Workers = []internal.Worker{w} + db.WorkerStatus[w.PodName] = "pending" + + old := time.Now().Add(-2 * internal.PodPendingTimeout) + pod := internal.Pod{Name: w.PodName, Phase: "Pending", CreationTime: old} + kube.PodsByName[w.PodName] = pod + + if err := app.syncWorkersState(context.Background()); err != nil { + t.Fatal(err) + } + if len(db.MarkFailed) != 1 || db.MarkFailed[0].Info.Reason != internal.ReasonPodStuckPending { + t.Errorf("got %v", db.MarkFailed) + } +} + +// TestPhase3_RunningNotGHKnownWithinTimeout is the still-may-register branch. +func TestPhase3_RunningNotGHKnownWithinTimeout(t *testing.T) { + app, db, gh, kube := schedTestApp() + recent := time.Now().Add(-10 * time.Second) + w := pendingWorker("rise-riscv-runner-staging-fresh", &recent) + db.Workers = []internal.Worker{w} + db.WorkerStatus[w.PodName] = "running" + kube.PodsByName[w.PodName] = runningPod(w.PodName) + // No GH runners returned + gh.OnListRunnersOrgGroup = func(_, _ string, _ int64) ([]internal.GHRunner, error) { + return nil, nil + } + if err := app.syncWorkersState(context.Background()); err != nil { + t.Fatal(err) + } + if len(db.MarkFailed) != 0 { + t.Errorf("should not mark failed within registration timeout: %v", db.MarkFailed) + } +} + +// TestPhase3_PodHasRunJobSelfUnregistered: worker has a row in jobs.k8s_pod → +// skip even when GH no longer reports the runner. +func TestPhase3_PodHasRunJobSelfUnregistered(t *testing.T) { + app, db, gh, kube := schedTestApp() + stale := time.Now().Add(-2 * internal.RunnerRegistrationTimeout) + w := pendingWorker("rise-riscv-runner-staging-selfunreg", &stale) + db.Workers = []internal.Worker{w} + db.WorkerStatus[w.PodName] = "running" + db.JobExistsByPod[w.PodName] = true + kube.PodsByName[w.PodName] = runningPod(w.PodName) + gh.OnListRunnersOrgGroup = func(_, _ string, _ int64) ([]internal.GHRunner, error) { return nil, nil } + if err := app.syncWorkersState(context.Background()); err != nil { + t.Fatal(err) + } + if len(db.MarkFailed) != 0 { + t.Errorf("self-unregistered worker should not be failed: %v", db.MarkFailed) + } +} + +// TestPhase3_OnlineBusyIsHealthy covers the no-op branch. +func TestPhase3_OnlineBusyIsHealthy(t *testing.T) { + app, db, gh, kube := schedTestApp() + stale := time.Now().Add(-2 * internal.RunnerPendingTimeout) + w := pendingWorker("rise-riscv-runner-staging-busy", &stale) + db.Workers = []internal.Worker{w} + db.WorkerStatus[w.PodName] = "running" + kube.PodsByName[w.PodName] = runningPod(w.PodName) + gh.OnListRunnersOrgGroup = func(_, _ string, _ int64) ([]internal.GHRunner, error) { + return []internal.GHRunner{{ID: 1, Name: w.PodName, Status: "online", Busy: true}}, nil + } + if err := app.syncWorkersState(context.Background()); err != nil { + t.Fatal(err) + } + if len(db.MarkFailed) != 0 { + t.Errorf("online+busy is healthy: %v", db.MarkFailed) + } +} + +// TestPhase3_UnknownRunnerStatusStillFails covers the catch-all `running` branch. +func TestPhase3_UnknownRunnerStatusStillFails(t *testing.T) { + app, db, gh, kube := schedTestApp() + stale := time.Now().Add(-2 * internal.RunnerRegistrationTimeout) + w := pendingWorker("rise-riscv-runner-staging-unknown", &stale) + db.Workers = []internal.Worker{w} + db.WorkerStatus[w.PodName] = "running" + kube.PodsByName[w.PodName] = runningPod(w.PodName) + gh.OnListRunnersOrgGroup = func(_, _ string, _ int64) ([]internal.GHRunner, error) { + return []internal.GHRunner{{ID: 5, Name: w.PodName, Status: "stuck"}}, nil + } + if err := app.syncWorkersState(context.Background()); err != nil { + t.Fatal(err) + } + if len(db.MarkFailed) != 1 { + t.Errorf("unknown status past timeout should fail: %v", db.MarkFailed) + } +} + +// TestPhase4_GitHubCleanup_DeletesUnknownRunners covers the !known branch. +func TestPhase4_GitHubCleanup_DeletesUnknownRunners(t *testing.T) { + app, db, gh, kube := schedTestApp() + // One healthy worker we'll leave alone + w := pendingWorker("rise-riscv-runner-staging-keep", nil) + w.Status = "completed" + db.Workers = []internal.Worker{w} + db.WorkerStatus[w.PodName] = "completed" + + gh.OnListRunnersOrgGroup = func(_, _ string, _ int64) ([]internal.GHRunner, error) { + return []internal.GHRunner{ + {ID: 1, Name: w.PodName, Status: "offline"}, // matches completed worker → delete + {ID: 2, Name: "rise-riscv-runner-staging-ghost", Status: "x"}, // no worker row → delete + {ID: 3, Name: "unrelated-runner", Status: "x"}, // wrong prefix → skip + }, nil + } + deletes := []int64{} + gh.OnDeleteRunnerOrg = func(_, _ string, id int64) error { + deletes = append(deletes, id) + return nil + } + + // Force HealthChecks to populate cache: add an active worker to the + // same scope so HealthChecks visits the key. + active := pendingWorker("rise-riscv-runner-staging-active", nil) + active.Status = "running" + db.Workers = append(db.Workers, active) + db.WorkerStatus[active.PodName] = "running" + kube.PodsByName[active.PodName] = runningPod(active.PodName) + + if err := app.syncWorkersState(context.Background()); err != nil { + t.Fatal(err) + } + if len(deletes) < 2 { + t.Errorf("expected DeleteRunnerOrg for completed + ghost, got %v", deletes) + } +} + +// TestPhase5_DeleteTerminalPodsPastGrace covers DeleteTerminalPods grace branch. +func TestPhase5_DeleteTerminalPodsPastGrace(t *testing.T) { + app, _, _, kube := schedTestApp() + old := time.Now().Add(-2 * internal.PodDeleteGrace) + kube.PodsByName["stale"] = internal.Pod{ + Name: "stale", Phase: "Succeeded", + Containers: []internal.ContainerStatus{{Name: "runner", Terminated: true, TerminatedAt: &old}}, + } + now := time.Now() + kube.PodsByName["fresh"] = internal.Pod{ + Name: "fresh", Phase: "Succeeded", CreationTime: now, + Containers: []internal.ContainerStatus{{Name: "runner", Terminated: true, TerminatedAt: &now}}, + } + // Phase 5 doesn't require workers, so test it directly. + app.DeleteTerminalPods(context.Background(), kube.PodsByName) + if len(kube.DeleteCalls) != 1 || kube.DeleteCalls[0] != "stale" { + t.Errorf("expected only stale deleted, got %v", kube.DeleteCalls) + } +} + +// TestPhase5_SkipsNonTerminalAndFreshTerminal covers the early-continue branches. +func TestPhase5_SkipsNonTerminalAndFreshTerminal(t *testing.T) { + app, _, _, kube := schedTestApp() + kube.PodsByName["running"] = internal.Pod{Name: "running", Phase: "Running"} + kube.PodsByName["pending"] = internal.Pod{Name: "pending", Phase: "Pending"} + app.DeleteTerminalPods(context.Background(), kube.PodsByName) + if len(kube.DeleteCalls) != 0 { + t.Errorf("non-terminal pods should not be deleted: %v", kube.DeleteCalls) + } +} + +// TestFetchGHRunners_UserScopeFiltersPrefix covers the user/repo branch. +func TestFetchGHRunners_UserScopeFiltersPrefix(t *testing.T) { + app, _, gh, _ := schedTestApp() + gh.OnListRunnersRepo = func(string, string) ([]internal.GHRunner, error) { + return []internal.GHRunner{ + {ID: 1, Name: "rise-riscv-runner-staging-good"}, + {ID: 2, Name: "other-runner"}, + }, nil + } + cache := map[orgRunnerKey]map[string]internal.GHRunner{} + key := orgRunnerKey{EntityType: internal.EntityUser, EntityName: "luhenry", RepoFullName: "luhenry/r"} + got := app.fetchGHRunners(context.Background(), key, "tok", cache) + if len(got) != 1 { + t.Errorf("expected prefix filter, got %v", got) + } +} + +// TestFetchGHRunners_Cached avoids the GH call on repeated lookups. +func TestFetchGHRunners_Cached(t *testing.T) { + app, _, gh, _ := schedTestApp() + calls := 0 + gh.OnListRunnersOrgGroup = func(_, _ string, _ int64) ([]internal.GHRunner, error) { + calls++ + return nil, nil + } + cache := map[orgRunnerKey]map[string]internal.GHRunner{} + key := orgRunnerKey{EntityType: internal.EntityOrganization, EntityName: "acme"} + app.fetchGHRunners(context.Background(), key, "tok", cache) + app.fetchGHRunners(context.Background(), key, "tok", cache) + if calls > 1 { + t.Errorf("expected cache hit, calls=%d", calls) + } +} + +// TestFetchGHRunners_EnsureRunnerGroupError covers the org error path. +func TestFetchGHRunners_EnsureRunnerGroupError(t *testing.T) { + app, _, gh, _ := schedTestApp() + gh.OnEnsureRunnerGroup = func(_, _, _ string) (int64, error) { return 0, errBoom } + cache := map[orgRunnerKey]map[string]internal.GHRunner{} + got := app.fetchGHRunners(context.Background(), orgRunnerKey{EntityType: internal.EntityOrganization, EntityName: "acme"}, "tok", cache) + if len(got) != 0 { + t.Errorf("expected empty on error, got %v", got) + } +} + +// TestFetchGHRunners_ListError covers the list error branch (org). +func TestFetchGHRunners_ListError(t *testing.T) { + app, _, gh, _ := schedTestApp() + gh.OnEnsureRunnerGroup = func(_, _, _ string) (int64, error) { return 1, nil } + gh.OnListRunnersOrgGroup = func(_, _ string, _ int64) ([]internal.GHRunner, error) { return nil, errBoom } + cache := map[orgRunnerKey]map[string]internal.GHRunner{} + got := app.fetchGHRunners(context.Background(), orgRunnerKey{EntityType: internal.EntityOrganization, EntityName: "acme"}, "tok", cache) + if len(got) != 0 { + t.Errorf("expected empty on list error, got %v", got) + } +} + +// TestDeleteGHRunner_UserScopeUsesRepoDelete confirms repo-scoped delete path. +func TestDeleteGHRunner_UserScopeUsesRepoDelete(t *testing.T) { + app, _, gh, _ := schedTestApp() + repoDeleted := false + gh.OnDeleteRunnerRepo = func(_, repo string, _ int64) error { + if repo == "luhenry/r" { + repoDeleted = true + } + return nil + } + ok := app.deleteGHRunner(context.Background(), "tok", + orgRunnerKey{EntityType: internal.EntityUser, EntityName: "luhenry", RepoFullName: "luhenry/r"}, + 7, "worker-1") + if !ok || !repoDeleted { + t.Errorf("expected repo-scoped delete ok=%v called=%v", ok, repoDeleted) + } +} + +// TestDeleteGHRunner_FailureReturnsFalse covers the error path. +func TestDeleteGHRunner_FailureReturnsFalse(t *testing.T) { + app, _, gh, _ := schedTestApp() + gh.OnDeleteRunnerOrg = func(_, _ string, _ int64) error { return errBoom } + ok := app.deleteGHRunner(context.Background(), "tok", + orgRunnerKey{EntityType: internal.EntityOrganization, EntityName: "acme"}, 1, "w") + if ok { + t.Error("expected false on failure") + } +} + +// TestFailAndCleanup_GitHubDeleteFailureAborts covers the "abort cleanup" path. +func TestFailAndCleanup_GitHubDeleteFailureAborts(t *testing.T) { + app, db, gh, kube := schedTestApp() + gh.OnDeleteRunnerOrg = func(_, _ string, _ int64) error { return errBoom } + w := pendingWorker("rise-riscv-runner-staging-busy", nil) + pod := runningPod(w.PodName) + kube.PodsByName[w.PodName] = pod + app.failAndCleanup(context.Background(), w, pod, "tok", + orgRunnerKey{EntityType: internal.EntityOrganization, EntityName: "e"}, + internal.GHRunner{ID: 1}, true, internal.ReasonRunnerIdle) + if len(db.MarkFailed) != 0 { + t.Errorf("MarkFailed should not be called when GH delete aborts: %v", db.MarkFailed) + } + if len(kube.KillCalls) != 0 { + t.Errorf("KillPod should not be called: %v", kube.KillCalls) + } +} + +// TestSyncWorkersState_PhasesIsolated covers be1434c: an orphan in phase 1 +// produces a `completed` worker, then phase 2 still observes the pod-less +// view and does nothing further. (Phases re-fetch their snapshot.) +func TestSyncWorkersState_PhasesIsolated(t *testing.T) { + app, db, _, _ := schedTestApp() + w := pendingWorker("rise-riscv-runner-staging-orphan", nil) + w.Status = "running" + db.Workers = []internal.Worker{w} + db.WorkerStatus[w.PodName] = "running" + + if err := app.syncWorkersState(context.Background()); err != nil { + t.Fatalf("sync: %v", err) + } + if len(db.MarkOrphaned) != 1 || db.MarkOrphaned[0] != w.PodName { + t.Fatalf("expected exactly one MarkWorkerOrphaned, got %v", db.MarkOrphaned) + } + if len(db.MarkFailed) != 0 { + t.Errorf("phase 1 should not produce a MarkFailed call: %v", db.MarkFailed) + } +} diff --git a/container-go/cmd/scheduler/templates.go b/container/cmd/scheduler/templates.go similarity index 98% rename from container-go/cmd/scheduler/templates.go rename to container/cmd/scheduler/templates.go index 17fe1c6..612387f 100644 --- a/container-go/cmd/scheduler/templates.go +++ b/container/cmd/scheduler/templates.go @@ -9,7 +9,7 @@ import ( "strings" "time" - "github.com/riseproject-dev/riscv-runner-app/container-go/internal" + "github.com/riseproject-dev/riscv-runner-app/container/internal" ) var statusColors = map[string]string{ diff --git a/container/cmd/scheduler/templates_test.go b/container/cmd/scheduler/templates_test.go new file mode 100644 index 0000000..f80a60c --- /dev/null +++ b/container/cmd/scheduler/templates_test.go @@ -0,0 +1,189 @@ +package main + +import ( + "context" + "encoding/json" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/riseproject-dev/riscv-runner-app/container/internal" +) + +// TestRenderWorker_RendersV1AndV2FailureInfo locks b081af0: render_worker +// produces text for both shapes — v1 has no reason line, v2 includes reason, +// pod_reason/message, container exit codes/logs, events. +func TestRenderWorker_RendersV1AndV2FailureInfo(t *testing.T) { + app, _, _, _ := schedTestApp() + + v1 := json.RawMessage(`{"version":1,"message":"old"}`) + w := internal.Worker{ + PodName: "p1", + Status: "failed", + FailureInfo: v1, + } + rendered := app.renderWorker(httptest.NewRequest("GET", "/", nil), w) + joined := strings.Join(rendered, "\n") + if strings.Contains(joined, "Reason:") { + t.Errorf("v1 should not render a Reason: line:\n%s", joined) + } + + v2 := json.RawMessage(`{ + "version": 2, + "reason": "pod_failed", + "pod_reason": "OOMKilled", + "pod_message": "out of memory", + "containers": { + "runner": {"exit_code": 137, "reason": "OOMKilled", "message": "kill -9", "logs": "boom\nboom2"} + }, + "events": [{"type":"Warning","reason":"Failed","message":"oops","last_seen":"2025-01-01"}] + }`) + w2 := internal.Worker{PodName: "p2", Status: "failed", FailureInfo: v2} + rendered2 := app.renderWorker(httptest.NewRequest("GET", "/", nil), w2) + joined2 := strings.Join(rendered2, "\n") + wants := []string{"Reason: pod_failed", "Pod: OOMKilled", "Container runner: exit=137", "boom", "Failed: oops"} + for _, want := range wants { + if !strings.Contains(joined2, want) { + t.Errorf("v2 output missing %q:\n%s", want, joined2) + } + } +} + +// TestRenderJob_AllShapes covers renderJob across with-html-url / without / +// without-k8s-pod variants. +func TestRenderJob_AllShapes(t *testing.T) { + htmlURL := "https://example.com/r" + pod := "rise-pod-1" + j := internal.Job{ + JobID: 1, + Status: "pending", + RepoFullName: "acme/r", + JobLabels: []byte(`["ubuntu-24.04-riscv"]`), + HTMLURL: &htmlURL, + K8sPod: &pod, + CreatedAt: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), + } + s := renderJob(j) + for _, want := range []string{`acme/r#1`, "rise-pod-1", "[ubuntu-24.04-riscv]", "[pending"} { + if !strings.Contains(s, want) { + t.Errorf("missing %q in %q", want, s) + } + } + + // Without html url / pod + j2 := internal.Job{JobID: 2, Status: "completed", RepoFullName: "x/y", JobLabels: []byte(`[]`)} + s2 := renderJob(j2) + if strings.Contains(s2, "") { + t.Errorf("should mark missing pod: %q", s2) + } +} + +// TestRenderLiveEvents covers the three branches: error, none, formatted. +func TestRenderLiveEvents(t *testing.T) { + app, _, _, kube := schedTestApp() + + // Error branch + kube.OnGetPodEvents = func(string) ([]internal.PodEvent, error) { return nil, errBoom } + got := app.renderLiveEvents(context.Background(), "p") + if len(got) != 1 || !strings.Contains(got[0], "error fetching") { + t.Errorf("got %v", got) + } + + // No events + kube.OnGetPodEvents = func(string) ([]internal.PodEvent, error) { return nil, nil } + got = app.renderLiveEvents(context.Background(), "p") + if len(got) != 1 || !strings.Contains(got[0], "(none)") { + t.Errorf("got %v", got) + } + + // Formatted with LastSeen + now := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + kube.OnGetPodEvents = func(string) ([]internal.PodEvent, error) { + return []internal.PodEvent{ + {Type: "Warning", Reason: "Pull", Message: "img", LastSeen: &now}, + }, nil + } + got = app.renderLiveEvents(context.Background(), "p") + if len(got) != 1 || !strings.Contains(got[0], "Pull: img") { + t.Errorf("got %v", got) + } + + // FirstSeen fallback + kube.OnGetPodEvents = func(string) ([]internal.PodEvent, error) { + return []internal.PodEvent{ + {Type: "Normal", Reason: "Created", Message: "", FirstSeen: &now}, + }, nil + } + got = app.renderLiveEvents(context.Background(), "p") + if len(got) != 1 || !strings.Contains(got[0], "2025-01-01") { + t.Errorf("got %v", got) + } + + // Both nil → "unknown" + kube.OnGetPodEvents = func(string) ([]internal.PodEvent, error) { + return []internal.PodEvent{{Type: "Normal", Reason: "x"}}, nil + } + got = app.renderLiveEvents(context.Background(), "p") + if !strings.Contains(got[0], "unknown") { + t.Errorf("got %v", got) + } +} + +// TestFormatHelpers covers formatStatus / formatTimestamp / formatLabels. +func TestFormatHelpers(t *testing.T) { + if s := formatStatus("pending"); !strings.Contains(s, "#ccc504") { + t.Errorf("pending color: %q", s) + } + if s := formatStatus("???"); !strings.Contains(s, "#666") { + t.Errorf("default color: %q", s) + } + if formatTimestamp(time.Time{}) != "?" { + t.Errorf("zero timestamp wrong") + } + if formatLabels(json.RawMessage(`null`)) != "" { + t.Errorf("null labels wrong") + } + if got := formatLabels(json.RawMessage(`["a","b"]`)); got != "[a, b]" { + t.Errorf("labels: %q", got) + } +} + +// TestStringOr covers each branch. +func TestStringOr(t *testing.T) { + if stringOr(nil, "d") != "d" { + t.Error("nil default") + } + if stringOr("ok", "d") != "ok" { + t.Error("string passthrough") + } + if stringOr(7, "d") != "7" { + t.Error("fmt fallback") + } +} + +// TestWorkers_FieldNames locks invariant 1055cc8 — the JSON serialisation of +// internal.Worker must keep the field names UI consumers expect. +func TestWorkers_FieldNames(t *testing.T) { + w := internal.Worker{PodName: "p", Status: "pending"} + b, _ := json.Marshal(w) + s := string(b) + for _, want := range []string{ + `"pod_name"`, + `"status"`, + `"job_labels"`, + `"k8s_pool"`, + `"k8s_image"`, + `"entity_id"`, + `"entity_name"`, + `"installation_id"`, + `"created_at"`, + } { + if !strings.Contains(s, want) { + t.Errorf("missing field %s in %s", want, s) + } + } +} diff --git a/container/constants.py b/container/constants.py deleted file mode 100644 index 74e3ef4..0000000 --- a/container/constants.py +++ /dev/null @@ -1,92 +0,0 @@ -import json -import os -from enum import Enum - -class EntityType(str, Enum): - ORGANIZATION = "Organization" - USER = "User" - - -class WebhookOutcome(str, Enum): - """Stored verbatim in installation_events.outcome (TEXT column).""" - OK = "ok" - JOB_STORED = "job_stored" - JOB_ALREADY_EXISTS = "job_already_exists" - JOB_MARKED_RUNNING = "job_marked_running" - JOB_MARKED_COMPLETED = "job_marked_completed" - JOB_NOT_FOUND = "job_not_found" - IGNORED_ACTION = "ignored_action" - IGNORED_NO_LABEL = "ignored_no_label" - IGNORED_EVENT = "ignored_event" - AUTH_404 = "auth_404" - AUTH_OTHER_ERROR = "auth_other_error" - -PROD = os.environ["PROD"].lower() == "true" -PROD_URL = os.environ["PROD_URL"] -STAGING_URL = os.environ["STAGING_URL"] - -K8S_KUBECONFIG = os.environ["K8S_KUBECONFIG"] - -GHAPP_ORG_ID = 2167633 # https://github.com/apps/rise-risc-v-runners -GHAPP_ORG_PRIVATE_KEY = os.environ["GHAPP_ORG_PRIVATE_KEY"] # PEM-encoded private key for the org GitHub App -GHAPP_PERSONAL_ID = 3131217 # https://github.com/apps/rise-risc-v-runners-personal -GHAPP_PERSONAL_PRIVATE_KEY = os.environ["GHAPP_PERSONAL_PRIVATE_KEY"] # PEM-encoded private key for the personal GitHub App -GHAPP_WEBHOOK_SECRET = os.environ["GHAPP_WEBHOOK_SECRET"] # Secret for validating GitHub webhook signatures -TRACE_API_SECRET = os.environ["TRACE_API_SECRET"] # Bearer token gating /trace/* endpoints - -POSTGRES_URL = os.environ["POSTGRES_URL"] # postgresql://user:pass@host:5432/db?sslmode=require -POSTGRES_SCHEMA = "prod" if PROD else "staging" -POSTGRES_MAXCONN = 10 - -RUNNER_GROUP_NAME = f"RISE RISC-V Runners{'' if PROD else " (staging)"}" -RUNNER_NAME_PREFIX = f"rise-riscv-runner{'' if PROD else '-staging'}-" - -RUNNER_REGISTRATION_TIMEOUT_SECONDS = 120 # pod Running but GH never sees runner -RUNNER_PENDING_TIMEOUT_SECONDS = 600 # pod Running but GH never picks up the runner -POD_PENDING_TIMEOUT_SECONDS = 600 # pod stuck Pending (no capacity, image pull, etc.) -POD_DELETE_GRACE_SECONDS = 6 * 60 * 60 # keep terminal pods around so logs remain inspectable - -# gh api orgs/ --jq '.id' -RISEPROJECT_DEV_ORG_ID = 152654596 # github.com/riseproject-dev -PYTORCH_ORG_ID = 21003710 # github.com/pytorch -GGML_ORG_ORG_ID = 134263123 # github.com/ggml-org (for llama.cpp) -# gh api users/ --jq '.id' -LUHENRY_USER_ID = 660779 # github.com/luhenry - -ENTITY_CONFIG = { - RISEPROJECT_DEV_ORG_ID: { - "max_workers": None, - "pre_allocated": 0, - "staging": [ - "riscv-runner-sample-staging", - ], - }, - PYTORCH_ORG_ID: { - "max_workers": 20, - "pre_allocated": 0, - }, - GGML_ORG_ORG_ID: { - "max_workers": 20, - "pre_allocated": 0, - }, - LUHENRY_USER_ID: { - "max_workers": None, - "pre_allocated": 0, - }, -} - -STAGING_ENTITIES = {oid: c["staging"] for oid, c in ENTITY_CONFIG.items() if c.get("staging", False)} - -GO_GHFE_URL = os.environ.get("GO_GHFE_URL", "") -GO_GHFE_ROUTING = { - RISEPROJECT_DEV_ORG_ID, - LUHENRY_USER_ID, -} - -RUNNER_REGISTRY = "rg.fr-par.scw.cloud/funcscwriseriscvrunnerappqdvknz9s" -RUNNER_IMAGE = "riscv-runner" -RUNNER_UBUNTU_24_04_TAG = "ubuntu-24.04-latest" if PROD else "ubuntu-24.04-staging" -RUNNER_UBUNTU_26_04_TAG = "ubuntu-26.04-latest" if PROD else "ubuntu-26.04-staging" - -RUNNER_IMAGE_UBUNTU_24_04 = f"{RUNNER_REGISTRY}/{RUNNER_IMAGE}:{RUNNER_UBUNTU_24_04_TAG}" -RUNNER_IMAGE_UBUNTU_26_04 = f"{RUNNER_REGISTRY}/{RUNNER_IMAGE}:{RUNNER_UBUNTU_26_04_TAG}" diff --git a/container/db.py b/container/db.py deleted file mode 100644 index d405970..0000000 --- a/container/db.py +++ /dev/null @@ -1,865 +0,0 @@ -from __future__ import annotations - -import contextlib -import json -import logging -import select -import time -import threading -from typing import Any, Iterator - -import psycopg2 -import psycopg2.extras -from psycopg2.pool import ThreadedConnectionPool - -from constants import POSTGRES_URL, POSTGRES_SCHEMA, POSTGRES_MAXCONN - -logger = logging.getLogger(__name__) - - -class DuplicateRunnerNameException(Exception): - """Raised when add_worker() detects a pod_name collision.""" - pass - - -# --- Connection management --- -# PostgreSQL connections are 1-query-at-a-time and NOT thread-safe. -# Waitress serves webhooks with 4+ threads, so each needs its own connection. -# ThreadedConnectionPool: minconn=1, maxconn=POSTGRES_MAXCONN. Threads borrow/return connections. -# -# A semaphore gates access so threads block (instead of crashing with PoolError) -# when all connections are in use. -# -# Thread-local caching: `hold_connection()` pins one connection to the current -# thread so nested `_get_conn()` calls share the same transaction — lets -# `sync_workers_state` take `LOCK TABLE` once and have all subsequent -# `mark_worker_*` calls respect it on the same connection. - -_pool: ThreadedConnectionPool | None = None -_pool_semaphore: threading.Semaphore | None = None -_pool_lock = threading.Lock() -_thread_local = threading.local() - - -def _init_pool() -> ThreadedConnectionPool: - global _pool, _pool_semaphore - if _pool is not None: - return _pool - with _pool_lock: - if _pool is not None: - return _pool - _pool = ThreadedConnectionPool( - minconn=1, - maxconn=POSTGRES_MAXCONN, - dsn=POSTGRES_URL, - ) - _pool_semaphore = threading.Semaphore(POSTGRES_MAXCONN) - return _pool - - -class _PoolConnection: - """Context manager that borrows a connection from the pool and returns it. - - - Acquires a semaphore slot before borrowing (blocks if pool is full). - - Sets search_path on every borrowed connection. - - Auto-commits on clean exit, auto-rollbacks on exception. - - Releases the semaphore slot after returning the connection. - - If the current thread has pinned a connection via `hold_connection()`, this - short-circuits: no new connection is borrowed, no commit/rollback happens on - exit (the outer `hold_connection` owns the lifecycle), and exceptions still - propagate so the outer block rolls back. - """ - def __init__(self) -> None: - self.conn = None - self._held = False - - def __enter__(self): - held = getattr(_thread_local, "conn", None) - if held is not None: - self.conn = held - self._held = True - return self.conn - pool = _init_pool() - _pool_semaphore.acquire() - try: - self.conn = pool.getconn() - with self.conn.cursor() as cur: - cur.execute(f"SET search_path TO {POSTGRES_SCHEMA}") - except Exception: - if self.conn is not None: - pool.putconn(self.conn) - self.conn = None - _pool_semaphore.release() - raise - return self.conn - - def __exit__(self, exc_type, exc_val, exc_tb): - if self._held: - # Even if we are in a `db.hold_connection()` block, do the commit/rollback as it would normally - if exc_type is not None: - self.conn.rollback() - else: - self.conn.commit() - return False - if self.conn is not None: - if exc_type is not None: - self.conn.rollback() - else: - self.conn.commit() - _init_pool().putconn(self.conn) - self.conn = None - _pool_semaphore.release() - return False - - -def _get_conn() -> _PoolConnection: - return _PoolConnection() - - -@contextlib.contextmanager -def hold_connection(): - """Pin one pool connection to the current thread for the duration of the block. - - Nested `_get_conn()` calls in the same thread reuse this connection so all - operations share a single transaction. COMMIT on clean exit, ROLLBACK on - exception. Used by `sync_workers_state` to hold `LOCK TABLE workers IN - EXCLUSIVE MODE` across all the `mark_worker_*` calls it makes. - """ - assert getattr(_thread_local, "conn", None) is None, "held connection already active" - pool = _init_pool() - _pool_semaphore.acquire() - conn = None - try: - conn = pool.getconn() - with conn.cursor() as cur: - cur.execute(f"SET search_path TO {POSTGRES_SCHEMA}") - _thread_local.conn = conn - yield conn - finally: - _thread_local.conn = None - if conn is not None: - pool.putconn(conn) - _pool_semaphore.release() - - -# --- Schema bootstrap --- - -def ensure_schema() -> None: - """Create schema, enum type, tables, and indexes if they don't exist. Idempotent. - - Uses a direct connection (not the pool context manager) because DDL - requires autocommit=True, which must be set before any statement runs. - The pool context manager runs SET search_path on enter, which starts a - transaction and prevents setting autocommit afterwards. - """ - pool = _init_pool() - _pool_semaphore.acquire() - conn = pool.getconn() - try: - conn.autocommit = True - with conn.cursor() as cur: - cur.execute(f"CREATE SCHEMA IF NOT EXISTS {POSTGRES_SCHEMA}") - cur.execute(f"SET search_path TO {POSTGRES_SCHEMA}") - - # Create enum types (idempotent via DO blocks) - cur.execute(""" - DO $$ BEGIN - CREATE TYPE status_enum AS ENUM ('pending', 'running', 'completed', 'failed'); - EXCEPTION - WHEN duplicate_object THEN null; - END $$ - """) - cur.execute(""" - DO $$ BEGIN - CREATE TYPE provider_enum AS ENUM ('github', 'gitlab', 'azdo'); - EXCEPTION - WHEN duplicate_object THEN null; - END $$ - """) - cur.execute(""" - DO $$ BEGIN - CREATE TYPE entity_type_enum AS ENUM ('Organization', 'User'); - EXCEPTION - WHEN duplicate_object THEN null; - END $$ - """) - - # Jobs table - cur.execute(""" - CREATE TABLE IF NOT EXISTS jobs ( - job_id BIGINT PRIMARY KEY, - status status_enum NOT NULL DEFAULT 'pending', - failure_info JSONB, - provider provider_enum NOT NULL, - entity_id BIGINT NOT NULL, - entity_name TEXT NOT NULL, - entity_type TEXT NOT NULL, - repo_full_name TEXT NOT NULL, - installation_id BIGINT NOT NULL, - job_labels JSONB NOT NULL DEFAULT '[]', - k8s_pool TEXT NOT NULL, - k8s_image TEXT NOT NULL, - html_url TEXT, - created_at TIMESTAMPTZ NOT NULL DEFAULT now(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT now() - ) - """) - - # Workers table - cur.execute(""" - CREATE TABLE IF NOT EXISTS workers ( - pod_name TEXT PRIMARY KEY, - provider provider_enum NOT NULL, - entity_id BIGINT NOT NULL, - entity_name TEXT NOT NULL, - entity_type TEXT NOT NULL, - installation_id BIGINT NOT NULL, - repo_full_name TEXT, - job_labels JSONB NOT NULL DEFAULT '[]', - k8s_pool TEXT NOT NULL, - k8s_image TEXT NOT NULL, - k8s_node TEXT, - status status_enum NOT NULL DEFAULT 'pending', - failure_info JSONB, - created_at TIMESTAMPTZ NOT NULL DEFAULT now(), - running_at TIMESTAMPTZ, - completed_at TIMESTAMPTZ, - updated_at TIMESTAMPTZ NOT NULL DEFAULT now() - ) - """) - - # Indexes (IF NOT EXISTS for idempotency) - cur.execute(""" - CREATE INDEX IF NOT EXISTS idx_jobs_active - ON jobs (entity_id, job_labels, created_at) - WHERE status != 'completed' - """) - cur.execute(""" - CREATE INDEX IF NOT EXISTS idx_jobs_reconcile - ON jobs (installation_id) - WHERE status != 'completed' - """) - cur.execute(""" - CREATE INDEX IF NOT EXISTS idx_jobs_created - ON jobs (created_at DESC) - """) - cur.execute(""" - CREATE INDEX IF NOT EXISTS idx_workers_active - ON workers (entity_id, job_labels, k8s_pool) - WHERE status != 'completed' - """) - - # Append-only event log for GitHub App install/uninstall/auth lifecycle. - # See plan: explains why an installation_id can disappear before a job - # is picked up. One row per webhook delivery + one per scheduler auth - # failure. Most context lives in `payload`; only filter/index keys - # are dedicated columns. - cur.execute(""" - CREATE TABLE IF NOT EXISTS installation_events ( - id BIGSERIAL PRIMARY KEY, - source TEXT NOT NULL, - event TEXT NOT NULL, - outcome TEXT NOT NULL, - installation_id BIGINT, - app_id BIGINT, - entity_type entity_type_enum, - entity_id BIGINT, - entity_name TEXT, - payload JSONB NOT NULL, - received_at TIMESTAMPTZ NOT NULL DEFAULT now() - ) - """) - cur.execute(""" - CREATE INDEX IF NOT EXISTS idx_install_events_installation - ON installation_events (installation_id, entity_id) - """) - cur.execute(""" - CREATE INDEX IF NOT EXISTS idx_install_events_entity - ON installation_events (entity_id, received_at DESC) - """) - - conn.autocommit = False - finally: - pool.putconn(conn) - _pool_semaphore.release() - logger.info("Schema '%s' ensured (tables + indexes)", POSTGRES_SCHEMA) - - -# --- Job operations --- - -def add_job(job_id: int, provider: str, entity_id: int, entity_name: str, entity_type: str | Any, - repo_full_name: str, installation_id: int, labels: list[str], - k8s_pool: str, k8s_image: str, html_url: str) -> bool: - """Store a new job. Returns True if created, False if duplicate.""" - sorted_labels = json.dumps(sorted(labels)) - entity_type_val = entity_type.value if hasattr(entity_type, 'value') else str(entity_type) - now = time.time() - - with _get_conn() as conn: - with conn.cursor() as cur: - cur.execute(""" - INSERT INTO jobs (job_id, status, provider, entity_id, entity_name, entity_type, - repo_full_name, installation_id, job_labels, k8s_pool, - k8s_image, html_url, created_at, updated_at) - VALUES (%s, 'pending', %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, - to_timestamp(%s), to_timestamp(%s)) - ON CONFLICT (job_id) DO NOTHING - """, (int(job_id), provider, int(entity_id), entity_name, entity_type_val, - repo_full_name, int(installation_id), sorted_labels, k8s_pool, - k8s_image, html_url, now, now)) - created = cur.rowcount > 0 - - if created: - cur.execute(f"NOTIFY {POSTGRES_SCHEMA}_queue_event, %s", (str(job_id),)) - - if created: - logger.info("Stored job %s for entity %s pool %s", job_id, entity_name, k8s_pool) - else: - logger.debug("Job %s already exists, skipping", job_id) - return created - - -def mark_job_running(job_id: int, runner_name: str | None) -> str | None: - """Update job status to running. Returns previous status string or None. - - Only allows the transition: pending -> running. - """ - with _get_conn() as conn: - with conn.cursor() as cur: - cur.execute(""" - WITH prev AS (SELECT status FROM jobs WHERE job_id = %s) - UPDATE jobs - SET status = 'running', - k8s_pod = COALESCE(k8s_pod, %s), - updated_at = now() - WHERE job_id = %s AND status = 'pending' - RETURNING (SELECT status::text FROM prev) as prev_status - """, (int(job_id), runner_name, int(job_id))) - row = cur.fetchone() - - if row is not None: - logger.info("Job %s status updated to running (was %s)", job_id, row[0]) - return row[0] - - # UPDATE didn't match — either job doesn't exist or is already running/completed - cur.execute("SELECT status::text FROM jobs WHERE job_id = %s", (int(job_id),)) - existing = cur.fetchone() - if existing is None: - logger.debug("Job %s not found in PostgreSQL", job_id) - return None - logger.debug("Job %s not updated to running (current status: %s)", job_id, existing[0]) - return existing[0] - - -def mark_job_completed(job_id: int, runner_name: str | None) -> str | None: - """Update job status to completed. Returns previous status string or None. - - Allows transitions: pending|running -> completed. - """ - with _get_conn() as conn: - with conn.cursor() as cur: - cur.execute(""" - WITH prev AS (SELECT status FROM jobs WHERE job_id = %s) - UPDATE jobs - SET status = 'completed', - k8s_pod = COALESCE(k8s_pod, %s), - updated_at = now() - WHERE job_id = %s AND (status = 'pending' OR status = 'running') - RETURNING (SELECT status::text FROM prev) as prev_status - """, (int(job_id), runner_name, int(job_id))) - row = cur.fetchone() - - if row is not None: - logger.info("Job %s status updated to completed (was %s)", job_id, row[0]) - return row[0] - - # UPDATE didn't match — either job doesn't exist or is already completed - cur.execute("SELECT status::text FROM jobs WHERE job_id = %s", (int(job_id),)) - existing = cur.fetchone() - if existing is None: - logger.debug("Job %s not found in PostgreSQL", job_id) - return None - return existing[0] - - -def mark_job_failed(job_id: int, failure_info: dict) -> str | None: - """Update job status to failed. Returns previous status string or None. - - Allows transitions: pending|running -> failed. - """ - assert "version" in failure_info and isinstance(failure_info['version'], int), f"failure_info must have a failure_info['version'] parameter and it must be an int" - with _get_conn() as conn: - with conn.cursor() as cur: - cur.execute(""" - WITH prev AS (SELECT status FROM jobs WHERE job_id = %s) - UPDATE jobs SET status = 'failed', failure_info = %s, updated_at = now() - WHERE job_id = %s AND (status = 'pending' OR status = 'running') - RETURNING (SELECT status::text FROM prev) as prev_status - """, (int(job_id), json.dumps(failure_info), int(job_id))) - row = cur.fetchone() - - if row is not None: - logger.info("Job %s status updated to completed (was %s)", job_id, row[0]) - return row[0] - - # UPDATE didn't match — either job doesn't exist or is already completed - cur.execute("SELECT status::text FROM jobs WHERE job_id = %s", (int(job_id),)) - existing = cur.fetchone() - if existing is None: - logger.debug("Job %s not found in PostgreSQL", job_id) - return None - return existing[0] - - -def job_exists_for_pod(pod_name: str) -> bool: - """Return True if any job row has k8s_pod = pod_name.""" - with _get_conn() as conn: - with conn.cursor() as cur: - cur.execute( - "SELECT 1 FROM jobs WHERE k8s_pod = %s LIMIT 1", - (pod_name,), - ) - return cur.fetchone() is not None - - -# --- Worker operations --- - -def get_pool_demand(entity_id: int, job_labels: list[str]) -> tuple[int, int]: - """Return (job_count, worker_count) for an entity + label set. - - Matches demand and supply by (entity_id, job_labels) rather than (entity_id, k8s_pool). - This fixes the bug where different label sets mapping to the same pool cause stuck workers. - Labels are sorted internally for consistent JSONB equality. - """ - sorted_labels = json.dumps(sorted(job_labels)) - with _get_conn() as conn: - with conn.cursor() as cur: - cur.execute(""" - SELECT - (SELECT COUNT(*) FROM jobs - WHERE entity_id = %s AND job_labels = %s - AND (status = 'pending' OR status = 'running')) as job_count, - (SELECT COUNT(*) FROM workers - WHERE entity_id = %s AND job_labels = %s - AND (status = 'pending' OR status = 'running')) as worker_count - """, (int(entity_id), sorted_labels, int(entity_id), sorted_labels)) - row = cur.fetchone() - return row[0], row[1] - - -def get_total_workers_for_entity(entity_id: int) -> int: - """Return total worker count across all pools for an entity.""" - with _get_conn() as conn: - with conn.cursor() as cur: - cur.execute(""" - SELECT COUNT(*) FROM workers - WHERE entity_id = %s AND (status = 'pending' OR status = 'running') - """, (int(entity_id),)) - row = cur.fetchone() - return row[0] - - -def get_pending_jobs() -> list[psycopg2.extras.RealDictRow]: - """Return all pending jobs in FIFO order as full row dicts. - - Consumers (demand_match) read fields via `job["job_id"]`, `job["entity_id"]`, - etc., so we return RealDictCursor rows — not raw tuples. - """ - with _get_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: - cur.execute(""" - SELECT * - FROM jobs - WHERE status = 'pending' - ORDER BY created_at - """) - return cur.fetchall() - - -def add_worker(provider: str, entity_id: int, entity_name: str, entity_type: str, - installation_id: int, repo_full_name: str | None, k8s_pool: str, pod_name: str, - job_labels: list[str], k8s_image: str) -> None: - """Add a worker. Raises DuplicateRunnerNameException on pod_name collision. - - `repo_full_name` is only meaningful for user-scoped runners (personal accounts, - where runners are registered under a specific repo). Pass None for org-scoped runners. - """ - sorted_labels = json.dumps(sorted(job_labels)) - - with _get_conn() as conn: - with conn.cursor() as cur: - cur.execute(""" - INSERT INTO workers (pod_name, provider, entity_id, entity_name, entity_type, - installation_id, repo_full_name, k8s_pool, job_labels, - k8s_image, status, created_at, updated_at) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, 'pending', now(), now()) - ON CONFLICT (pod_name) DO NOTHING - """, (pod_name, provider, int(entity_id), entity_name, entity_type, - int(installation_id), repo_full_name, k8s_pool, sorted_labels, k8s_image)) - - if cur.rowcount == 0: - raise DuplicateRunnerNameException( - f"Worker pod_name '{pod_name}' already exists") - - logger.debug("Added worker %s to pool %s:%s", pod_name, entity_id, k8s_pool) - - -def mark_worker_running(pod_name: str, k8s_node: str, running_at: Any): - with _get_conn() as conn: - with conn.cursor() as cur: - cur.execute(""" - UPDATE workers - SET status = 'running', - k8s_node = %s, - running_at = COALESCE(running_at, %s, now()), - updated_at = now() - WHERE pod_name = %s AND status = 'pending' - """, (k8s_node, running_at, pod_name)) - logger.debug("Marked worker %s running", pod_name) - - -def mark_worker_completed(pod_name: str, k8s_node: str, completed_at: Any): - with _get_conn() as conn: - with conn.cursor() as cur: - cur.execute(""" - UPDATE workers - SET status = 'completed', - k8s_node = COALESCE(k8s_node, %s), - completed_at = COALESCE(completed_at, %s, now()), - updated_at = now() - WHERE pod_name = %s AND (status = 'pending' OR status = 'running') - """, (k8s_node, completed_at, pod_name)) - logger.debug("Marked worker %s completed", pod_name) - - -def mark_worker_failed(pod_name: str, k8s_node: str, failure_info: dict, completed_at: Any) -> None: - """Mark a worker as failed with failure_info and completed_at. - - Allows transitions: pending -> failed, running -> failed. - `completed_at` may be a datetime or None (DB now() fallback). - """ - assert failure_info and "version" in failure_info and isinstance(failure_info["version"], int), \ - "failure_info must be a dict with an int 'version' field" - with _get_conn() as conn: - with conn.cursor() as cur: - cur.execute(""" - UPDATE workers - SET status = 'failed', - k8s_node = COALESCE(k8s_node, %s), - failure_info = %s, - completed_at = COALESCE(%s, now()), - updated_at = now() - WHERE pod_name = %s AND (status = 'pending' OR status = 'running') - """, (k8s_node, json.dumps(failure_info), completed_at, pod_name)) - logger.debug("Marked worker %s failed", pod_name) - - -def mark_worker_orphaned(pod_name: str): - with _get_conn() as conn: - with conn.cursor() as cur: - cur.execute(""" - UPDATE workers - SET status = 'completed', - completed_at = COALESCE(completed_at, now()), - updated_at = now() - WHERE pod_name = %s AND (status = 'pending' OR status = 'running') - """, (pod_name,)) - logger.debug("Marked worker %s orphaned", pod_name) - - -def get_active_jobs_and_workers() -> tuple[list[psycopg2.extras.RealDictRow], list[psycopg2.extras.RealDictRow]]: - """Return (active_jobs, active_workers) as raw rows from PostgreSQL.""" - with _get_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: - cur.execute(""" - SELECT * - FROM jobs WHERE status = 'pending' OR status = 'running' - ORDER BY created_at - """) - jobs = cur.fetchall() - - cur.execute(""" - SELECT * - FROM workers WHERE status = 'pending' OR status = 'running' - ORDER BY created_at - """) - workers = cur.fetchall() - - return jobs, workers - - -def get_active_jobs() -> list[psycopg2.extras.RealDictRow]: - """Return active_jobs as raw rows from PostgreSQL.""" - with _get_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: - cur.execute(""" - SELECT * - FROM jobs WHERE status = 'pending' OR status = 'running' - ORDER BY created_at - """) - return cur.fetchall() - - -def get_active_workers() -> list[psycopg2.extras.RealDictRow]: - """Return active_workers as raw rows from PostgreSQL.""" - with _get_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: - cur.execute(""" - SELECT * - FROM workers WHERE status = 'pending' OR status = 'running' - ORDER BY created_at - """) - return cur.fetchall() - - -def get_all_jobs(start: str | None = None, end: str | None = None, - page: int = 0, per_page: int = 100) -> tuple[list[psycopg2.extras.RealDictRow], int]: - """Return (jobs, total_count) with optional date filtering and paging. - - Args: - start: ISO date string (YYYY-MM-DD). Only jobs created on or after this date. - end: ISO date string (YYYY-MM-DD). Only jobs created before this date. - page: Page number (0-indexed). - per_page: Number of jobs per page. - - Returns: - Tuple of (list of job dicts, total matching count for pagination). - """ - with _get_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: - conditions = [] - params: list = [] - if start: - conditions.append("created_at >= %s::timestamptz") - params.append(start) - if end: - conditions.append("created_at < %s::timestamptz") - params.append(end) - where = "WHERE " + " AND ".join(conditions) if conditions else "" - - cur.execute(f"SELECT COUNT(*) AS total FROM jobs {where}", params) - total = cur.fetchone()["total"] - - page_params = params + [per_page, page * per_page] - cur.execute(f""" - SELECT * - FROM jobs - {where} - ORDER BY created_at DESC - LIMIT %s OFFSET %s - """, page_params) - rows = cur.fetchall() - return rows, total - - -def get_all_workers(start: str | None = None, end: str | None = None, - page: int = 0, per_page: int = 100) -> tuple[list[psycopg2.extras.RealDictRow], int]: - """Return (workers, total_count) with optional date filtering and paging. - - Args: - start: ISO date string (YYYY-MM-DD). Only workers created on or after this date. - end: ISO date string (YYYY-MM-DD). Only workers created before this date. - page: Page number (0-indexed). - per_page: Number of workers per page. - - Returns: - Tuple of (list of worker dicts, total matching count for pagination). - """ - with _get_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: - conditions = [] - params: list = [] - if start: - conditions.append("created_at >= %s::timestamptz") - params.append(start) - if end: - conditions.append("created_at < %s::timestamptz") - params.append(end) - where = "WHERE " + " AND ".join(conditions) if conditions else "" - - cur.execute(f"SELECT COUNT(*) AS total FROM workers {where}", params) - total = cur.fetchone()["total"] - - page_params = params + [per_page, page * per_page] - cur.execute(f""" - SELECT * - FROM workers - {where} - ORDER BY created_at DESC - LIMIT %s OFFSET %s - """, page_params) - rows = cur.fetchall() - return rows, total - - -def get_workers_for_reconcile(terminal_lookback_seconds: int = 3600) -> list[psycopg2.extras.RealDictRow]: - """Return all active workers plus recently-terminal workers for reconciliation. - - Active = pending/running. Terminal = completed/failed within the lookback window. - Terminal rows are included so sync_workers_state can delete their GitHub counterparts. - """ - with _get_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: - cur.execute(""" - SELECT * - FROM workers - WHERE status IN ('pending', 'running') - OR (status IN ('completed', 'failed') - AND completed_at IS NOT NULL - AND completed_at > now() - (%s || ' seconds')::interval) - """, (int(terminal_lookback_seconds),)) - return cur.fetchall() - - -# --- Installation event log --- - -def add_installation_event( - *, - source: str, - event: str, - outcome: str, - payload: dict, - installation_id: int | None = None, - app_id: int | None = None, - entity_type: str | None = None, - entity_id: int | None = None, - entity_name: str | None = None, -) -> int: - """Insert one installation_events row. Returns the new BIGSERIAL id. - - `payload` is required (the column is JSONB NOT NULL); pass {} when there's - nothing to log. Caller is responsible for calling this in its own - transaction (separate from any side-effect writes); see the webhook handler. - """ - assert payload is not None, "payload is required (pass {} for empty)" - with _get_conn() as conn: - with conn.cursor() as cur: - cur.execute(""" - INSERT INTO installation_events - (source, event, outcome, - installation_id, app_id, entity_type, entity_id, - entity_name, payload) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) - RETURNING id - """, (source, event, outcome, - installation_id, app_id, entity_type, entity_id, - entity_name, json.dumps(payload))) - return cur.fetchone()[0] - - -def get_events_by_entity_id(entity_id: int) -> list[psycopg2.extras.RealDictRow]: - """Return all events for an entity, ordered by received_at. - - For workflow_job.* rows, projects payload->workflow_job.id and - payload->repository.full_name as `job_id` / `repo_full_name` so the - timeline stays readable without a payload fetch. The full payload is - NOT projected — clients fetch it via /trace/payload/ when needed. - """ - with _get_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: - cur.execute(""" - SELECT id, source, event, outcome, - installation_id, app_id, entity_type, entity_id, - entity_name, received_at, - CASE WHEN event LIKE 'workflow_job.%%' - THEN payload->'workflow_job'->>'id' END AS job_id, - CASE WHEN event LIKE 'workflow_job.%%' - THEN payload->'repository'->>'full_name' END AS repo_full_name - FROM installation_events - WHERE entity_id = %s - ORDER BY received_at - """, (int(entity_id),)) - return cur.fetchall() - - -def get_payload_by_id(event_id: int) -> dict | None: - """Return only the JSONB payload for one installation_events row, or None.""" - with _get_conn() as conn: - with conn.cursor() as cur: - cur.execute( - "SELECT payload FROM installation_events WHERE id = %s", - (int(event_id),), - ) - row = cur.fetchone() - return row[0] if row else None - - -def get_entity_id_for_installation(installation_id: int) -> int | None: - """Resolve installation_id -> entity_id. - - Looks first in installation_events for the most recent row with a non-NULL - entity_id (logged installations carry it). Falls back to jobs.entity_id if - no events exist for that installation_id yet (e.g. install pre-dates the - logging change). - """ - with _get_conn() as conn: - with conn.cursor() as cur: - cur.execute(""" - SELECT entity_id FROM installation_events - WHERE installation_id = %s AND entity_id IS NOT NULL - ORDER BY received_at DESC - LIMIT 1 - """, (int(installation_id),)) - row = cur.fetchone() - if row is not None: - return row[0] - cur.execute(""" - SELECT entity_id FROM jobs - WHERE installation_id = %s - ORDER BY created_at DESC - LIMIT 1 - """, (int(installation_id),)) - row = cur.fetchone() - return row[0] if row else None - - -def get_entity_id_for_job(job_id: int) -> int | None: - """Resolve job_id -> entity_id via the jobs table (one query).""" - with _get_conn() as conn: - with conn.cursor() as cur: - cur.execute( - "SELECT entity_id FROM jobs WHERE job_id = %s", - (int(job_id),), - ) - row = cur.fetchone() - return row[0] if row else None - - -# --- Pub/Sub --- - -_listen_conn = None -_listen_lock = threading.Lock() - - -def _get_listen_conn(): - """Get or create a dedicated AUTOCOMMIT connection for LISTEN/NOTIFY.""" - global _listen_conn - if _listen_conn is not None and _listen_conn.closed == 0: - return _listen_conn - with _listen_lock: - if _listen_conn is not None and _listen_conn.closed == 0: - return _listen_conn - _listen_conn = psycopg2.connect(POSTGRES_URL) - _listen_conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) - with _listen_conn.cursor() as cur: - cur.execute(f"SET search_path TO {POSTGRES_SCHEMA}") - cur.execute(f"LISTEN {POSTGRES_SCHEMA}_queue_event") - return _listen_conn - - -def wait_for_job(timeout: int) -> None: - """Block until a new job is published or timeout expires. - - Drains all buffered notifications after waking so the scheduler isn't - woken again for events that arrived while it was processing. - """ - assert timeout - conn = _get_listen_conn() - ready = select.select([conn], [], [], timeout) - if ready[0]: - conn.poll() - if conn.notifies: - logger.debug("Woken by PG queue event: %d notifications", len(conn.notifies)) - # Drain all buffered notifications - conn.notifies.clear() diff --git a/container/ghfe.py b/container/ghfe.py deleted file mode 100644 index 50036fd..0000000 --- a/container/ghfe.py +++ /dev/null @@ -1,650 +0,0 @@ -import hashlib -import hmac -import json -import logging -import requests -import time - -from flask import Flask, g, request, make_response -from flask.json import dumps as json_dumps - -import db -import github as gh -from constants import * - - -ORG_APP_INSTALL_URL = "https://github.com/apps/rise-risc-v-runners/installations/new" -PERSONAL_APP_INSTALL_URL = "https://github.com/apps/rise-risc-v-runners-personal/installations/new" - -app = Flask(__name__) - -logger = logging.getLogger(__name__) - - -class WebhookError(Exception): - """Exception raised during webhook processing.""" - def __init__(self, status_code, message): - self.status_code = status_code - self.message = message - super().__init__(message) - - -@app.errorhandler(WebhookError) -def handle_webhook_error(e): - if e.status_code == 200: - logger.debug(e.message) - else: - logger.warning(e.message) - return make_response(e.message, e.status_code) - - -@app.errorhandler(AssertionError) -def handle_assertion_error(e): - logger.info(e) - return make_response(str(e), 400) - - -@app.before_request -def _start_timer(): - g.print_perf_log = False - g.request_start = time.perf_counter() - - -@app.after_request -def _log_duration(response): - if request.method == "GET" and request.path == "/health": - pass - elif not g.print_perf_log: - pass - else: - elapsed_ms = (time.perf_counter() - g.request_start) * 1000 - logger.info( - "%s %s -> %d in %.1fms", - request.method, request.path, response.status_code, elapsed_ms, - ) - return response - - -# --- Webhook validation --- - -def compute_signature(body, secret): - return hmac.new(secret.encode('utf-8'), msg=body.encode('utf-8'), digestmod=hashlib.sha256) - - -def verify_signature(body, signature, secret): - """Verify that the body was sent from GitHub by validating the signature.""" - if not signature: - return False, "X-Hub-Signature-256 header is missing!" - - hash = compute_signature(body, secret) - expected_signature = "sha256=" + hash.hexdigest() - - if not hmac.compare_digest(expected_signature, signature): - return False, f"Request signatures didn't match! Expected: {expected_signature}, Got: {signature}" - - return True, "Signatures match" - - -def check_webhook_signature(headers, body): - """Verify the webhook signature.""" - if not "X-Github-Event" in request.headers: - raise WebhookError(400, "Missing X-Github-Event header") - event = headers["X-Github-Event"] - - if not "X-Hub-Signature-256" in request.headers: - raise WebhookError(400, "Missing X-Hub-Signature-256 header") - signature = headers["X-Hub-Signature-256"] - - is_valid, message = verify_signature(body, signature, GHAPP_WEBHOOK_SECRET) - if not is_valid: - logger.warning("Webhook signature verification failed: %s", message) - raise WebhookError(401, message) - - return event, body - - -# Per-section keys to drop from a workflow_job webhook payload before logging -# it. The `sender`, `repository`, `organization`, and `workflow_job` objects -# carry dozens of redundant URL fields (`*_url`) plus a few large secondary -# fields (`license`, `steps[]`) that we never use for diagnostics. The only -# URL we keep is `workflow_job.html_url` — the clickable link to the run on -# GitHub.com that operators use during investigations. -_WORKFLOW_JOB_DROP_KEYS = { - "sender": frozenset({ - "url", "html_url", - "gists_url", "repos_url", "avatar_url", "events_url", "starred_url", - "followers_url", "following_url", "organizations_url", - "subscriptions_url", "received_events_url", - }), - "repository": frozenset({ - "url", "license", - "git_url", "ssh_url", "svn_url", "html_url", - "keys_url", "tags_url", "blobs_url", "clone_url", "forks_url", - "hooks_url", "pulls_url", "teams_url", "trees_url", "events_url", - "issues_url", "labels_url", "merges_url", "mirror_url", "archive_url", - "commits_url", "compare_url", "branches_url", "comments_url", - "contents_url", "git_refs_url", "git_tags_url", "releases_url", - "statuses_url", "assignees_url", "downloads_url", "languages_url", - "milestones_url", "stargazers_url", "deployments_url", - "git_commits_url", "subscribers_url", "contributors_url", - "issue_events_url", "subscription_url", "collaborators_url", - "issue_comment_url", "notifications_url", - }), - "repository.owner": frozenset({ - "url", "html_url", - "gists_url", "repos_url", "avatar_url", "events_url", "starred_url", - "followers_url", "following_url", "organizations_url", - "subscriptions_url", "received_events_url", - }), - "organization": frozenset({ - "url", - "hooks_url", "repos_url", "avatar_url", "events_url", "issues_url", - "members_url", "public_members_url", - }), - "workflow_job": frozenset({ - "url", "run_url", "check_run_url", - "steps", - }), -} - - -def _drop_keys(d, keys): - """Return a shallow copy of d with `keys` removed (no-op if d isn't a dict).""" - if not isinstance(d, dict): - return d - return {k: v for k, v in d.items() if k not in keys} - - -def _trim_workflow_job_payload(payload): - """Drop the noisy fields listed in _WORKFLOW_JOB_DROP_KEYS from a - workflow_job payload before persisting it. - - Cuts ~70 redundant URL fields and the `steps[]` array. The only URL we - keep is workflow_job.html_url (used as the operational link to the run). - """ - trimmed = dict(payload) - if isinstance(trimmed.get("sender"), dict): - trimmed["sender"] = _drop_keys(trimmed["sender"], _WORKFLOW_JOB_DROP_KEYS["sender"]) - if isinstance(trimmed.get("repository"), dict): - repo = dict(trimmed["repository"]) - if isinstance(repo.get("owner"), dict): - repo["owner"] = _drop_keys(repo["owner"], _WORKFLOW_JOB_DROP_KEYS["repository.owner"]) - trimmed["repository"] = _drop_keys(repo, _WORKFLOW_JOB_DROP_KEYS["repository"]) - if isinstance(trimmed.get("organization"), dict): - trimmed["organization"] = _drop_keys(trimmed["organization"], _WORKFLOW_JOB_DROP_KEYS["organization"]) - if isinstance(trimmed.get("workflow_job"), dict): - trimmed["workflow_job"] = _drop_keys(trimmed["workflow_job"], _WORKFLOW_JOB_DROP_KEYS["workflow_job"]) - return trimmed - - -def _log_webhook_event( - *, - event: str, - outcome: WebhookOutcome, - payload: dict, - app_id: int, - installation_id: int | None = None, - entity_type: str | None = None, - entity_id: int | None = None, - entity_name: str | None = None, -) -> None: - try: - db.add_installation_event( - source="webhook", - event=event, - outcome=outcome, - payload=payload, - app_id=app_id, - installation_id=installation_id, - entity_type=entity_type, - entity_id=entity_id, - entity_name=entity_name, - ) - except Exception: - logger.exception("Failed to record installation_events row event=%s outcome=%s", event, outcome) - raise - - -def authorize_entity(payload): - """Authorize the repository owner (organization or personal account).""" - owner = payload["repository"]["owner"] - owner_id = owner["id"] - if not owner_id: - raise WebhookError(400, "Owner ID is missing in payload") - - owner_type = owner["type"] - if not owner_type: - raise WebhookError(400, "Owner Type is missing in payload") - if owner_type not in (EntityType.ORGANIZATION, EntityType.USER): - raise WebhookError(400, f"Unsupported owner type: {owner_type}") - - return owner_id, EntityType(owner_type) - - -def match_labels_to_k8s(org_id, repo_full_name, job_labels): - """ - Map workflow job labels to a k8s pool name and container image. - - Returns (k8s_pool, k8s_image) on a match, or None if no rule matches. - """ - # Special case(s) for PyTorch org - if org_id == PYTORCH_ORG_ID or (org_id == RISEPROJECT_DEV_ORG_ID and repo_full_name in ["riseproject-dev/pytorch", "riseproject-dev/executorch"]): - if any("linux.riscv64.xlarge" in job_label or "linux.riscv64.2xlarge" in job_label for job_label in job_labels): - return "scw-em-rv1", RUNNER_IMAGE_UBUNTU_24_04 - elif "ubuntu-24.04-riscv" in job_labels: - return "scw-em-rv1", RUNNER_IMAGE_UBUNTU_24_04 - else: - return None - - # Special case(s) for GGML org - elif org_id == GGML_ORG_ORG_ID or (org_id == RISEPROJECT_DEV_ORG_ID and repo_full_name in ["riseproject-dev/llama.cpp", "riseproject-dev/llama.cpp-validation"]): - if job_labels == ["ubuntu-24.04-riscv"]: - return "cloudv10x-jupiter", RUNNER_IMAGE_UBUNTU_24_04 - else: - return None - - # General cases - elif job_labels == ["ubuntu-24.04-riscv"]: - return "scw-em-rv1", RUNNER_IMAGE_UBUNTU_24_04 - # FIXME: there is no hardware that supports 26.04 (RVA23) just yet - # elif job_labels == ["ubuntu-26.04-riscv"]: - # return "scw-em-rv1", RUNNER_IMAGE_UBUNTU_26_04 - - return None - - -# --- Routes --- - -@app.route("/health", methods=['GET']) -def health(): - return "ok" - - -def _setup_page(title, body_html, status=200): - html = f""" -{title} - -{body_html}""" - return make_response(html, status) - - -def _render_setup(expected): - installation_id = request.args.get("installation_id") - if not installation_id: - return _setup_page( - "RISE RISC-V Runners — Setup", - f"""

Missing installation id

-

This page is the post-install redirect target for the RISE RISC-V Runners GitHub Apps. It expects an installation_id query parameter, which GitHub normally appends after installation.

-

If you got here by mistake, you can (re-)install one of the apps:

-

Install on an organization Install on a personal account

""", - status=400, - ) - - try: - installation = gh.get_installation(installation_id, entity_type=expected) - except gh.GitHubAPIError as e: - if e.status_code == 404: - wrong_app_name = "personal" if expected == EntityType.ORGANIZATION else "organization" - right_url = PERSONAL_APP_INSTALL_URL if expected == EntityType.ORGANIZATION else ORG_APP_INSTALL_URL - return _setup_page( - "RISE RISC-V Runners — Wrong app", - f"""

Installation not found for this app

-

We couldn't find installation {installation_id} under the app you just installed. The most likely cause is that you installed the {"organization" if expected == EntityType.ORGANIZATION else "personal"} app on a {wrong_app_name} account — these two must match.

-

Please uninstall it from your GitHub settings and install the correct app:

-

Install the {wrong_app_name} app

""", - status=404, - ) - logger.error("Unexpected error fetching installation %s: %s", installation_id, e) - return _setup_page( - "RISE RISC-V Runners — Setup error", - f"""

Something went wrong

-

GitHub returned an error while validating your installation ({e.status_code}). Please try again in a minute, or contact the RISE team if the problem persists.

""", - status=502, - ) - - account = installation.get("account") or {} - account_type = account.get("type") - account_login = account.get("login", "(unknown)") - - if account_type == expected.value: - return _setup_page( - "RISE RISC-V Runners — Installed", - f"""

All set, {account_login}!

-

The RISE RISC-V Runners {"organization" if expected == EntityType.ORGANIZATION else "personal"} app is correctly installed on {account_login}.

-

You can now trigger GitHub Actions jobs with the ubuntu-24.04-riscv label and they will be picked up automatically.

""", - ) - - # Mismatch: user installed this app on the wrong account type. - if expected == EntityType.ORGANIZATION: - logger.info("Entity %s installed Personal Account app on Organization, account_type=%s account_login=%s", account_login, account_type, account_login) - return _setup_page( - "RISE RISC-V Runners — Wrong account type", - f"""

You installed the organization app on a personal account

-

The RISE RISC-V Runners (organization) app was installed on personal account {account_login}. It only works on GitHub organizations.

-

For personal accounts, install the dedicated personal app instead:

-

Install the personal app

-

You should also uninstall the organization app from {account_login}'s GitHub settings to avoid confusion.

""", - status=400, - ) - else: - logger.info("Entity %s installed Organization app on Personal Account, account_type=%s account_login=%s", account_login, account_type, account_login) - return _setup_page( - "RISE RISC-V Runners — Wrong account type", - f"""

You installed the personal app on an organization

-

The RISE RISC-V Runners (personal) app was installed on organization {account_login}. It only works on personal GitHub accounts.

-

For organizations, install the dedicated organization app instead:

-

Install the organization app

-

You should also uninstall the personal app from {account_login}'s GitHub settings to avoid confusion.

""", - status=400, - ) - - -@app.route("/setup/org", methods=["GET"]) -def setup_org(): - return _render_setup(expected=EntityType.ORGANIZATION) - - -@app.route("/setup/personal", methods=["GET"]) -def setup_personal(): - return _render_setup(expected=EntityType.USER) - - -# --- /trace endpoints --- - -def _check_trace_auth(): - """401 unless the Bearer token matches TRACE_API_SECRET. Plain equality check.""" - auth = request.headers.get("Authorization", "") - if auth != f"Bearer {TRACE_API_SECRET}": - raise WebhookError(401, "Unauthorized") - - -def _json_response(data): - return make_response(json_dumps(data, default=str), 200, {"Content-Type": "application/json"}) - - -@app.route("/trace/entity/", methods=["GET"]) -def trace_entity(entity_id): - _check_trace_auth() - events = db.get_events_by_entity_id(entity_id) - return _json_response({"events": events}) - - -@app.route("/trace/installation/", methods=["GET"]) -def trace_installation(installation_id): - _check_trace_auth() - entity_id = db.get_entity_id_for_installation(installation_id) - if entity_id is None: - raise WebhookError(404, "Entity not found") - events = db.get_events_by_entity_id(entity_id) - return _json_response({"events": events}) - - -@app.route("/trace/job/", methods=["GET"]) -def trace_job(job_id): - _check_trace_auth() - entity_id = db.get_entity_id_for_job(job_id) - if entity_id is None: - raise WebhookError(404, "Entity not found") - events = db.get_events_by_entity_id(entity_id) - return _json_response({"events": events}) - - -@app.route("/trace/payload/", methods=["GET"]) -def trace_payload(event_id): - _check_trace_auth() - payload = db.get_payload_by_id(event_id) - if payload is None: - raise WebhookError(404, "Payload not found") - return _json_response({"payload": payload}) - - -@app.route("/", methods=['POST']) -def webhook(): - event, body = check_webhook_signature(request.headers, request.get_data(as_text=True)) - - try: - payload = json.loads(body) - except json.JSONDecodeError: - logger.debug("Invalid JSON payload") - raise WebhookError(400, "Invalid JSON payload") - - if not "X-GitHub-Hook-Installation-Target-Id" in request.headers: - raise WebhookError(400, "Missing X-GitHub-Hook-Installation-Target-Id header") - try: - app_id = int(request.headers["X-GitHub-Hook-Installation-Target-Id"]) - except ValueError: - raise WebhookError(400, "Invalid X-GitHub-Hook-Installation-Target-Id header") - - if event == "ping": - _log_webhook_event(event="ping", outcome=WebhookOutcome.OK, payload=payload, app_id=app_id) - return f"pong" - - elif event == "installation": - action = payload["action"] - install = payload["installation"] - account = install["account"] - _log_webhook_event( - event=f"{event}.{action}", - outcome=WebhookOutcome.OK, - payload=payload, - app_id=app_id, - installation_id=install["id"], - entity_type=install["target_type"], - entity_id=install["target_id"], - entity_name=account["login"], - ) - return f"{event}.{action} logged" - - elif event == "installation_repositories": - action = payload["action"] - install = payload["installation"] - account = install["account"] - _log_webhook_event( - event=f"{event}.{action}", - outcome=WebhookOutcome.OK, - payload=payload, - app_id=app_id, - installation_id=install["id"], - entity_type=install["target_type"], - entity_id=install["target_id"], - entity_name=account["login"], - ) - return f"{event}.{action} logged" - - elif event == "installation_target": - action = payload["action"] - # `installation_target.renamed` carries the new account at top level; - # `installation.account` would be the pre-rename name. - account = payload["account"] - install = payload["installation"] - _log_webhook_event( - event=f"{event}.{action}", - outcome=WebhookOutcome.OK, - payload=payload, - app_id=app_id, - installation_id=install["id"], - entity_type=payload["target_type"], - entity_id=account["id"], - entity_name=account["login"], - ) - return f"{event}.{action} logged" - - elif event == "workflow_job": - action = payload["action"] - - # workflow_job's `installation` object is just `{id, node_id}` — pull - # the entity from `repository.owner` instead. - install = payload["installation"] - owner = payload["repository"]["owner"] - # Drop the noisy URL/license/steps fields before logging. The - # ignored_no_label branch overrides `payload` below with an even - # tighter dict, so this only affects the processed-job outcomes - # (job_stored, job_marked_running, etc.). - log_fields = dict( - payload=_trim_workflow_job_payload(payload), - app_id=app_id, - installation_id=install["id"], - entity_type=owner["type"], - entity_id=owner["id"], - entity_name=owner["login"], - ) - - # Ignore workflow_job actions we don't process (e.g. 'waiting'). - if action not in ("queued", "in_progress", "completed"): - _log_webhook_event(event=f"{event}.{action}", outcome=WebhookOutcome.IGNORED_ACTION, **log_fields) - logger.debug("Ignoring action: %s", action) - return f"Ignoring action: {action}" - - entity_id, entity_type = authorize_entity(payload) - - # Check if we should redirect to staging - if PROD: - repo_name = payload["repository"].get("name") - if entity_id in STAGING_ENTITIES and repo_name and repo_name in STAGING_ENTITIES[entity_id]: - g.print_perf_log = True - logger.debug("Proxying request for entity=%s repo=%s to staging (%s)", entity_id, repo_name, STAGING_URL) - resp = requests.post( - STAGING_URL, - data=request.get_data(), - headers={k: v for k, v in request.headers if k.lower() != "host"}, - timeout=30, - ) - logger.info("Proxied request for entity=%s repo=%s to staging, status=%s", entity_id, repo_name, resp.status_code) - return make_response(resp.content, resp.status_code) - - if GO_GHFE_URL and entity_id in GO_GHFE_ROUTING: - g.print_perf_log = True - logger.debug("Proxying request for entity=%s to Go ghfe (%s)", entity_id, GO_GHFE_URL) - resp = requests.post( - GO_GHFE_URL, - data=request.get_data(), - headers={k: v for k, v in request.headers if k.lower() != "host"}, - timeout=30, - ) - logger.info("Proxied request for entity=%s to Go ghfe, status=%s", entity_id, resp.status_code) - return make_response(resp.content, resp.status_code) - - job_id = payload["workflow_job"]["id"] - if not job_id: - raise WebhookError(400, "Job ID is missing in payload") - - # labels may be missing when no labels are defined - job_labels = payload["workflow_job"]["labels"] or [] - - repo_full_name = payload["repository"]["full_name"] - if not repo_full_name: - raise WebhookError(400, "Repository full name is missing in payload") - - repo_id = payload["repository"]["id"] - if not repo_id: - raise WebhookError(400, "Repository ID is missing in payload") - - # Filter out unsupported jobs early. - match = match_labels_to_k8s(entity_id, repo_full_name, job_labels) - if match is None: - # ignored_no_label is by far the highest-volume row; keep only the - # fields a human needs to diagnose "user used an unsupported label" - # (which labels they tried, which repo, link to the run on GitHub). - log_fields["payload"] = { - "workflow_job": { - "labels": job_labels, - "html_url": payload["workflow_job"].get("html_url"), - }, - "repository": {"full_name": repo_full_name}, - } - _log_webhook_event(event=f"{event}.{action}", outcome=WebhookOutcome.IGNORED_NO_LABEL, **log_fields) - raise WebhookError(200, f"Ignoring job: missing required platform label (got {job_labels})") - k8s_pool, k8s_image = match - - logger.info("Received %s workflow_job id=%s name=%s repo=%s labels=%s entity_type=%s", - action, job_id, payload["workflow_job"]["name"], - payload["repository"]["full_name"], - payload["workflow_job"]["labels"], - entity_type.value) - - # Only enable printing if we know we care for that webhook - g.print_perf_log = True - - if action == "queued": - installation_id = payload["installation"]["id"] - if not installation_id: - raise WebhookError(400, "Installation ID is missing in payload") - - entity_name = payload["repository"]["owner"]["login"] - if not entity_name: - raise WebhookError(400, "Entity name is missing in payload") - - html_url = payload["workflow_job"]["html_url"] - if not html_url: - raise WebhookError(400, "HTML URL is missing in payload") - - stored = db.add_job( - job_id=job_id, - provider="github", - entity_id=entity_id, - entity_name=entity_name, - entity_type=entity_type, - repo_full_name=repo_full_name, - installation_id=installation_id, - labels=job_labels, - k8s_pool=k8s_pool, - k8s_image=k8s_image, - html_url=html_url, - ) - - outcome = WebhookOutcome.JOB_STORED if stored else WebhookOutcome.JOB_ALREADY_EXISTS - _log_webhook_event(event=f"{event}.{action}", outcome=outcome, **log_fields) - - if stored: - return f"Job {job_id} stored." - else: - return f"Job {job_id} already exists." - - elif action == "in_progress": - prev_status = db.mark_job_running(job_id, payload["workflow_job"].get("runner_name")) - outcome = WebhookOutcome.JOB_NOT_FOUND if prev_status is None else WebhookOutcome.JOB_MARKED_RUNNING - _log_webhook_event(event=f"{event}.{action}", outcome=outcome, **log_fields) - if prev_status is None: - logger.warning("Job %s not found on in_progress event", job_id) - return f"Job {job_id} not found." - logger.info("Job %s marked running (was %s)", job_id, prev_status) - return f"Job {job_id} marked running (was {prev_status})." - - elif action == "completed": - prev_status = db.mark_job_completed(job_id, payload["workflow_job"].get("runner_name")) - outcome = WebhookOutcome.JOB_NOT_FOUND if prev_status is None else WebhookOutcome.JOB_MARKED_COMPLETED - _log_webhook_event(event=f"{event}.{action}", outcome=outcome, **log_fields) - if prev_status is None: - logger.warning("Job %s not found on completed event", job_id) - return f"Job {job_id} not found." - return f"Job {job_id} completed (was {prev_status})." - - else: - _log_webhook_event(event=event, outcome=WebhookOutcome.IGNORED_EVENT, payload=payload, app_id=app_id) - return f"Ignoring {event} event" - -if __name__ == "__main__": - # Set the logging level for all loggers to INFO - logging.basicConfig( - level=logging.getLevelNamesMapping()[os.environ.get("LOGLEVEL", "INFO")], - format='%(pathname)s:%(lineno)d::%(funcName)s: [%(levelname)s] %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - - # Ensure PostgreSQL schema/tables exist - db.ensure_schema() - - from waitress import serve - - HOST = "0.0.0.0" - PORT = 8080 - - print(f"Starting server on http://{HOST}:{PORT}") - serve(app, host=HOST, port=PORT, threads=8) # it's pretty much only IO and CPU is at ~5% diff --git a/container/github.py b/container/github.py deleted file mode 100644 index d576184..0000000 --- a/container/github.py +++ /dev/null @@ -1,268 +0,0 @@ -import functools -import logging -import time - -import jwt -import requests - -from cachetools.func import ttl_cache -from constants import * - -logger = logging.getLogger(__name__) - - -class GitHubAPIError(Exception): - """Exception raised for GitHub API errors.""" - def __init__(self, status_code: int, message: str): - self.status_code = int(status_code) - self.message = message - super().__init__(message) - - -@functools.lru_cache(maxsize=1) -def init_ghapp_private_key_org(): - private_key = jwt.jwk_from_pem(GHAPP_ORG_PRIVATE_KEY.encode('utf-8')) - assert private_key, "Failed to load private key from GHAPP_ORG_PRIVATE_KEY" - return private_key - - -@functools.lru_cache(maxsize=1) -def init_ghapp_private_key_personal(): - private_key = jwt.jwk_from_pem(GHAPP_PERSONAL_PRIVATE_KEY.encode('utf-8')) - assert private_key, "Failed to load private key from GHAPP_PERSONAL_PRIVATE_KEY" - return private_key - - -def generate_jwt(app_id, private_key): - """Generate a JWT for GitHub App authentication.""" - payload = { - "iat": int(time.time()), - "exp": int(time.time()) + (10 * 60), - "iss": app_id, - } - return jwt.JWT().encode(payload, private_key, alg="RS256") - - -@ttl_cache(maxsize=1024, ttl=60*59) # Authentication Token lifetime is 1 hour -def authenticate_app(installation_id, app_id): - """Authenticate the app and get an installation token. - - `app_id` selects which app's private key signs the JWT — pass - `GHAPP_PERSONAL_ID` for user installations, `GHAPP_ORG_ID` for org - installations. - """ - if app_id == GHAPP_PERSONAL_ID: - jwt_token = generate_jwt(GHAPP_PERSONAL_ID, init_ghapp_private_key_personal()) - elif app_id == GHAPP_ORG_ID: - jwt_token = generate_jwt(GHAPP_ORG_ID, init_ghapp_private_key_org()) - else: - raise ValueError(f"Unknown app_id: {app_id}") - - headers = { - "Authorization": f"Bearer {jwt_token}", - "Accept": "application/vnd.github.v3+json", - } - url = f"https://api.github.com/app/installations/{installation_id}/access_tokens" - response = requests.post(url, headers=headers, json={}) - - if response.status_code == 201: - logger.debug("Obtained installation access token for installation %s", installation_id) - return response.json().get("token") - else: - error = response.json().get("message") - logger.error("Failed to get installation access token for installation %s: %s", installation_id, error) - raise GitHubAPIError(response.status_code, f"Failed to get installation access token: {error}") - - -def get_installation(installation_id, entity_type): - """Fetch installation metadata for the app matching entity_type. - - Returns the parsed JSON (includes account.type, account.login, app_id). - """ - assert entity_type is not None - if entity_type == EntityType.USER: - jwt_token = generate_jwt(GHAPP_PERSONAL_ID, init_ghapp_private_key_personal()) - else: - jwt_token = generate_jwt(GHAPP_ORG_ID, init_ghapp_private_key_org()) - - headers = { - "Authorization": f"Bearer {jwt_token}", - "Accept": "application/vnd.github.v3+json", - } - url = f"https://api.github.com/app/installations/{installation_id}" - response = requests.get(url, headers=headers) - - if response.status_code == 200: - return response.json() - else: - error = response.json().get("message") - logger.error("Failed to get installation %s: %s", installation_id, error) - raise GitHubAPIError(response.status_code, f"Failed to get installation: {error}") - - -def ensure_runner_group(entity_name, token, group_name): - """Ensure the runner group exists and return its ID.""" - headers = { - "Authorization": f"Bearer {token}", - "Accept": "application/vnd.github.v3+json", - } - - list_url = f"https://api.github.com/orgs/{entity_name}/actions/runner-groups" - response = requests.get(list_url, headers=headers) - if response.status_code != 200: - error = response.json() - logger.error("Failed to list runner groups for org %s: %s", entity_name, error) - raise GitHubAPIError(response.status_code, f"Failed to list runner groups: {error}") - - for group in response.json().get("runner_groups", []): - if group.get("name") == group_name: - logger.debug("Found existing runner group '%s' (id=%s) for org %s", - group_name, group["id"], entity_name) - return group["id"] - - create_body = { - "name": group_name, - "visibility": "all", - "allows_public_repositories": True, - } - response = requests.post(list_url, headers=headers, json=create_body) - if response.status_code == 201: - runner_group_id = response.json().get("id") - logger.debug("Created runner group '%s' (id=%s) for org %s", - group_name, runner_group_id, entity_name) - return runner_group_id - else: - error = response.json() - logger.error("Failed to create runner group '%s' for org %s: %s", - group_name, entity_name, error) - raise GitHubAPIError(response.status_code, f"Failed to create runner group: {error}") - - -def create_jit_runner_config_org(token, group_id, labels, entity_name, runner_name): - """Create a JIT runner configuration for a new ephemeral runner (org-scoped).""" - headers = { - "Authorization": f"Bearer {token}", - "Accept": "application/vnd.github.v3+json", - } - url = f"https://api.github.com/orgs/{entity_name}/actions/runners/generate-jitconfig" - body = { - "name": runner_name, - "runner_group_id": group_id, - "labels": labels, - "work_folder": "../../../work", # /home/runner/actions-runner/cached/X.Y.Z/bin/../ -> /home/runner/work - } - response = requests.post(url, headers=headers, json=body) - - if response.status_code == 201: - jit_config = response.json().get("encoded_jit_config") - logger.debug("Created JIT runner config for org %s, runner name=%s", entity_name, runner_name) - return jit_config - else: - error = response.json() - logger.error("Failed to create JIT runner config for org %s: %s", entity_name, error) - raise GitHubAPIError(response.status_code, f"Failed to create JIT runner config: {error}") - - -def create_jit_runner_config_repo(token, labels, repo_full_name, runner_name): - """Create a JIT runner configuration for a new ephemeral runner (repo-scoped, for personal accounts).""" - headers = { - "Authorization": f"Bearer {token}", - "Accept": "application/vnd.github.v3+json", - } - url = f"https://api.github.com/repos/{repo_full_name}/actions/runners/generate-jitconfig" - body = { - "name": runner_name, - "runner_group_id": 1, # default runner group for repos - "labels": labels, - "work_folder": "../../../work", # /home/runner/actions-runner/cached/X.Y.Z/bin/../ -> /home/runner/work - } - response = requests.post(url, headers=headers, json=body) - - if response.status_code == 201: - jit_config = response.json().get("encoded_jit_config") - logger.debug("Created JIT runner config for repo %s, runner name=%s", repo_full_name, runner_name) - return jit_config - else: - error = response.json() - logger.error("Failed to create JIT runner config for repo %s: %s", repo_full_name, error) - raise GitHubAPIError(response.status_code, f"Failed to create JIT runner config: {error}") - - -def _paginated_get(url, token, collection_key): - """GET a paginated GitHub list endpoint, following the `Link: rel="next"` header. - - Returns the concatenated list under `collection_key` in each response body. - Raises GitHubAPIError on non-2xx. - """ - headers = { - "Authorization": f"Bearer {token}", - "Accept": "application/vnd.github.v3+json", - } - results = [] - next_url = f"{url}?per_page=100" - while next_url: - response = requests.get(next_url, headers=headers) - if response.status_code != 200: - raise GitHubAPIError(response.status_code, f"GET {next_url}: {response.text}") - results.extend(response.json().get(collection_key, [])) - next_url = response.links.get("next", {}).get("url") - return results - - -def list_runners_org_group(token, entity_name, group_id): - """List all runners registered under a specific org runner group.""" - url = f"https://api.github.com/orgs/{entity_name}/actions/runner-groups/{group_id}/runners" - return _paginated_get(url, token, "runners") - - -def list_runners_repo(token, repo_full_name): - """List all runners registered under a specific repo.""" - url = f"https://api.github.com/repos/{repo_full_name}/actions/runners" - return _paginated_get(url, token, "runners") - - -def delete_runner_org(token, entity_name, runner_id): - """DELETE an org-scoped runner. 404 is swallowed (already gone).""" - headers = { - "Authorization": f"Bearer {token}", - "Accept": "application/vnd.github.v3+json", - } - url = f"https://api.github.com/orgs/{entity_name}/actions/runners/{runner_id}" - response = requests.delete(url, headers=headers) - if response.status_code in (204, 404): - return - raise GitHubAPIError(response.status_code, f"DELETE {url}: {response.text}") - - -def delete_runner_repo(token, repo_full_name, runner_id): - """DELETE a repo-scoped runner. 404 is swallowed.""" - headers = { - "Authorization": f"Bearer {token}", - "Accept": "application/vnd.github.v3+json", - } - url = f"https://api.github.com/repos/{repo_full_name}/actions/runners/{runner_id}" - response = requests.delete(url, headers=headers) - if response.status_code in (204, 404): - return - raise GitHubAPIError(response.status_code, f"DELETE {url}: {response.text}") - - -def get_job_info(repo_full_name, job_id, token): - """Get the effective status of a workflow job from GitHub API. - - GitHub can return status="in_progress" with conclusion="cancelled" (or other - terminal conclusions). When a conclusion is present, the job is effectively - completed regardless of the status field. - """ - headers = { - "Authorization": f"Bearer {token}", - "Accept": "application/vnd.github.v3+json", - } - url = f"https://api.github.com/repos/{repo_full_name}/actions/jobs/{job_id}" - response = requests.get(url, headers=headers) - - if response.status_code == 200: - return response.json() - else: - logger.error("Failed to get job status for %s job %s: %s", repo_full_name, job_id, response.status_code) - raise GitHubAPIError(response.status_code, f"Failed to get job status: {response.text}") diff --git a/container-go/go.mod b/container/go.mod similarity index 95% rename from container-go/go.mod rename to container/go.mod index 9e5df28..3f76964 100644 --- a/container-go/go.mod +++ b/container/go.mod @@ -1,4 +1,4 @@ -module github.com/riseproject-dev/riscv-runner-app/container-go +module github.com/riseproject-dev/riscv-runner-app/container go 1.26.0 @@ -29,6 +29,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/pashagolub/pgxmock/v4 v4.9.0 // indirect github.com/spf13/pflag v1.0.9 // indirect github.com/x448/float16 v0.8.4 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect diff --git a/container-go/go.sum b/container/go.sum similarity index 98% rename from container-go/go.sum rename to container/go.sum index 76c76d7..76cd8dc 100644 --- a/container-go/go.sum +++ b/container/go.sum @@ -55,6 +55,8 @@ github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee h1:W5t00kpgFd github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/pashagolub/pgxmock/v4 v4.9.0 h1:itlO8nrVRnzkdMBXLs8pWUyyB2PC3Gku0WGIj/gGl7I= +github.com/pashagolub/pgxmock/v4 v4.9.0/go.mod h1:9L57pC193h2aKRHVyiiE817avasIPZnPwPlw3JczWvM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= diff --git a/container-go/internal/constants.go b/container/internal/constants.go similarity index 97% rename from container-go/internal/constants.go rename to container/internal/constants.go index 883a6e5..a804a7b 100644 --- a/container-go/internal/constants.go +++ b/container/internal/constants.go @@ -1,6 +1,5 @@ // Package internal hosts shared types, configuration, and clients -// used by both cmd/ghfe and cmd/scheduler. See container-go/CONTRACT.md -// for the external behavior this package implements. +// used by both cmd/ghfe and cmd/scheduler. package internal import ( diff --git a/container/internal/constants_test.go b/container/internal/constants_test.go new file mode 100644 index 0000000..af507c8 --- /dev/null +++ b/container/internal/constants_test.go @@ -0,0 +1,103 @@ +package internal + +import ( + "strings" + "testing" +) + +func TestOrDefault(t *testing.T) { + if got := orDefault("", "fallback"); got != "fallback" { + t.Errorf("empty: got %q", got) + } + if got := orDefault("v", "fallback"); got != "v" { + t.Errorf("non-empty: got %q", got) + } +} + +func TestLoadConfig_AllRequiredPresentProd(t *testing.T) { + env := map[string]string{ + "PROD": "true", + "PROD_URL": "https://prod", + "STAGING_URL": "https://stg", + "POSTGRES_URL": "postgres://x", + "K8S_KUBECONFIG": "kube: yaml", + "LOGLEVEL": "DEBUG", + "TRACE_API_SECRET": "tsec", + "GHAPP_WEBHOOK_SECRET": "wsec", + "GHAPP_ORG_PRIVATE_KEY": "okey", + "GHAPP_PERSONAL_PRIVATE_KEY": "pkey", + } + cfg, err := LoadConfig(func(k string) string { return env[k] }) + if err != nil { + t.Fatalf("err: %v", err) + } + if !cfg.Prod || cfg.PostgresSchema != "prod" { + t.Errorf("prod path off: %+v", cfg) + } + if cfg.RunnerPrefix != "rise-riscv-runner-" { + t.Errorf("runner prefix: %q", cfg.RunnerPrefix) + } + if !strings.Contains(cfg.ImageUbuntu24, "ubuntu-24.04-latest") { + t.Errorf("image24: %q", cfg.ImageUbuntu24) + } + if !strings.Contains(cfg.ImageUbuntu26, "ubuntu-26.04-latest") { + t.Errorf("image26: %q", cfg.ImageUbuntu26) + } + if cfg.LogLevel != "DEBUG" { + t.Errorf("loglevel: %q", cfg.LogLevel) + } +} + +func TestLoadConfig_StagingDefaults(t *testing.T) { + env := map[string]string{ + "PROD_URL": "https://prod", + "STAGING_URL": "https://stg", + "POSTGRES_URL": "postgres://x", + "K8S_KUBECONFIG": "kube: yaml", + "TRACE_API_SECRET": "tsec", + "GHAPP_WEBHOOK_SECRET": "wsec", + "GHAPP_ORG_PRIVATE_KEY": "okey", + "GHAPP_PERSONAL_PRIVATE_KEY": "pkey", + } + cfg, err := LoadConfig(func(k string) string { return env[k] }) + if err != nil { + t.Fatalf("err: %v", err) + } + if cfg.Prod || cfg.PostgresSchema != "staging" { + t.Errorf("expected staging: %+v", cfg) + } + if cfg.RunnerPrefix != "rise-riscv-runner-staging-" { + t.Errorf("staging prefix: %q", cfg.RunnerPrefix) + } + if cfg.LogLevel != "INFO" { + t.Errorf("default loglevel should be INFO: %q", cfg.LogLevel) + } + if !strings.Contains(cfg.ImageUbuntu24, "ubuntu-24.04-staging") { + t.Errorf("staging image24: %q", cfg.ImageUbuntu24) + } +} + +func TestLoadConfig_MissingRequiredListed(t *testing.T) { + cfg, err := LoadConfig(func(string) string { return "" }) + if err == nil { + t.Fatal("expected error") + } + if cfg.PostgresSchema != "" { + t.Errorf("cfg should be zero: %+v", cfg) + } + for _, k := range []string{"PROD_URL", "STAGING_URL", "POSTGRES_URL", "K8S_KUBECONFIG", "TRACE_API_SECRET", "GHAPP_WEBHOOK_SECRET", "GHAPP_ORG_PRIVATE_KEY", "GHAPP_PERSONAL_PRIVATE_KEY"} { + if !strings.Contains(err.Error(), k) { + t.Errorf("missing key %q not mentioned: %v", k, err) + } + } +} + +func TestLoadConfigFromEnv_NoEnv(t *testing.T) { + // Unset every required var so the env-backed loader fails cleanly. + for _, k := range []string{"PROD_URL", "STAGING_URL", "POSTGRES_URL", "K8S_KUBECONFIG", "TRACE_API_SECRET", "GHAPP_WEBHOOK_SECRET", "GHAPP_ORG_PRIVATE_KEY", "GHAPP_PERSONAL_PRIVATE_KEY", "PROD", "LOGLEVEL"} { + t.Setenv(k, "") + } + if _, err := LoadConfigFromEnv(); err == nil { + t.Fatal("expected error with empty env") + } +} diff --git a/container-go/internal/contract.go b/container/internal/contract.go similarity index 100% rename from container-go/internal/contract.go rename to container/internal/contract.go diff --git a/container/internal/contract_test.go b/container/internal/contract_test.go new file mode 100644 index 0000000..4bd5592 --- /dev/null +++ b/container/internal/contract_test.go @@ -0,0 +1,91 @@ +package internal + +import ( + "bytes" + "log/slog" + "math" + "strings" + "testing" + "time" +) + +func TestParseEntityType(t *testing.T) { + cases := []struct { + in string + want EntityType + wantErr bool + }{ + {"Organization", EntityOrganization, false}, + {"User", EntityUser, false}, + {"", "", true}, + {"organization", "", true}, + {"Bot", "", true}, + } + for _, c := range cases { + got, err := ParseEntityType(c.in) + if c.wantErr { + if err == nil { + t.Errorf("%q: expected error", c.in) + } + continue + } + if err != nil { + t.Errorf("%q: unexpected err %v", c.in, err) + } + if got != c.want { + t.Errorf("%q: got %v want %v", c.in, got, c.want) + } + } +} + +func TestEntity_LogValue(t *testing.T) { + var buf bytes.Buffer + h := slog.NewTextHandler(&buf, nil) + slog.New(h).Info("hi", "entity", Entity{Type: EntityOrganization, Name: "acme", ID: 42}) + s := buf.String() + if !strings.Contains(s, "entity.type=Organization") { + t.Errorf("type missing: %q", s) + } + if !strings.Contains(s, "entity.name=acme") { + t.Errorf("name missing: %q", s) + } + if !strings.Contains(s, "entity.id=42") { + t.Errorf("id missing: %q", s) + } +} + +func TestJob_Entity(t *testing.T) { + j := Job{EntityID: 7, EntityName: "acme", EntityType: "Organization"} + e := j.Entity() + if e.ID != 7 || e.Name != "acme" || e.Type != EntityOrganization { + t.Errorf("got %+v", e) + } +} + +func TestWorker_Entity(t *testing.T) { + w := Worker{EntityID: 9, EntityName: "luhenry", EntityType: "User"} + e := w.Entity() + if e.ID != 9 || e.Name != "luhenry" || e.Type != EntityUser { + t.Errorf("got %+v", e) + } +} + +func TestGitHubAPIError_Error(t *testing.T) { + e := &GitHubAPIError{StatusCode: 404, Message: "Not Found"} + if e.Error() != "Not Found" { + t.Errorf("got %q", e.Error()) + } +} + +func TestAgeSeconds(t *testing.T) { + // nil pointer → +Inf-ish + if got := AgeSeconds(nil); got < 1e100 { + t.Errorf("nil should be huge, got %v", got) + } + + t0 := time.Now().Add(-30 * time.Second) + got := AgeSeconds(&t0) + if math.Abs(got-30) > 5 { + t.Errorf("expected ~30, got %v", got) + } +} diff --git a/container-go/internal/db.go b/container/internal/db.go similarity index 95% rename from container-go/internal/db.go rename to container/internal/db.go index cad697f..aa091b2 100644 --- a/container-go/internal/db.go +++ b/container/internal/db.go @@ -13,16 +13,33 @@ import ( "github.com/jackc/pgx/v5/pgxpool" ) -// pgDB wires the DB interface to a pgxpool.Pool. A separate connection -// (listenConn) holds the LISTEN session for the {schema}_queue_event channel. +// pgDB wires the DB interface to a pool. A separate connection (listenConn) +// holds the LISTEN session for the {schema}_queue_event channel. type pgDB struct { - pool *pgxpool.Pool + pool pgxPool schema string - listenConn *pgx.Conn + listenConn listenConn } -// queryer is the subset of pgx methods we use; matches both *pgxpool.Pool -// and pgx.Tx so writes/reads can run inside WithWorkerLock's transaction. +// pgxPool is the subset of *pgxpool.Pool used by pgDB. pgxmock.PgxPoolIface +// implements the same surface so tests can drive every query path without a +// live Postgres. +type pgxPool interface { + Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) + Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) + QueryRow(ctx context.Context, sql string, args ...any) pgx.Row + BeginTx(ctx context.Context, opts pgx.TxOptions) (pgx.Tx, error) + Close() +} + +// listenConn is the subset of *pgx.Conn used for LISTEN/NOTIFY. Tests stub it. +type listenConn interface { + WaitForNotification(ctx context.Context) (*pgconn.Notification, error) + Close(ctx context.Context) error +} + +// queryer is the subset of pgx methods we use; matches both pgxPool and +// pgx.Tx so writes/reads can run inside WithWorkerLock's transaction. type queryer interface { Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) diff --git a/container/internal/db_test.go b/container/internal/db_test.go new file mode 100644 index 0000000..79a045d --- /dev/null +++ b/container/internal/db_test.go @@ -0,0 +1,644 @@ +package internal + +import ( + "context" + "encoding/json" + "errors" + "testing" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/pashagolub/pgxmock/v4" +) + +// newMockDB returns a pgDB wired to a pgxmock pool with regex query matching. +// Tests register expectations on the returned mock and call methods on pgDB. +func newMockDB(t *testing.T) (*pgDB, pgxmock.PgxPoolIface) { + t.Helper() + mock, err := pgxmock.NewPool(pgxmock.QueryMatcherOption(pgxmock.QueryMatcherRegexp)) + if err != nil { + t.Fatalf("NewPool: %v", err) + } + t.Cleanup(mock.Close) + return &pgDB{pool: mock, schema: "staging"}, mock +} + +// anyN returns n pgxmock.AnyArg() placeholders for tests that don't care +// about the literal arg values. +func anyN(n int) []any { + out := make([]any, n) + for i := range out { + out[i] = pgxmock.AnyArg() + } + return out +} + +// jobScanRow returns one row's worth of jobColumns-shaped values for ScanJobs. +func jobScanRow() []any { + return []any{ + int64(1), "pending", []byte(`{}`), "github", int64(99), "acme", + "Organization", "acme/r", int64(7), []byte(`["x"]`), "scw-em-rv1", + "img", nil, nil, time.Now(), time.Now(), + } +} + +func workerScanRow(name string, status string) []any { + return []any{ + name, "github", int64(99), "acme", "Organization", int64(7), nil, + []byte(`["x"]`), "scw-em-rv1", "img", nil, status, nil, + time.Now(), nil, nil, time.Now(), + } +} + +func jobColumns() []string { + return []string{"job_id", "status", "failure_info", "provider", "entity_id", + "entity_name", "entity_type", "repo_full_name", "installation_id", + "job_labels", "k8s_pool", "k8s_image", "k8s_pod", "html_url", + "created_at", "updated_at"} +} + +func workerColumns() []string { + return []string{"pod_name", "provider", "entity_id", "entity_name", "entity_type", + "installation_id", "repo_full_name", "job_labels", "k8s_pool", "k8s_image", "k8s_node", + "status", "failure_info", "created_at", "running_at", "completed_at", "updated_at"} +} + +func TestAddJob_InsertedAndNotifies(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectExec(`INSERT INTO jobs`).WithArgs(anyN(12)...). + WillReturnResult(pgxmock.NewResult("INSERT", 1)) + mock.ExpectExec(`NOTIFY staging_queue_event`).WithArgs(pgxmock.AnyArg()). + WillReturnResult(pgxmock.NewResult("NOTIFY", 0)) + + got, err := db.AddJob(context.Background(), Job{JobID: 1, Provider: "github", EntityID: 99, + EntityName: "acme", EntityType: "Organization", RepoFullName: "acme/r", + InstallationID: 7, K8sPool: "scw-em-rv1", K8sImage: "img"}, []string{"x"}) + if err != nil || !got { + t.Fatalf("AddJob: got=%v err=%v", got, err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet: %v", err) + } +} + +func TestAddJob_DuplicateReturnsFalse(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectExec(`INSERT INTO jobs`).WithArgs(anyN(12)...).WillReturnResult(pgxmock.NewResult("INSERT", 0)) + got, err := db.AddJob(context.Background(), Job{JobID: 1, EntityType: "User"}, nil) + if err != nil || got { + t.Fatalf("AddJob duplicate: got=%v err=%v", got, err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet: %v", err) + } +} + +func TestAddJob_PropagatesError(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectExec(`INSERT INTO jobs`).WithArgs(anyN(12)...).WillReturnError(errors.New("dial")) + _, err := db.AddJob(context.Background(), Job{JobID: 1}, nil) + if err == nil { + t.Fatal("expected error") + } +} + +func TestMarkJobRunning_ReturnsPrevStatus(t *testing.T) { + db, mock := newMockDB(t) + prev := "pending" + mock.ExpectQuery(`WITH prev AS .*UPDATE jobs.*status = 'running'`). + WithArgs(int64(1), "runner-x"). + WillReturnRows(pgxmock.NewRows([]string{"prev_status"}).AddRow(&prev)) + prev, err := db.MarkJobRunning(context.Background(), 1, "runner-x") + if err != nil || prev != "pending" { + t.Fatalf("prev=%q err=%v", prev, err) + } +} + +func TestMarkJobRunning_NoOpReadsCurrent(t *testing.T) { + db, mock := newMockDB(t) + cur := "completed" + mock.ExpectQuery(`UPDATE jobs.*status = 'running'`). + WithArgs(int64(2), "rn"). + WillReturnRows(pgxmock.NewRows([]string{"prev_status"}).AddRow((*string)(nil))) + mock.ExpectQuery(`SELECT status::text FROM jobs WHERE job_id`). + WithArgs(int64(2)). + WillReturnRows(pgxmock.NewRows([]string{"status"}).AddRow(&cur)) + prev, err := db.MarkJobRunning(context.Background(), 2, "rn") + if err != nil || prev != "completed" { + t.Fatalf("prev=%q err=%v", prev, err) + } +} + +func TestMarkJobRunning_NotFoundReturnsEmpty(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectQuery(`UPDATE jobs.*status = 'running'`). + WithArgs(int64(3), ""). + WillReturnRows(pgxmock.NewRows([]string{"prev_status"}).AddRow((*string)(nil))) + mock.ExpectQuery(`SELECT status::text FROM jobs WHERE job_id`). + WithArgs(int64(3)). + WillReturnError(pgx.ErrNoRows) + prev, err := db.MarkJobRunning(context.Background(), 3, "") + if err != nil || prev != "" { + t.Fatalf("prev=%q err=%v", prev, err) + } +} + +func TestMarkJobCompleted_AcceptsRunningOrPending(t *testing.T) { + db, mock := newMockDB(t) + prev := "running" + mock.ExpectQuery(`UPDATE jobs.*status = 'completed'`). + WithArgs(int64(1), "rn"). + WillReturnRows(pgxmock.NewRows([]string{"prev_status"}).AddRow(&prev)) + prev, err := db.MarkJobCompleted(context.Background(), 1, "rn") + if err != nil || prev != "running" { + t.Fatalf("prev=%q err=%v", prev, err) + } +} + +func TestMarkJobFailed_RequiresVersion(t *testing.T) { + db, _ := newMockDB(t) + _, err := db.MarkJobFailed(context.Background(), 1, FailureInfo{}) + if err == nil { + t.Fatal("expected error on zero version") + } +} + +func TestMarkJobFailed_Success(t *testing.T) { + db, mock := newMockDB(t) + prev := "running" + mock.ExpectQuery(`UPDATE jobs SET status = 'failed'`). + WithArgs(int64(1), pgxmock.AnyArg()). + WillReturnRows(pgxmock.NewRows([]string{"prev_status"}).AddRow(&prev)) + prev, err := db.MarkJobFailed(context.Background(), 1, FailureInfo{Version: 2, Reason: ReasonPodFailed}) + if err != nil || prev != "running" { + t.Fatalf("prev=%q err=%v", prev, err) + } +} + +func TestMarkJobFailed_NoMatchFallbackReadsCurrent(t *testing.T) { + db, mock := newMockDB(t) + // UPDATE matches nothing → first QueryRow returns nil-prev (no rows). + mock.ExpectQuery(`UPDATE jobs SET status = 'failed'`). + WithArgs(int64(1), pgxmock.AnyArg()). + WillReturnError(pgx.ErrNoRows) + cur := "completed" + mock.ExpectQuery(`SELECT status::text FROM jobs WHERE job_id`). + WithArgs(int64(1)). + WillReturnRows(pgxmock.NewRows([]string{"status"}).AddRow(&cur)) + prev, err := db.MarkJobFailed(context.Background(), 1, FailureInfo{Version: 2, Reason: ReasonPodFailed}) + if err != nil || prev != "completed" { + t.Fatalf("prev=%q err=%v", prev, err) + } +} + +func TestMarkJobFailed_NotFoundReturnsEmpty(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectQuery(`UPDATE jobs SET status = 'failed'`). + WithArgs(int64(99), pgxmock.AnyArg()). + WillReturnError(pgx.ErrNoRows) + mock.ExpectQuery(`SELECT status::text FROM jobs WHERE job_id`). + WithArgs(int64(99)). + WillReturnError(pgx.ErrNoRows) + prev, err := db.MarkJobFailed(context.Background(), 99, FailureInfo{Version: 2, Reason: ReasonPodFailed}) + if err != nil || prev != "" { + t.Fatalf("prev=%q err=%v", prev, err) + } +} + +func TestMarkJobFailed_UpdateErrorPropagates(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectQuery(`UPDATE jobs SET status = 'failed'`). + WithArgs(int64(1), pgxmock.AnyArg()). + WillReturnError(errors.New("boom")) + _, err := db.MarkJobFailed(context.Background(), 1, FailureInfo{Version: 2, Reason: ReasonPodFailed}) + if err == nil { + t.Fatal("expected error") + } +} + +func TestJobExistsForPod(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectQuery(`SELECT 1 FROM jobs WHERE k8s_pod`). + WithArgs("pod-1"). + WillReturnRows(pgxmock.NewRows([]string{"x"}).AddRow(1)) + got, err := db.JobExistsForPod(context.Background(), "pod-1") + if err != nil || !got { + t.Fatalf("got=%v err=%v", got, err) + } + + mock.ExpectQuery(`SELECT 1 FROM jobs WHERE k8s_pod`). + WithArgs("nope"). + WillReturnError(pgx.ErrNoRows) + got, err = db.JobExistsForPod(context.Background(), "nope") + if err != nil || got { + t.Fatalf("got=%v err=%v", got, err) + } +} + +func TestGetActiveJobs(t *testing.T) { + db, mock := newMockDB(t) + rows := pgxmock.NewRows(jobColumns()).AddRow(jobScanRow()...) + mock.ExpectQuery(`SELECT .* FROM jobs.*status = 'pending' OR status = 'running'`).WillReturnRows(rows) + out, err := db.GetActiveJobs(context.Background()) + if err != nil || len(out) != 1 || out[0].JobID != 1 { + t.Fatalf("got %+v err=%v", out, err) + } +} + +func TestGetPendingJobs(t *testing.T) { + db, mock := newMockDB(t) + rows := pgxmock.NewRows(jobColumns()).AddRow(jobScanRow()...) + mock.ExpectQuery(`SELECT .* FROM jobs.*WHERE status = 'pending' ORDER BY created_at`).WillReturnRows(rows) + out, err := db.GetPendingJobs(context.Background()) + if err != nil || len(out) != 1 { + t.Fatalf("got %+v err=%v", out, err) + } +} + +func TestGetAllJobs_Paginated(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM jobs`). + WillReturnRows(pgxmock.NewRows([]string{"c"}).AddRow(1)) + rows := pgxmock.NewRows(jobColumns()).AddRow(jobScanRow()...) + mock.ExpectQuery(`SELECT .* FROM jobs.*LIMIT \$1 OFFSET \$2`). + WithArgs(10, 0).WillReturnRows(rows) + out, total, err := db.GetAllJobs(context.Background(), "", "", 0, 10) + if err != nil || total != 1 || len(out) != 1 { + t.Fatalf("got total=%d out=%+v err=%v", total, out, err) + } +} + +func TestGetAllJobs_WithDateFilter(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM jobs WHERE created_at >= \$1::timestamptz AND created_at < \$2::timestamptz`). + WithArgs("2026-01-01", "2026-02-01"). + WillReturnRows(pgxmock.NewRows([]string{"c"}).AddRow(0)) + rows := pgxmock.NewRows(jobColumns()) + mock.ExpectQuery(`SELECT .* FROM jobs WHERE .*LIMIT \$3 OFFSET \$4`). + WithArgs("2026-01-01", "2026-02-01", 50, 100). + WillReturnRows(rows) + _, total, err := db.GetAllJobs(context.Background(), "2026-01-01", "2026-02-01", 2, 50) + if err != nil || total != 0 { + t.Fatalf("total=%d err=%v", total, err) + } +} + +func TestGetPoolDemand(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectQuery(`SELECT.*job_count.*worker_count`). + WithArgs(int64(7), `["x"]`). + WillReturnRows(pgxmock.NewRows([]string{"j", "w"}).AddRow(3, 1)) + j, w, err := db.GetPoolDemand(context.Background(), 7, []string{"x"}) + if err != nil || j != 3 || w != 1 { + t.Fatalf("j=%d w=%d err=%v", j, w, err) + } +} + +func TestGetTotalWorkersForEntity(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM workers`). + WithArgs(int64(7)). + WillReturnRows(pgxmock.NewRows([]string{"c"}).AddRow(4)) + got, err := db.GetTotalWorkersForEntity(context.Background(), 7) + if err != nil || got != 4 { + t.Fatalf("got=%d err=%v", got, err) + } +} + +func TestAddWorker_DuplicateError(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectExec(`INSERT INTO workers`).WithArgs(anyN(10)...). + WillReturnResult(pgxmock.NewResult("INSERT", 0)) + err := db.AddWorker(context.Background(), Worker{PodName: "p"}, nil) + if !errors.Is(err, ErrDuplicatePodName) { + t.Fatalf("expected ErrDuplicatePodName, got %v", err) + } +} + +func TestAddWorker_Success(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectExec(`INSERT INTO workers`).WithArgs(anyN(10)...). + WillReturnResult(pgxmock.NewResult("INSERT", 1)) + if err := db.AddWorker(context.Background(), Worker{PodName: "p"}, nil); err != nil { + t.Fatalf("AddWorker: %v", err) + } +} + +func TestMarkWorkerRunning(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectExec(`UPDATE workers.*status = 'running'`).WithArgs(anyN(3)...). + WillReturnResult(pgxmock.NewResult("UPDATE", 1)) + now := time.Now() + if err := db.MarkWorkerRunning(context.Background(), "p", "node-1", &now); err != nil { + t.Fatalf("MarkWorkerRunning: %v", err) + } +} + +func TestMarkWorkerCompleted(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectExec(`UPDATE workers.*status = 'completed'`).WithArgs(anyN(3)...). + WillReturnResult(pgxmock.NewResult("UPDATE", 1)) + if err := db.MarkWorkerCompleted(context.Background(), "p", "node-1", nil); err != nil { + t.Fatalf("MarkWorkerCompleted: %v", err) + } +} + +func TestMarkWorkerFailed_RequiresVersion(t *testing.T) { + db, _ := newMockDB(t) + if err := db.MarkWorkerFailed(context.Background(), "p", "n", FailureInfo{}, nil); err == nil { + t.Fatal("expected error on zero version") + } +} + +func TestMarkWorkerFailed_Success(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectExec(`UPDATE workers.*status = 'failed'`).WithArgs(anyN(4)...). + WillReturnResult(pgxmock.NewResult("UPDATE", 1)) + if err := db.MarkWorkerFailed(context.Background(), "p", "n", + FailureInfo{Version: 2, Reason: ReasonPodFailed}, nil); err != nil { + t.Fatalf("MarkWorkerFailed: %v", err) + } +} + +func TestMarkWorkerOrphaned(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectExec(`UPDATE workers\s+SET status = 'completed'`).WithArgs("p"). + WillReturnResult(pgxmock.NewResult("UPDATE", 1)) + if err := db.MarkWorkerOrphaned(context.Background(), "p"); err != nil { + t.Fatalf("MarkWorkerOrphaned: %v", err) + } +} + +func TestGetActiveJobsAndWorkers(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectQuery(`SELECT .* FROM jobs`). + WillReturnRows(pgxmock.NewRows(jobColumns()).AddRow(jobScanRow()...)) + mock.ExpectQuery(`SELECT .* FROM workers`). + WillReturnRows(pgxmock.NewRows(workerColumns()).AddRow(workerScanRow("p", "pending")...)) + jobs, workers, err := db.GetActiveJobsAndWorkers(context.Background()) + if err != nil || len(jobs) != 1 || len(workers) != 1 { + t.Fatalf("jobs=%d workers=%d err=%v", len(jobs), len(workers), err) + } +} + +func TestGetActiveWorkers(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectQuery(`SELECT .* FROM workers`). + WillReturnRows(pgxmock.NewRows(workerColumns()).AddRow(workerScanRow("p", "running")...)) + out, err := db.GetActiveWorkers(context.Background()) + if err != nil || len(out) != 1 { + t.Fatalf("got %+v err=%v", out, err) + } +} + +func TestGetAllWorkers_Paginated(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM workers`). + WillReturnRows(pgxmock.NewRows([]string{"c"}).AddRow(0)) + mock.ExpectQuery(`SELECT .* FROM workers`).WithArgs(10, 0). + WillReturnRows(pgxmock.NewRows(workerColumns())) + _, total, err := db.GetAllWorkers(context.Background(), "", "", 0, 10) + if err != nil || total != 0 { + t.Fatalf("total=%d err=%v", total, err) + } +} + +func TestGetWorkersForReconcile(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectQuery(`SELECT .* FROM workers\s+WHERE status IN`). + WithArgs(3600). + WillReturnRows(pgxmock.NewRows(workerColumns()).AddRow(workerScanRow("p", "running")...)) + out, err := db.GetWorkersForReconcile(context.Background(), time.Hour) + if err != nil || len(out) != 1 { + t.Fatalf("got %+v err=%v", out, err) + } +} + +func TestAddInstallationEvent(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectQuery(`INSERT INTO installation_events`).WithArgs(anyN(9)...). + WillReturnRows(pgxmock.NewRows([]string{"id"}).AddRow(int64(99))) + id, err := db.AddInstallationEvent(context.Background(), + InstallationEvent{Source: "webhook", Event: "ping", Outcome: "ok"}, + []byte(`{"zen":"hi"}`)) + if err != nil || id != 99 { + t.Fatalf("id=%d err=%v", id, err) + } +} + +func TestAddInstallationEvent_DefaultEmptyPayload(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectQuery(`INSERT INTO installation_events`). + WithArgs("webhook", "x", "ok", (*int64)(nil), (*int64)(nil), + (*string)(nil), (*int64)(nil), (*string)(nil), "{}"). + WillReturnRows(pgxmock.NewRows([]string{"id"}).AddRow(int64(1))) + _, err := db.AddInstallationEvent(context.Background(), + InstallationEvent{Source: "webhook", Event: "x", Outcome: "ok"}, nil) + if err != nil { + t.Fatalf("AddInstallationEvent: %v", err) + } +} + +func TestGetEventsByEntityID(t *testing.T) { + db, mock := newMockDB(t) + now := time.Now() + jobIDStr, repo := "42", "acme/r" + mock.ExpectQuery(`SELECT.*FROM installation_events.*WHERE entity_id`). + WithArgs(int64(7)). + WillReturnRows(pgxmock.NewRows([]string{"id", "source", "event", "outcome", + "installation_id", "app_id", "entity_type", "entity_id", "entity_name", + "received_at", "job_id", "repo_full_name"}). + AddRow(int64(1), "webhook", "workflow_job.queued", "job_stored", + (*int64)(nil), (*int64)(nil), (*string)(nil), (*int64)(nil), (*string)(nil), + now, &jobIDStr, &repo)) + out, err := db.GetEventsByEntityID(context.Background(), 7) + if err != nil || len(out) != 1 || *out[0].JobID != "42" { + t.Fatalf("got %+v err=%v", out, err) + } +} + +func TestGetPayloadByID(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectQuery(`SELECT payload FROM installation_events`). + WithArgs(int64(5)). + WillReturnRows(pgxmock.NewRows([]string{"payload"}).AddRow([]byte(`{"hi":1}`))) + body, err := db.GetPayloadByID(context.Background(), 5) + if err != nil || string(body) != `{"hi":1}` { + t.Fatalf("got %s err=%v", body, err) + } + + mock.ExpectQuery(`SELECT payload FROM installation_events`). + WithArgs(int64(0)).WillReturnError(pgx.ErrNoRows) + body, err = db.GetPayloadByID(context.Background(), 0) + if err != nil || body != nil { + t.Fatalf("got %s err=%v", body, err) + } +} + +func TestGetEntityIDForInstallation_FromEvents(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectQuery(`SELECT entity_id FROM installation_events`). + WithArgs(int64(11)). + WillReturnRows(pgxmock.NewRows([]string{"entity_id"}).AddRow(int64(99))) + id, ok, err := db.GetEntityIDForInstallation(context.Background(), 11) + if err != nil || !ok || id != 99 { + t.Fatalf("got id=%d ok=%v err=%v", id, ok, err) + } +} + +func TestGetEntityIDForInstallation_FallsBackToJobs(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectQuery(`SELECT entity_id FROM installation_events`). + WithArgs(int64(11)). + WillReturnError(pgx.ErrNoRows) + mock.ExpectQuery(`SELECT entity_id FROM jobs`). + WithArgs(int64(11)). + WillReturnRows(pgxmock.NewRows([]string{"entity_id"}).AddRow(int64(99))) + id, ok, err := db.GetEntityIDForInstallation(context.Background(), 11) + if err != nil || !ok || id != 99 { + t.Fatalf("got id=%d ok=%v err=%v", id, ok, err) + } +} + +func TestGetEntityIDForInstallation_NotFound(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectQuery(`SELECT entity_id FROM installation_events`). + WithArgs(int64(0)). + WillReturnError(pgx.ErrNoRows) + mock.ExpectQuery(`SELECT entity_id FROM jobs`). + WithArgs(int64(0)). + WillReturnError(pgx.ErrNoRows) + _, ok, err := db.GetEntityIDForInstallation(context.Background(), 0) + if err != nil || ok { + t.Fatalf("expected not found, got ok=%v err=%v", ok, err) + } +} + +func TestGetEntityIDForJob(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectQuery(`SELECT entity_id FROM jobs WHERE job_id`). + WithArgs(int64(1)). + WillReturnRows(pgxmock.NewRows([]string{"entity_id"}).AddRow(int64(7))) + id, ok, err := db.GetEntityIDForJob(context.Background(), 1) + if err != nil || !ok || id != 7 { + t.Fatalf("got id=%d ok=%v err=%v", id, ok, err) + } + + mock.ExpectQuery(`SELECT entity_id FROM jobs WHERE job_id`). + WithArgs(int64(0)). + WillReturnError(pgx.ErrNoRows) + _, ok, err = db.GetEntityIDForJob(context.Background(), 0) + if err != nil || ok { + t.Fatalf("expected not found") + } +} + +func TestWithWorkerLock_CommitsOnSuccess(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectBegin() + mock.ExpectExec(`LOCK TABLE workers IN EXCLUSIVE MODE`). + WillReturnResult(pgxmock.NewResult("LOCK", 0)) + mock.ExpectCommit() + + called := false + err := db.WithWorkerLock(context.Background(), func(ctx context.Context) error { + called = true + if _, ok := ctx.Value(txCtxKey{}).(pgx.Tx); !ok { + t.Error("tx not attached to ctx") + } + return nil + }) + if err != nil || !called { + t.Fatalf("WithWorkerLock: called=%v err=%v", called, err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet: %v", err) + } +} + +func TestWithWorkerLock_RollsBackOnFnError(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectBegin() + mock.ExpectExec(`LOCK TABLE workers`).WillReturnResult(pgxmock.NewResult("LOCK", 0)) + mock.ExpectRollback() + err := db.WithWorkerLock(context.Background(), func(ctx context.Context) error { + return errors.New("fn boom") + }) + if err == nil || err.Error() != "fn boom" { + t.Fatalf("expected fn boom, got %v", err) + } +} + +func TestWithWorkerLock_LockFailureBubbles(t *testing.T) { + db, mock := newMockDB(t) + mock.ExpectBegin() + mock.ExpectExec(`LOCK TABLE workers`).WillReturnError(errors.New("locked")) + mock.ExpectRollback() + err := db.WithWorkerLock(context.Background(), func(ctx context.Context) error { + t.Fatal("fn should not run") + return nil + }) + if err == nil { + t.Fatal("expected lock error") + } +} + +// fakeListenConn lets WaitForJob run without a real Postgres LISTEN socket. +type fakeListenConn struct { + wait func(ctx context.Context) (*pgconn.Notification, error) + closeOk bool +} + +func (f *fakeListenConn) WaitForNotification(ctx context.Context) (*pgconn.Notification, error) { + return f.wait(ctx) +} +func (f *fakeListenConn) Close(ctx context.Context) error { + f.closeOk = true + return nil +} + +func TestWaitForJob_DeadlineIsSuccess(t *testing.T) { + db, _ := newMockDB(t) + db.listenConn = &fakeListenConn{wait: func(ctx context.Context) (*pgconn.Notification, error) { + <-ctx.Done() + return nil, context.DeadlineExceeded + }} + if err := db.WaitForJob(context.Background(), 5*time.Millisecond); err != nil { + t.Errorf("DeadlineExceeded should be nil, got %v", err) + } +} + +func TestWaitForJob_NotificationReturns(t *testing.T) { + db, _ := newMockDB(t) + calls := 0 + db.listenConn = &fakeListenConn{wait: func(ctx context.Context) (*pgconn.Notification, error) { + calls++ + if calls == 1 { + return &pgconn.Notification{Channel: "staging_queue_event", Payload: "1"}, nil + } + <-ctx.Done() + return nil, context.DeadlineExceeded + }} + if err := db.WaitForJob(context.Background(), 50*time.Millisecond); err != nil { + t.Errorf("expected nil, got %v", err) + } + if calls < 1 { + t.Errorf("expected at least one wait call, got %d", calls) + } +} + +func TestClose_TolerantOfMissingListenConn(t *testing.T) { + db, _ := newMockDB(t) + db.listenConn = &fakeListenConn{wait: func(ctx context.Context) (*pgconn.Notification, error) { return nil, nil }} + db.Close() // shouldn't panic +} + +func TestSortedJSON_StableOrder(t *testing.T) { + got := SortedJSON([]string{"b", "a", "c"}) + var arr []string + _ = json.Unmarshal([]byte(got), &arr) + if len(arr) != 3 || arr[0] != "a" || arr[1] != "b" || arr[2] != "c" { + t.Errorf("not sorted: %v", arr) + } +} diff --git a/container-go/internal/github.go b/container/internal/github.go similarity index 100% rename from container-go/internal/github.go rename to container/internal/github.go diff --git a/container/internal/github_test.go b/container/internal/github_test.go new file mode 100644 index 0000000..f3157c0 --- /dev/null +++ b/container/internal/github_test.go @@ -0,0 +1,477 @@ +package internal + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// testKeyPEM generates a fresh RSA private key as PEM. Used once per test +// run; the key is throwaway, doesn't touch any real GitHub. +func testKeyPEM(t *testing.T) string { + t.Helper() + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("rsa.GenerateKey: %v", err) + } + der := x509.MarshalPKCS1PrivateKey(key) + return string(pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: der})) +} + +// newTestClient wires a GHClient to point at a httptest.Server. Callers +// register handlers on the returned mux. +func newTestClient(t *testing.T) (*GHClient, *http.ServeMux, *httptest.Server) { + t.Helper() + mux := http.NewServeMux() + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + cfg := Config{GHAppOrgKey: testKeyPEM(t), GHAppPersonalKey: testKeyPEM(t)} + gh, err := NewGHClient(cfg) + if err != nil { + t.Fatalf("NewGHClient: %v", err) + } + gh.BaseURL = srv.URL + gh.HTTP = srv.Client() + return gh, mux, srv +} + +func writeJSON(w http.ResponseWriter, status int, body any) { + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(body) +} + +func TestNewGHClient_RejectsBadPEM(t *testing.T) { + if _, err := NewGHClient(Config{GHAppOrgKey: "not pem", GHAppPersonalKey: testKeyPEM(t)}); err == nil { + t.Error("expected error on bad org key") + } + if _, err := NewGHClient(Config{GHAppOrgKey: testKeyPEM(t), GHAppPersonalKey: "not pem"}); err == nil { + t.Error("expected error on bad personal key") + } +} + +func TestGenerateJWT_SignsAndIncludesClaims(t *testing.T) { + gh, _, _ := newTestClient(t) + fixed := time.Date(2030, 1, 1, 0, 0, 0, 0, time.UTC) + gh.Now = func() time.Time { return fixed } + + tok, err := gh.GenerateJWT(gh.OrgAppID) + if err != nil { + t.Fatalf("GenerateJWT: %v", err) + } + parsed, err := jwt.Parse(tok, func(t *jwt.Token) (any, error) { + return &gh.OrgKey.PublicKey, nil + }) + if err != nil || !parsed.Valid { + t.Fatalf("parse JWT: %v", err) + } + claims := parsed.Claims.(jwt.MapClaims) + if int64(claims["iat"].(float64)) != fixed.Unix() { + t.Errorf("iat=%v want %v", claims["iat"], fixed.Unix()) + } + if int64(claims["exp"].(float64)) != fixed.Add(10*time.Minute).Unix() { + t.Errorf("exp wrong") + } + if int64(claims["iss"].(float64)) != gh.OrgAppID { + t.Errorf("iss=%v want %d", claims["iss"], gh.OrgAppID) + } +} + +func TestGenerateJWT_UnknownAppID(t *testing.T) { + gh, _, _ := newTestClient(t) + if _, err := gh.GenerateJWT(99999); err == nil { + t.Error("expected error on unknown app_id") + } +} + +func TestAuthenticateApp_SuccessAndCache(t *testing.T) { + gh, mux, _ := newTestClient(t) + var calls int + mux.HandleFunc("/app/installations/42/access_tokens", func(w http.ResponseWriter, r *http.Request) { + calls++ + if r.Method != "POST" { + t.Errorf("method=%s", r.Method) + } + if got := r.Header.Get("Authorization"); !strings.HasPrefix(got, "Bearer ") { + t.Errorf("missing Bearer auth: %q", got) + } + writeJSON(w, 201, map[string]string{"token": "ghs_xyz"}) + }) + + tok, err := gh.AuthenticateApp(context.Background(), 42, gh.OrgAppID) + if err != nil { + t.Fatalf("first call: %v", err) + } + if tok != "ghs_xyz" { + t.Errorf("token=%q", tok) + } + // Second call within TTL is served from cache. + tok2, err := gh.AuthenticateApp(context.Background(), 42, gh.OrgAppID) + if err != nil || tok2 != tok || calls != 1 { + t.Errorf("expected cache hit; calls=%d tok2=%q", calls, tok2) + } +} + +func TestAuthenticateApp_ExpiryRefetches(t *testing.T) { + gh, mux, _ := newTestClient(t) + now := time.Date(2030, 1, 1, 0, 0, 0, 0, time.UTC) + gh.Now = func() time.Time { return now } + gh.TokenTTL = 5 * time.Minute + + var calls int + mux.HandleFunc("/app/installations/1/access_tokens", func(w http.ResponseWriter, r *http.Request) { + calls++ + writeJSON(w, 201, map[string]string{"token": fmt.Sprintf("t%d", calls)}) + }) + + if _, err := gh.AuthenticateApp(context.Background(), 1, gh.OrgAppID); err != nil { + t.Fatal(err) + } + now = now.Add(10 * time.Minute) // TTL expired + t2, err := gh.AuthenticateApp(context.Background(), 1, gh.OrgAppID) + if err != nil { + t.Fatal(err) + } + if t2 != "t2" || calls != 2 { + t.Errorf("expected refetch; calls=%d tok=%s", calls, t2) + } +} + +func TestAuthenticateApp_404ReturnsAPIError(t *testing.T) { + gh, mux, _ := newTestClient(t) + mux.HandleFunc("/app/installations/99/access_tokens", func(w http.ResponseWriter, r *http.Request) { + writeJSON(w, 404, map[string]string{"message": "Not Found"}) + }) + _, err := gh.AuthenticateApp(context.Background(), 99, gh.OrgAppID) + var apiErr *GitHubAPIError + if !errors.As(err, &apiErr) || apiErr.StatusCode != 404 { + t.Fatalf("expected GitHubAPIError{404}, got %v", err) + } +} + +func TestAuthenticateApp_LRUEvictsOldest(t *testing.T) { + gh, mux, _ := newTestClient(t) + gh.TokenMax = 2 + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + writeJSON(w, 201, map[string]string{"token": "t"}) + }) + for i := int64(1); i <= 3; i++ { + if _, err := gh.AuthenticateApp(context.Background(), i, gh.OrgAppID); err != nil { + t.Fatalf("auth %d: %v", i, err) + } + } + if len(gh.tokens) != 2 { + t.Errorf("expected 2 tokens after LRU, got %d", len(gh.tokens)) + } +} + +func TestGetInstallation_OrgVsUserPicksApp(t *testing.T) { + gh, mux, _ := newTestClient(t) + mux.HandleFunc("/app/installations/10", func(w http.ResponseWriter, r *http.Request) { + writeJSON(w, 200, map[string]any{ + "id": 10, + "account": map[string]any{"login": "acme", "type": "Organization", "id": 999}, + }) + }) + inst, err := gh.GetInstallation(context.Background(), 10, EntityOrganization) + if err != nil { + t.Fatalf("GetInstallation: %v", err) + } + if inst.Account.Login != "acme" || inst.Account.Type != "Organization" || inst.Account.ID != 999 { + t.Errorf("account decoded wrong: %+v", inst.Account) + } + if len(inst.Raw) == 0 { + t.Error("Raw not preserved") + } +} + +func TestGetInstallation_404(t *testing.T) { + gh, mux, _ := newTestClient(t) + mux.HandleFunc("/app/installations/0", func(w http.ResponseWriter, r *http.Request) { + writeJSON(w, 404, map[string]string{"message": "nope"}) + }) + _, err := gh.GetInstallation(context.Background(), 0, EntityUser) + var apiErr *GitHubAPIError + if !errors.As(err, &apiErr) || apiErr.StatusCode != 404 { + t.Errorf("expected 404 API err, got %v", err) + } +} + +func TestEnsureRunnerGroup_ReturnsExisting(t *testing.T) { + gh, mux, _ := newTestClient(t) + mux.HandleFunc("/orgs/acme/actions/runner-groups", func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + t.Errorf("expected GET, got %s", r.Method) + } + writeJSON(w, 200, map[string]any{ + "runner_groups": []map[string]any{ + {"id": 7, "name": "other"}, + {"id": 13, "name": "RISE RISC-V Runners"}, + }, + }) + }) + id, err := gh.EnsureRunnerGroup(context.Background(), "tok", "acme", "RISE RISC-V Runners") + if err != nil || id != 13 { + t.Errorf("got id=%d err=%v want 13", id, err) + } +} + +func TestEnsureRunnerGroup_CreatesWhenMissing(t *testing.T) { + gh, mux, _ := newTestClient(t) + var calls []string + mux.HandleFunc("/orgs/acme/actions/runner-groups", func(w http.ResponseWriter, r *http.Request) { + calls = append(calls, r.Method) + if r.Method == "GET" { + writeJSON(w, 200, map[string]any{"runner_groups": []map[string]any{{"id": 1, "name": "other"}}}) + return + } + // POST body has visibility:"all", allows_public_repositories:true + var body map[string]any + _ = json.NewDecoder(r.Body).Decode(&body) + if body["visibility"] != "all" || body["allows_public_repositories"] != true { + t.Errorf("create body wrong: %+v", body) + } + writeJSON(w, 201, map[string]any{"id": 42}) + }) + id, err := gh.EnsureRunnerGroup(context.Background(), "tok", "acme", "RISE RISC-V Runners") + if err != nil || id != 42 { + t.Errorf("got id=%d err=%v want 42", id, err) + } + if len(calls) != 2 || calls[0] != "GET" || calls[1] != "POST" { + t.Errorf("call order: %v", calls) + } +} + +func TestEnsureRunnerGroup_ListErrorPropagates(t *testing.T) { + gh, mux, _ := newTestClient(t) + mux.HandleFunc("/orgs/acme/actions/runner-groups", func(w http.ResponseWriter, r *http.Request) { + writeJSON(w, 500, map[string]string{"message": "boom"}) + }) + _, err := gh.EnsureRunnerGroup(context.Background(), "tok", "acme", "g") + var apiErr *GitHubAPIError + if !errors.As(err, &apiErr) || apiErr.StatusCode != 500 { + t.Errorf("expected 500 API err, got %v", err) + } +} + +func TestEnsureRunnerGroup_CreateErrorPropagates(t *testing.T) { + gh, mux, _ := newTestClient(t) + mux.HandleFunc("/orgs/acme/actions/runner-groups", func(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" { + writeJSON(w, 200, map[string]any{"runner_groups": []any{}}) + return + } + writeJSON(w, 422, map[string]string{"message": "bad"}) + }) + _, err := gh.EnsureRunnerGroup(context.Background(), "tok", "acme", "g") + var apiErr *GitHubAPIError + if !errors.As(err, &apiErr) || apiErr.StatusCode != 422 { + t.Errorf("expected 422 API err, got %v", err) + } +} + +func TestCreateJITRunnerConfigOrg_PostsBody(t *testing.T) { + gh, mux, _ := newTestClient(t) + mux.HandleFunc("/orgs/acme/actions/runners/generate-jitconfig", func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + _ = json.NewDecoder(r.Body).Decode(&body) + if body["name"] != "r-1" || body["runner_group_id"].(float64) != 13 || + body["work_folder"] != "../../../work" { + t.Errorf("body wrong: %+v", body) + } + writeJSON(w, 201, map[string]string{"encoded_jit_config": "ENC"}) + }) + got, err := gh.CreateJITRunnerConfigOrg(context.Background(), "tok", "acme", "r-1", 13, []string{"ubuntu-24.04-riscv"}) + if err != nil || got != "ENC" { + t.Errorf("got %q err=%v", got, err) + } +} + +func TestCreateJITRunnerConfigRepo_UsesGroupID1(t *testing.T) { + gh, mux, _ := newTestClient(t) + mux.HandleFunc("/repos/user/proj/actions/runners/generate-jitconfig", func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + _ = json.NewDecoder(r.Body).Decode(&body) + if body["runner_group_id"].(float64) != 1 { + t.Errorf("repo JIT should use group_id=1, got %v", body["runner_group_id"]) + } + writeJSON(w, 201, map[string]string{"encoded_jit_config": "ENC"}) + }) + got, err := gh.CreateJITRunnerConfigRepo(context.Background(), "tok", "user/proj", "r-1", []string{"x"}) + if err != nil || got != "ENC" { + t.Errorf("got %q err=%v", got, err) + } +} + +func TestCreateJITRunnerConfig_NonCreatedIsError(t *testing.T) { + gh, mux, _ := newTestClient(t) + mux.HandleFunc("/orgs/acme/actions/runners/generate-jitconfig", func(w http.ResponseWriter, r *http.Request) { + writeJSON(w, 422, map[string]string{"message": "bad"}) + }) + _, err := gh.CreateJITRunnerConfigOrg(context.Background(), "tok", "acme", "r", 1, nil) + var apiErr *GitHubAPIError + if !errors.As(err, &apiErr) || apiErr.StatusCode != 422 { + t.Errorf("expected 422 API err, got %v", err) + } +} + +func TestListRunnersOrgGroup_FollowsPagination(t *testing.T) { + gh, mux, srv := newTestClient(t) + base := "/orgs/acme/actions/runner-groups/7/runners" + mux.HandleFunc(base, func(w http.ResponseWriter, r *http.Request) { + page := r.URL.Query().Get("page") + if page == "" { + w.Header().Set("Link", fmt.Sprintf(`<%s%s?page=2&per_page=100>; rel="next"`, srv.URL, base)) + writeJSON(w, 200, map[string]any{"runners": []map[string]any{ + {"id": 1, "name": "a", "status": "online", "busy": false}, + }}) + return + } + writeJSON(w, 200, map[string]any{"runners": []map[string]any{ + {"id": 2, "name": "b", "status": "offline", "busy": false}, + }}) + }) + out, err := gh.ListRunnersOrgGroup(context.Background(), "tok", "acme", 7) + if err != nil { + t.Fatalf("list: %v", err) + } + if len(out) != 2 || out[0].Name != "a" || out[1].Name != "b" { + t.Errorf("pagination didn't concatenate: %+v", out) + } +} + +func TestListRunnersRepo_NoPaginationHeader(t *testing.T) { + gh, mux, _ := newTestClient(t) + mux.HandleFunc("/repos/user/proj/actions/runners", func(w http.ResponseWriter, r *http.Request) { + writeJSON(w, 200, map[string]any{"runners": []map[string]any{ + {"id": 9, "name": "rise-riscv-runner-staging-abc", "status": "online", "busy": true}, + }}) + }) + out, err := gh.ListRunnersRepo(context.Background(), "tok", "user/proj") + if err != nil || len(out) != 1 || !out[0].Busy { + t.Errorf("got %+v err=%v", out, err) + } +} + +func TestListRunners_NonOKErrors(t *testing.T) { + gh, mux, _ := newTestClient(t) + mux.HandleFunc("/orgs/acme/actions/runner-groups/1/runners", func(w http.ResponseWriter, r *http.Request) { + writeJSON(w, 500, map[string]string{"message": "boom"}) + }) + _, err := gh.ListRunnersOrgGroup(context.Background(), "tok", "acme", 1) + var apiErr *GitHubAPIError + if !errors.As(err, &apiErr) || apiErr.StatusCode != 500 { + t.Errorf("expected 500 API err, got %v", err) + } +} + +func TestDeleteRunner_204And404Succeed(t *testing.T) { + gh, mux, _ := newTestClient(t) + for _, status := range []int{204, 404} { + path := fmt.Sprintf("/orgs/acme/actions/runners/%d", status) + mux.HandleFunc(path, func(s int) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(s) } + }(status)) + if err := gh.DeleteRunnerOrg(context.Background(), "tok", "acme", int64(status)); err != nil { + t.Errorf("status %d should be success: %v", status, err) + } + } +} + +func TestDeleteRunnerOrg_5xxIsError(t *testing.T) { + gh, mux, _ := newTestClient(t) + mux.HandleFunc("/orgs/acme/actions/runners/1", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + }) + err := gh.DeleteRunnerOrg(context.Background(), "tok", "acme", 1) + var apiErr *GitHubAPIError + if !errors.As(err, &apiErr) || apiErr.StatusCode != 500 { + t.Errorf("expected 500 API err, got %v", err) + } +} + +func TestDeleteRunnerRepo_204(t *testing.T) { + gh, mux, _ := newTestClient(t) + mux.HandleFunc("/repos/user/proj/actions/runners/5", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(204) + }) + if err := gh.DeleteRunnerRepo(context.Background(), "tok", "user/proj", 5); err != nil { + t.Errorf("expected success: %v", err) + } +} + +func TestGetJobInfo_Success(t *testing.T) { + gh, mux, _ := newTestClient(t) + mux.HandleFunc("/repos/user/proj/actions/jobs/77", func(w http.ResponseWriter, r *http.Request) { + writeJSON(w, 200, map[string]any{"status": "in_progress", "conclusion": "cancelled", "runner_name": "rn"}) + }) + j, err := gh.GetJobInfo(context.Background(), "tok", "user/proj", 77) + if err != nil { + t.Fatalf("GetJobInfo: %v", err) + } + if j.Status != "in_progress" || j.Conclusion == nil || *j.Conclusion != "cancelled" || j.RunnerName != "rn" { + t.Errorf("decoded wrong: %+v", j) + } +} + +func TestGetJobInfo_404(t *testing.T) { + gh, mux, _ := newTestClient(t) + mux.HandleFunc("/repos/user/proj/actions/jobs/0", func(w http.ResponseWriter, r *http.Request) { + writeJSON(w, 404, map[string]string{"message": "missing"}) + }) + _, err := gh.GetJobInfo(context.Background(), "tok", "user/proj", 0) + var apiErr *GitHubAPIError + if !errors.As(err, &apiErr) || apiErr.StatusCode != 404 { + t.Errorf("expected 404 API err, got %v", err) + } +} + +func TestNextLink(t *testing.T) { + cases := []struct{ header, want string }{ + {"", ""}, + {`; rel="next", ; rel="last"`, "https://api/x?page=2"}, + {`; rel="first"`, ""}, + } + for i, c := range cases { + if got := nextLink(c.header); got != c.want { + t.Errorf("case %d: nextLink(%q)=%q want %q", i, c.header, got, c.want) + } + } +} + +func TestApiMessage_PrefersStructured(t *testing.T) { + got := apiMessage([]byte(`{"message":"hi"}`)) + if got != "hi" { + t.Errorf("got %q want hi", got) + } + got = apiMessage([]byte("plain")) + if got != "plain" { + t.Errorf("got %q want fall-through", got) + } +} + +func TestDoJSON_TransportError(t *testing.T) { + gh := &GHClient{ + BaseURL: "http://127.0.0.1:1", // unreachable + HTTP: &http.Client{Timeout: 100 * time.Millisecond}, + } + _, _, err := gh.doJSON(context.Background(), "GET", "/x", nil, "tok") + if err == nil { + t.Error("expected transport error") + } +} + +// silence unused-import lint when imports rotate +var _ = io.Discard diff --git a/container-go/internal/k8s.go b/container/internal/k8s.go similarity index 98% rename from container-go/internal/k8s.go rename to container/internal/k8s.go index 2944447..0a12893 100644 --- a/container-go/internal/k8s.go +++ b/container/internal/k8s.go @@ -43,8 +43,8 @@ func NewK8sClientFromInterface(cs kubernetes.Interface) *K8sClient { // ProvisionRunner creates the runner pod. The exact shape (host-network, // privileged, two emptyDir volumes, single container, RUNNER_JITCONFIG env, -// ephemeral-storage limit on scw-em-* only) is load-bearing — multiple -// invariants in CONTRACT.md pin it down. Don't tweak without a test. +// ephemeral-storage limit on scw-em-* only) is load-bearing. Don't tweak +// without a test. func (k *K8sClient) ProvisionRunner(ctx context.Context, jitConfig, runnerName, image, pool string, entity Entity) error { limits := corev1.ResourceList{ "riseproject.com/runner": resource.MustParse("1"), diff --git a/container/internal/k8s_test.go b/container/internal/k8s_test.go new file mode 100644 index 0000000..d39e1c4 --- /dev/null +++ b/container/internal/k8s_test.go @@ -0,0 +1,547 @@ +package internal + +import ( + "context" + "strings" + "testing" + "time" + + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" +) + +// fakePod returns the pod the fake clientset stored under runnerName. +func fakePod(t *testing.T, k *K8sClient, runnerName string) *corev1.Pod { + t.Helper() + pod, err := k.cs.CoreV1().Pods("default").Get(context.Background(), runnerName, metav1.GetOptions{}) + if err != nil { + t.Fatalf("get pod: %v", err) + } + return pod +} + +// TestProvisionRunner_UsesHostNetwork asserts pod.spec.hostNetwork=true on every +// pool — invariant 9de4c35. +func TestProvisionRunner_UsesHostNetwork(t *testing.T) { + for _, pool := range []string{"scw-em-rv1", "cloudv10x-jupiter"} { + k := NewK8sClientFromInterface(fake.NewSimpleClientset()) + if err := k.ProvisionRunner(context.Background(), "jit", "runner-"+pool, "img", pool, Entity{ID: 1, Name: "ent"}); err != nil { + t.Fatalf("provision: %v", err) + } + p := fakePod(t, k, "runner-"+pool) + if !p.Spec.HostNetwork { + t.Errorf("pool=%s expected HostNetwork=true", pool) + } + } +} + +// TestProvisionRunner_EmptyDirVolumes asserts the two emptyDir volumes for +// /var/lib/docker and /var/lib/k0s exist on every pool (invariants 0028278/653a5ba). +func TestProvisionRunner_EmptyDirVolumes(t *testing.T) { + k := NewK8sClientFromInterface(fake.NewSimpleClientset()) + if err := k.ProvisionRunner(context.Background(), "jit", "r", "img", "scw-em-rv1", Entity{ID: 1, Name: "ent"}); err != nil { + t.Fatalf("provision: %v", err) + } + p := fakePod(t, k, "r") + type mount struct { + name string + path string + } + want := []mount{{"docker-graph", "/var/lib/docker"}, {"k0s", "/var/lib/k0s"}} + if len(p.Spec.Containers) != 1 { + t.Fatalf("expected single container, got %d", len(p.Spec.Containers)) + } + for _, m := range want { + var foundVolume, foundMount bool + for _, v := range p.Spec.Volumes { + if v.Name == m.name && v.EmptyDir != nil { + foundVolume = true + } + } + for _, vm := range p.Spec.Containers[0].VolumeMounts { + if vm.Name == m.name && vm.MountPath == m.path { + foundMount = true + } + } + if !foundVolume { + t.Errorf("volume %s emptyDir not found", m.name) + } + if !foundMount { + t.Errorf("volumeMount %s at %s not found", m.name, m.path) + } + } +} + +// TestProvisionRunner_DiskLimitsOnlyOnScwEM asserts ephemeral-storage=90Gi +// only on scw-em-* pools (invariant 3286cf6). +func TestProvisionRunner_DiskLimitsOnlyOnScwEM(t *testing.T) { + tests := []struct { + pool string + wantDisk bool + }{ + {"scw-em-rv1", true}, + {"scw-em-something", true}, + {"cloudv10x-jupiter", false}, + } + for _, tc := range tests { + k := NewK8sClientFromInterface(fake.NewSimpleClientset()) + if err := k.ProvisionRunner(context.Background(), "jit", "r-"+tc.pool, "img", tc.pool, Entity{ID: 1, Name: "ent"}); err != nil { + t.Fatalf("[%s] provision: %v", tc.pool, err) + } + p := fakePod(t, k, "r-"+tc.pool) + limits := p.Spec.Containers[0].Resources.Limits + _, has := limits["ephemeral-storage"] + if has != tc.wantDisk { + t.Errorf("pool=%s ephemeral-storage present=%v want=%v", tc.pool, has, tc.wantDisk) + } + if has { + q := limits["ephemeral-storage"] + want := resource.MustParse("90Gi") + if q.Cmp(want) != 0 { + t.Errorf("pool=%s ephemeral-storage=%s want 90Gi", tc.pool, q.String()) + } + } + if _, has := limits["riseproject.com/runner"]; !has { + t.Errorf("pool=%s runner limit missing", tc.pool) + } + } +} + +// TestProvisionRunner_NoSidecar asserts pod has exactly one container, no +// docker-certs volume, no DOCKER_* env (invariant 5c5004f). +func TestProvisionRunner_NoSidecar(t *testing.T) { + k := NewK8sClientFromInterface(fake.NewSimpleClientset()) + if err := k.ProvisionRunner(context.Background(), "jit", "r", "img", "scw-em-rv1", Entity{ID: 1, Name: "ent"}); err != nil { + t.Fatalf("provision: %v", err) + } + p := fakePod(t, k, "r") + if len(p.Spec.Containers) != 1 { + t.Fatalf("expected single container, got %d", len(p.Spec.Containers)) + } + for _, v := range p.Spec.Volumes { + if strings.Contains(v.Name, "docker-cert") { + t.Errorf("docker-certs volume %s leaked into spec", v.Name) + } + } + for _, e := range p.Spec.Containers[0].Env { + if strings.HasPrefix(e.Name, "DOCKER_") { + t.Errorf("DOCKER_* env leaked: %s", e.Name) + } + } + // Required env present + mustHaveEnv(t, p.Spec.Containers[0].Env, "RUNNER_WAIT_FOR_DOCKER_IN_SECONDS", "60") + mustHaveEnv(t, p.Spec.Containers[0].Env, "RUNNER_JITCONFIG", "jit") +} + +func mustHaveEnv(t *testing.T, env []corev1.EnvVar, name, value string) { + t.Helper() + for _, e := range env { + if e.Name == name { + if e.Value != value { + t.Errorf("env %s=%q want %q", name, e.Value, value) + } + return + } + } + t.Errorf("env %s missing", name) +} + +// TestProvisionRunner_Labels asserts the four pod labels are set. +func TestProvisionRunner_Labels(t *testing.T) { + k := NewK8sClientFromInterface(fake.NewSimpleClientset()) + if err := k.ProvisionRunner(context.Background(), "jit", "r", "img", "scw-em-rv1", Entity{ID: 42, Name: "pytorch"}); err != nil { + t.Fatalf("provision: %v", err) + } + p := fakePod(t, k, "r") + want := map[string]string{ + "app": "rise-riscv-runner", + "riseproject.dev/entity_id": "42", + "riseproject.dev/entity_name": "pytorch", + "riseproject.dev/board": "scw-em-rv1", + } + for k, v := range want { + if p.Labels[k] != v { + t.Errorf("label %s=%q want %q", k, p.Labels[k], v) + } + } + if p.Spec.NodeSelector["riseproject.dev/board"] != "scw-em-rv1" { + t.Errorf("nodeSelector board mismatch: %v", p.Spec.NodeSelector) + } +} + +// TestProvisionRunner_TimeoutsAndPrivileged asserts the lesser invariants: +// activeDeadlineSeconds=525600, restartPolicy=Never, container privileged=true. +func TestProvisionRunner_TimeoutsAndPrivileged(t *testing.T) { + k := NewK8sClientFromInterface(fake.NewSimpleClientset()) + if err := k.ProvisionRunner(context.Background(), "jit", "r", "img", "scw-em-rv1", Entity{ID: 1, Name: "ent"}); err != nil { + t.Fatalf("provision: %v", err) + } + p := fakePod(t, k, "r") + if p.Spec.ActiveDeadlineSeconds == nil || *p.Spec.ActiveDeadlineSeconds != 525600 { + t.Errorf("activeDeadlineSeconds=%v want 525600", p.Spec.ActiveDeadlineSeconds) + } + if p.Spec.RestartPolicy != corev1.RestartPolicyNever { + t.Errorf("restartPolicy=%v want Never", p.Spec.RestartPolicy) + } + sc := p.Spec.Containers[0].SecurityContext + if sc == nil || sc.Privileged == nil || !*sc.Privileged { + t.Errorf("container not privileged") + } +} + +// makeNode returns a fake node with allocatable runner capacity and the +// given board label. +func makeNode(name, board string, capacity int64) *corev1.Node { + return &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Labels: map[string]string{"riseproject.dev/board": board}, + }, + Status: corev1.NodeStatus{ + Allocatable: corev1.ResourceList{ + "riseproject.com/runner": *resource.NewQuantity(capacity, resource.DecimalSI), + }, + }, + } +} + +// makePod returns a fake runner pod in the given phase on the given board. +func makePod(name, board string, phase corev1.PodPhase) *corev1.Pod { + return &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: "default", + Labels: map[string]string{ + "app": "rise-riscv-runner", + "riseproject.dev/board": board, + }, + }, + Status: corev1.PodStatus{Phase: phase}, + } +} + +func TestAvailableSlots_TotalMinusActive(t *testing.T) { + cs := fake.NewSimpleClientset( + makeNode("n1", "scw-em-rv1", 3), + makeNode("n2", "scw-em-rv1", 2), + makeNode("other", "cloudv10x-jupiter", 5), // different board, ignored + makePod("p1", "scw-em-rv1", corev1.PodPending), + makePod("p2", "scw-em-rv1", corev1.PodRunning), + makePod("p3", "scw-em-rv1", corev1.PodSucceeded), // terminal, doesn't count + makePod("po", "cloudv10x-jupiter", corev1.PodRunning), + ) + k := NewK8sClientFromInterface(cs) + cap, err := k.AvailableSlots(context.Background(), "scw-em-rv1") + if err != nil { + t.Fatalf("AvailableSlots: %v", err) + } + if cap.Total != 5 || cap.Active != 2 || cap.Available != 3 { + t.Errorf("got %+v, want {Total:5 Active:2 Available:3}", cap) + } +} + +func TestAvailableSlots_NoMatchingNodes(t *testing.T) { + k := NewK8sClientFromInterface(fake.NewSimpleClientset()) + cap, err := k.AvailableSlots(context.Background(), "scw-em-rv1") + if err != nil { + t.Fatalf("AvailableSlots: %v", err) + } + if cap.Total != 0 || cap.Active != 0 || cap.Available != 0 { + t.Errorf("got %+v, want zero", cap) + } +} + +func TestListPods_FiltersByAppLabel(t *testing.T) { + cs := fake.NewSimpleClientset( + makePod("r1", "scw-em-rv1", corev1.PodRunning), + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "noise", Namespace: "default", + Labels: map[string]string{"app": "something-else"}, + }, + }, + ) + k := NewK8sClientFromInterface(cs) + pods, err := k.ListPods(context.Background()) + if err != nil { + t.Fatalf("ListPods: %v", err) + } + if len(pods) != 1 || pods[0].Name != "r1" { + t.Errorf("filter failed: %+v", pods) + } +} + +func TestDeletePod_404IsSilentSuccess(t *testing.T) { + cs := fake.NewSimpleClientset() + k := NewK8sClientFromInterface(cs) + if err := k.DeletePod(context.Background(), "nope"); err != nil { + t.Fatalf("expected nil on 404, got %v", err) + } +} + +func TestDeletePod_Deletes(t *testing.T) { + cs := fake.NewSimpleClientset(makePod("p", "scw-em-rv1", corev1.PodSucceeded)) + k := NewK8sClientFromInterface(cs) + if err := k.DeletePod(context.Background(), "p"); err != nil { + t.Fatalf("DeletePod: %v", err) + } + if _, err := cs.CoreV1().Pods("default").Get(context.Background(), "p", metav1.GetOptions{}); err == nil { + t.Fatalf("pod should be gone") + } +} + +func TestKillPod_PatchesActiveDeadlineSeconds(t *testing.T) { + cs := fake.NewSimpleClientset(makePod("p", "scw-em-rv1", corev1.PodRunning)) + k := NewK8sClientFromInterface(cs) + if err := k.KillPod(context.Background(), "p"); err != nil { + t.Fatalf("KillPod: %v", err) + } + got, err := cs.CoreV1().Pods("default").Get(context.Background(), "p", metav1.GetOptions{}) + if err != nil { + t.Fatal(err) + } + if got.Spec.ActiveDeadlineSeconds == nil || *got.Spec.ActiveDeadlineSeconds != 1 { + t.Errorf("activeDeadlineSeconds=%v want 1", got.Spec.ActiveDeadlineSeconds) + } +} + +func TestKillPod_404IsSilentSuccess(t *testing.T) { + cs := fake.NewSimpleClientset() + k := NewK8sClientFromInterface(cs) + if err := k.KillPod(context.Background(), "nope"); err != nil { + t.Fatalf("expected nil on 404, got %v", err) + } +} + +func TestGetPodEvents_SortedAscending(t *testing.T) { + t1 := time.Now().Add(-10 * time.Minute) + t2 := t1.Add(5 * time.Minute) + // fake clientset doesn't honour the FieldSelector — only seed events + // that involve the target pod so we exercise the sort path cleanly. + cs := fake.NewSimpleClientset( + &corev1.Event{ + ObjectMeta: metav1.ObjectMeta{Name: "e2", Namespace: "default"}, + InvolvedObject: corev1.ObjectReference{Name: "p"}, + Type: "Warning", + Reason: "Late", + LastTimestamp: metav1.Time{Time: t2}, + }, + &corev1.Event{ + ObjectMeta: metav1.ObjectMeta{Name: "e1", Namespace: "default"}, + InvolvedObject: corev1.ObjectReference{Name: "p"}, + Type: "Normal", + Reason: "Early", + LastTimestamp: metav1.Time{Time: t1}, + }, + ) + k := NewK8sClientFromInterface(cs) + evs, err := k.GetPodEvents(context.Background(), "p") + if err != nil { + t.Fatalf("GetPodEvents: %v", err) + } + if len(evs) != 2 { + t.Fatalf("len=%d want 2", len(evs)) + } + if evs[0].Reason != "Early" || evs[1].Reason != "Late" { + t.Errorf("not sorted ascending: %v", evs) + } +} + +func TestGetPodEvents_FallsBackToEventTime(t *testing.T) { + t1 := time.Now().Add(-1 * time.Minute) + cs := fake.NewSimpleClientset( + &corev1.Event{ + ObjectMeta: metav1.ObjectMeta{Name: "e", Namespace: "default"}, + InvolvedObject: corev1.ObjectReference{Name: "p"}, + EventTime: metav1.MicroTime{Time: t1}, + Type: "Normal", + Reason: "X", + }, + ) + k := NewK8sClientFromInterface(cs) + evs, err := k.GetPodEvents(context.Background(), "p") + if err != nil { + t.Fatal(err) + } + if len(evs) != 1 || evs[0].LastSeen == nil { + t.Fatalf("expected EventTime to become LastSeen: %+v", evs) + } +} + +func TestConvertPod_CapturesContainerStates(t *testing.T) { + now := metav1.Now() + finished := metav1.NewTime(now.Time.Add(1 * time.Minute)) + exit := int32(137) + p := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: "p", Namespace: "default"}, + Spec: corev1.PodSpec{NodeName: "node-1"}, + Status: corev1.PodStatus{ + Phase: corev1.PodFailed, + Message: "container failed", + Reason: "OOMKilled", + ContainerStatuses: []corev1.ContainerStatus{{ + Name: "runner", + State: corev1.ContainerState{ + Terminated: &corev1.ContainerStateTerminated{ + ExitCode: exit, + Reason: "OOMKilled", + Message: "out of memory", + FinishedAt: finished, + }, + }, + }, { + Name: "init", + State: corev1.ContainerState{ + Waiting: &corev1.ContainerStateWaiting{ + Reason: "ImagePullBackOff", + Message: "backoff", + }, + }, + }, { + Name: "running", + State: corev1.ContainerState{ + Running: &corev1.ContainerStateRunning{StartedAt: now}, + }, + }}, + Conditions: []corev1.PodCondition{{ + Type: corev1.PodReady, Status: corev1.ConditionTrue, + LastTransitionTime: now, + }}, + }, + } + out := convertPod(p) + if out.Phase != "Failed" || out.NodeName != "node-1" || out.Message != "container failed" { + t.Errorf("pod fields wrong: %+v", out) + } + if out.ReadyTransition == nil { + t.Errorf("ReadyTransition not captured") + } + if len(out.Containers) != 3 { + t.Fatalf("expected 3 containers, got %d", len(out.Containers)) + } + for _, cs := range out.Containers { + switch cs.Name { + case "runner": + if !cs.Terminated || cs.ExitCode == nil || *cs.ExitCode != 137 || cs.Reason != "OOMKilled" { + t.Errorf("runner state wrong: %+v", cs) + } + case "init": + if !cs.Waiting || cs.WaitingReason != "ImagePullBackOff" { + t.Errorf("init state wrong: %+v", cs) + } + case "running": + if !cs.Running || cs.RunningStarted == nil { + t.Errorf("running state wrong: %+v", cs) + } + } + } +} + +func TestPod_FinishedAtReturnsLatest(t *testing.T) { + t1 := time.Now().Add(-5 * time.Minute) + t2 := t1.Add(2 * time.Minute) // later + p := Pod{Containers: []ContainerStatus{ + {Name: "a", Terminated: true, TerminatedAt: &t1}, + {Name: "b", Terminated: true, TerminatedAt: &t2}, + }} + got := p.FinishedAt() + if got == nil || !got.Equal(t2) { + t.Errorf("FinishedAt=%v want %v", got, t2) + } +} + +func TestPod_FinishedAtNilWhenNotTerminated(t *testing.T) { + p := Pod{Containers: []ContainerStatus{{Name: "a", Running: true}}} + if p.FinishedAt() != nil { + t.Errorf("expected nil") + } +} + +func TestPod_RunnerStartedAtFromRunner(t *testing.T) { + start := time.Now().Add(-1 * time.Minute) + p := Pod{Containers: []ContainerStatus{ + {Name: "init", Running: false}, + {Name: "runner", Running: true, RunningStarted: &start}, + }} + got := p.RunnerStartedAt() + if got == nil || !got.Equal(start) { + t.Errorf("RunnerStartedAt=%v want %v", got, start) + } +} + +func TestPod_RunnerStartedAtFallsBackToReady(t *testing.T) { + ready := time.Now() + p := Pod{ReadyTransition: &ready} + got := p.RunnerStartedAt() + if got == nil || !got.Equal(ready) { + t.Errorf("RunnerStartedAt=%v want %v", got, ready) + } +} + +func TestCollectPodFailureInfo_BuildsV2Shape(t *testing.T) { + exit := int32(1) + terminatedAt := time.Now() + pod := Pod{ + Name: "p", Message: "msg", Reason: "BadStuff", + Containers: []ContainerStatus{{ + Name: "runner", Terminated: true, TerminatedAt: &terminatedAt, + ExitCode: &exit, Reason: "Error", Message: "boom", + }, { + Name: "waiting", Waiting: true, + WaitingReason: "ImagePull", WaitingMessage: "pulling", + }}, + } + cs := fake.NewSimpleClientset( + &corev1.Event{ + ObjectMeta: metav1.ObjectMeta{Name: "e1", Namespace: "default"}, + InvolvedObject: corev1.ObjectReference{Name: "p"}, + Type: "Warning", Reason: "Failed", Message: "uh-oh", Count: 1, + }, + ) + k := NewK8sClientFromInterface(cs) + info := k.CollectPodFailureInfo(context.Background(), pod, ReasonPodFailed) + if info.Version != 2 || info.Reason != ReasonPodFailed { + t.Errorf("header fields wrong: %+v", info) + } + if info.PodMessage != "msg" || info.PodReason != "BadStuff" { + t.Errorf("pod fields lost: %+v", info) + } + runner, ok := info.Containers["runner"] + if !ok || runner.ExitCode == nil || *runner.ExitCode != 1 || runner.Reason != "Error" { + t.Errorf("runner container wrong: %+v", runner) + } + waiting := info.Containers["waiting"] + if waiting.Reason != "ImagePull" || waiting.Message != "pulling" { + t.Errorf("waiting container falls back to waiting fields: %+v", waiting) + } + if len(info.Events) != 1 || info.Events[0].Reason != "Failed" { + t.Errorf("events not collected: %+v", info.Events) + } +} + +func TestEventTime_PrefersLastOverFirst(t *testing.T) { + t1 := time.Now() + t2 := t1.Add(time.Minute) + ev := PodEvent{FirstSeen: &t1, LastSeen: &t2} + got := eventTime(ev) + if got == nil || !got.Equal(t2) { + t.Errorf("got %v want %v", got, t2) + } +} + +func TestEventTime_FirstWhenLastNil(t *testing.T) { + t1 := time.Now() + ev := PodEvent{FirstSeen: &t1} + got := eventTime(ev) + if got == nil || !got.Equal(t1) { + t.Errorf("got %v want %v", got, t1) + } +} + +func TestEventTime_NilWhenBothMissing(t *testing.T) { + if eventTime(PodEvent{}) != nil { + t.Error("expected nil") + } +} diff --git a/container-go/internal/log.go b/container/internal/log.go similarity index 100% rename from container-go/internal/log.go rename to container/internal/log.go diff --git a/container/internal/log_test.go b/container/internal/log_test.go new file mode 100644 index 0000000..1fd021e --- /dev/null +++ b/container/internal/log_test.go @@ -0,0 +1,44 @@ +package internal + +import ( + "bytes" + "log/slog" + "strings" + "testing" +) + +func TestInitSlog_Levels(t *testing.T) { + cases := []struct { + level string + emitDbg bool + emitWrn bool + }{ + {"DEBUG", true, true}, + {"debug", true, true}, + {"INFO", false, true}, + {"", false, true}, + {"warn", false, true}, + {"WARNING", false, true}, + {"ERROR", false, false}, + {"garbage", false, true}, // default → INFO + } + for _, c := range cases { + t.Run(c.level, func(t *testing.T) { + var buf bytes.Buffer + InitSlog(c.level, &buf) + slog.Debug("dbg-line") + slog.Warn("wrn-line") + slog.Error("err-line") + s := buf.String() + if strings.Contains(s, "dbg-line") != c.emitDbg { + t.Errorf("debug emit: got=%v want=%v out=%q", !c.emitDbg, c.emitDbg, s) + } + if strings.Contains(s, "wrn-line") != c.emitWrn { + t.Errorf("warn emit: got=%v want=%v out=%q", !c.emitWrn, c.emitWrn, s) + } + if !strings.Contains(s, "err-line") { + t.Errorf("error always emits: %q", s) + } + }) + } +} diff --git a/container-go/internal/testutil/fakes.go b/container/internal/testutil/fakes.go similarity index 91% rename from container-go/internal/testutil/fakes.go rename to container/internal/testutil/fakes.go index 37b684e..0542ec1 100644 --- a/container-go/internal/testutil/fakes.go +++ b/container/internal/testutil/fakes.go @@ -7,7 +7,7 @@ import ( "sync" "time" - "github.com/riseproject-dev/riscv-runner-app/container-go/internal" + "github.com/riseproject-dev/riscv-runner-app/container/internal" ) // FakeDB satisfies internal.DB. Public fields are inspected by tests; methods @@ -30,13 +30,16 @@ type FakeDB struct { OnMarkJobComplete func(int64, string) (string, error) OnMarkJobFailed func(int64, internal.FailureInfo) (string, error) - OnGetActiveJobs func() ([]internal.Job, error) - OnGetPendingJobs func() ([]internal.Job, error) - OnGetWorkersReconcile func(time.Duration) ([]internal.Worker, error) - OnGetEntityIDInstall func(int64) (int64, bool, error) - OnGetEntityIDJob func(int64) (int64, bool, error) - OnGetEventsByEntityID func(int64) ([]internal.InstallationEvent, error) - OnGetPayloadByID func(int64) ([]byte, error) + OnGetActiveJobs func() ([]internal.Job, error) + OnGetPendingJobs func() ([]internal.Job, error) + OnGetWorkersReconcile func(time.Duration) ([]internal.Worker, error) + OnGetEntityIDInstall func(int64) (int64, bool, error) + OnGetEntityIDJob func(int64) (int64, bool, error) + OnGetEventsByEntityID func(int64) ([]internal.InstallationEvent, error) + OnGetPayloadByID func(int64) ([]byte, error) + OnGetActiveJobsAndWorkers func() ([]internal.Job, []internal.Worker, error) + OnGetAllJobs func(start, end string, page, perPage int) ([]internal.Job, int, error) + OnGetAllWorkers func(start, end string, page, perPage int) ([]internal.Worker, int, error) WorkerStatus map[string]string // last-known status; tests poke this MarkRunning []string @@ -148,6 +151,9 @@ func (f *FakeDB) GetPendingJobs(ctx context.Context) ([]internal.Job, error) { } func (f *FakeDB) GetAllJobs(ctx context.Context, start, end string, page, perPage int) ([]internal.Job, int, error) { + if f.OnGetAllJobs != nil { + return f.OnGetAllJobs(start, end, page, perPage) + } return f.Jobs, len(f.Jobs), nil } @@ -212,6 +218,9 @@ func (f *FakeDB) MarkWorkerOrphaned(ctx context.Context, pod string) error { } func (f *FakeDB) GetActiveJobsAndWorkers(ctx context.Context) ([]internal.Job, []internal.Worker, error) { + if f.OnGetActiveJobsAndWorkers != nil { + return f.OnGetActiveJobsAndWorkers() + } return f.Jobs, f.Workers, nil } @@ -220,6 +229,9 @@ func (f *FakeDB) GetActiveWorkers(ctx context.Context) ([]internal.Worker, error } func (f *FakeDB) GetAllWorkers(ctx context.Context, start, end string, page, perPage int) ([]internal.Worker, int, error) { + if f.OnGetAllWorkers != nil { + return f.OnGetAllWorkers(start, end, page, perPage) + } return f.Workers, len(f.Workers), nil } @@ -386,6 +398,7 @@ type FakeKube struct { KillCalls []string OnProvisionRunner func(jit, name, image, pool string, entity internal.Entity) error + OnGetPodEvents func(podName string) ([]internal.PodEvent, error) } // NewFakeKube allocates the maps so callers can mutate them directly. @@ -420,6 +433,9 @@ func (f *FakeKube) ListPods(ctx context.Context) ([]internal.Pod, error) { } func (f *FakeKube) GetPodEvents(ctx context.Context, podName string) ([]internal.PodEvent, error) { + if f.OnGetPodEvents != nil { + return f.OnGetPodEvents(podName) + } return f.EventsByPod[podName], nil } diff --git a/container/k8s.py b/container/k8s.py deleted file mode 100644 index 947198f..0000000 --- a/container/k8s.py +++ /dev/null @@ -1,282 +0,0 @@ -from __future__ import annotations - -import datetime -import functools -import logging -from enum import Enum -import kubernetes as k8s -import yaml - -from constants import * - -logger = logging.getLogger(__name__) - - -class FailureReason(str, Enum): - POD_ALLOCATION_FAILURE = "pod_allocation_failure" - POD_FAILED = "pod_failed" - POD_STUCK_PENDING = "pod_stuck_pending" - RUNNER_NEVER_REGISTERED = "runner_never_registered" - RUNNER_IDLE = "runner_idle" - - -@functools.lru_cache(maxsize=1) -def _init_client(): - """Create a Kubernetes API client from a kubeconfig env var.""" - return k8s.config.new_client_from_config_dict(yaml.safe_load(K8S_KUBECONFIG)) - - -def provision_runner(jit_config, runner_name, k8s_image, k8s_pool, entity_id, entity_name): - """Provision a new runner in a Kubernetes pod. - - k8s_pool is the board name (e.g. "scw-em-rv1"). The nodeSelector is - reconstructed internally from it. - """ - node_selector = {"riseproject.dev/board": k8s_pool} - - with _init_client() as client: - api = k8s.client.CoreV1Api(client) - - resources = { - "limits": { - "riseproject.com/runner": "1", - } - } - - if k8s_pool.startswith("scw-em-"): - resources["limits"]["ephemeral-storage"] = "90Gi" - - pod_manifest = { - "apiVersion": "v1", - "kind": "Pod", - "metadata": { - "name": runner_name, - "labels": { - "app": "rise-riscv-runner", - "riseproject.dev/entity_id": str(entity_id), - "riseproject.dev/entity_name": str(entity_name), - "riseproject.dev/board": k8s_pool, - }, - }, - "spec": { - "nodeSelector": node_selector, - # 24h queue limit + 5d execution limit + 2h buffer = 525600s - "activeDeadlineSeconds": 525600, - "restartPolicy": "Never", - "hostNetwork": True, - "containers": [ - { - "name": "runner", - "image": k8s_image, - "imagePullPolicy": "IfNotPresent", - # privileged is required so the in-container dockerd can set up iptables rules and the docker0 bridge. - "securityContext": {"privileged": True}, - "env": [ - {"name": "RUNNER_WAIT_FOR_DOCKER_IN_SECONDS", "value": "60"}, - {"name": "RUNNER_JITCONFIG", "value": jit_config}, - ], - "resources": resources, - "volumeMounts": [ - { - "name": "docker-graph", - "mountPath": "/var/lib/docker", - }, - { - "name": "k0s", - "mountPath": "/var/lib/k0s", - }, - ], - }, - ], - "volumes": [ - { - "name": "docker-graph", - "emptyDir": {}, - }, - { - "name": "k0s", - "emptyDir": {}, - }, - ], - } - } - - api.create_namespaced_pod(body=pod_manifest, namespace="default") - - -def delete_pod(pod): - """Delete a runner pod.""" - assert pod, "Pod must be provided to delete it" - with _init_client() as client: - api = k8s.client.CoreV1Api(client) - try: - api.delete_namespaced_pod(name=pod.metadata.name, namespace="default") - logger.info("Deleted runner pod %s", pod.metadata.name) - return f"Pod {pod.metadata.name} deleted successfully." - except k8s.client.exceptions.ApiException as e: - if e.status == 404: - logger.debug("Pod %s not found, already deleted", pod.metadata.name) - return f"Pod {pod.metadata.name} not found." - raise - - -def kill_pod(pod): - """Force a pod to transition to Failed phase (reason=DeadlineExceeded). - - Patches spec.activeDeadlineSeconds to 1. The kubelet compares this against - now() - pod.status.startTime and, when exceeded, marks phase=Failed and - SIGTERM/SIGKILLs the containers. Unlike delete_pod(), the pod stays in the - cluster so logs/events remain inspectable until the grace window removes it. - """ - assert pod, "Pod must be provided to kill it" - body = {"spec": {"activeDeadlineSeconds": 1}} - with _init_client() as client: - api = k8s.client.CoreV1Api(client) - try: - api.patch_namespaced_pod(name=pod.metadata.name, namespace="default", body=body) - logger.info("Killed runner pod %s (activeDeadlineSeconds=1)", pod.metadata.name) - except k8s.client.exceptions.ApiException as e: - if e.status == 404: - logger.debug("Pod %s not found, already gone", pod.metadata.name) - return - raise - - -def get_available_slots(label_selector): - """Check if there's an available runner slot on nodes matching the selector.""" - with _init_client() as client: - api = k8s.client.CoreV1Api(client) - - nodes = api.list_node(label_selector=label_selector) - total = sum(int(node.status.allocatable.get("riseproject.com/runner", "0")) for node in nodes.items) - - pods = api.list_namespaced_pod(label_selector=f"app=rise-riscv-runner,{label_selector}", namespace="default") - active = sum(1 for p in pods.items if p.status.phase in ("Pending", "Running")) - - available = total - active - logger.debug("Capacity check: label_selector=%s, total=%d, active=%d, available=%d", - label_selector, total, active, available) - return available - - -def get_pod_events(pod_name): - """Get events for a specific pod, sorted by last timestamp.""" - with _init_client() as client: - api = k8s.client.CoreV1Api(client) - events = api.list_namespaced_event(field_selector=f"involvedObject.name={pod_name}", namespace="default") - sorted_events = sorted( - events.items, - key=lambda e: e.last_timestamp or e.event_time or e.metadata.creation_timestamp, - ) - return sorted_events - - -def list_pods(): - """Get all runner pods.""" - with _init_client() as client: - api = k8s.client.CoreV1Api(client) - pods = api.list_namespaced_pod(label_selector="app=rise-riscv-runner", namespace="default") - return pods.items - - -def get_pod_logs(pod_name: str, container: str) -> str | None: - """Get full logs for a container in a pod. Returns log string or None on failure.""" - try: - with _init_client() as client: - api = k8s.client.CoreV1Api(client) - return api.read_namespaced_pod_log( - name=pod_name, - namespace="default", - container=container, - ) - except Exception as e: - logger.debug("Failed to get logs for %s/%s: %s", pod_name, container, e) - return None - - -def get_runner_running_at(pod) -> datetime.datetime | None: - """When the 'runner' container actually began running. Best-effort.""" - for cs in (pod.status.container_statuses or []): - if cs.name == "runner" and cs.state and cs.state.running: - return cs.state.running.started_at - for cond in (pod.status.conditions or []): - if cond.type == "Ready" and cond.status == "True": - return cond.last_transition_time - return None - - -def get_pod_finished_at(pod) -> datetime.datetime | None: - """Latest container termination time for Succeeded/Failed pods.""" - finishes = [] - for cs in (pod.status.container_statuses or []) + (pod.status.init_container_statuses or []): - if cs.state and cs.state.terminated and cs.state.terminated.finished_at: - finishes.append(cs.state.terminated.finished_at) - return max(finishes) if finishes else None - - -def collect_pod_failure_info(pod, reason: FailureReason) -> dict: - """Collect exhaustive diagnostic info from a pod for the workers.failure_info column. - - Gathers container termination/running info, full container logs, and pod events. - Safe to call on Running or Pending pods too (logs are read live). - Callers must pass a FailureReason to describe why the worker is being failed. - """ - assert isinstance(reason, FailureReason), "reason must be a FailureReason enum value" - pod_name = pod.metadata.name - info = { - "version": 2, # bump when the structure changes - "reason": reason.value, - "containers": {}, - "events": [], - "pod_message": pod.status.message, - "pod_reason": pod.status.reason, - } - - # Container termination info + logs (main containers) - for cs in (pod.status.container_statuses or []): - container_info = _extract_container_info(cs) - container_info["logs"] = get_pod_logs(pod_name, cs.name) - info["containers"][cs.name] = container_info - - # Init container termination info + logs (none today, but defensive for future use) - for cs in (pod.status.init_container_statuses or []): - container_info = _extract_container_info(cs) - container_info["logs"] = get_pod_logs(pod_name, cs.name) - info["containers"][cs.name] = container_info - - # Pod events - try: - events = get_pod_events(pod_name) - for ev in events: - ts = ev.last_timestamp or ev.event_time or ev.metadata.creation_timestamp - info["events"].append({ - "type": ev.type, - "reason": ev.reason, - "message": ev.message, - "count": ev.count, - "first_seen": str(ev.first_timestamp) if ev.first_timestamp else None, - "last_seen": str(ts) if ts else None, - }) - except Exception as e: - logger.debug("Failed to get events for %s: %s", pod_name, e) - - return info - - -def _extract_container_info(container_status) -> dict: - """Extract termination info from a V1ContainerStatus.""" - result = { - "exit_code": None, - "reason": None, - "message": None, - } - if container_status.state and container_status.state.terminated: - t = container_status.state.terminated - result["exit_code"] = t.exit_code - result["reason"] = t.reason - result["message"] = t.message - elif container_status.state and container_status.state.waiting: - w = container_status.state.waiting - result["reason"] = w.reason - result["message"] = w.message - return result diff --git a/package-lock.json b/container/package-lock.json similarity index 100% rename from package-lock.json rename to container/package-lock.json diff --git a/package.json b/container/package.json similarity index 100% rename from package.json rename to container/package.json diff --git a/container/requirements.txt b/container/requirements.txt deleted file mode 100644 index 8b822eb..0000000 --- a/container/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -cachetools -Flask==3.1.2 -jwt -kubernetes -psycopg2-binary -requests -waitress diff --git a/container/scheduler.py b/container/scheduler.py deleted file mode 100644 index a310173..0000000 --- a/container/scheduler.py +++ /dev/null @@ -1,903 +0,0 @@ -import datetime -import itertools -import json -import logging -import random -import string -import threading -import traceback -from collections.abc import Callable, Iterable -from typing import Any - -import db -from db import DuplicateRunnerNameException -import k8s -from k8s import FailureReason -import github as gh -from constants import * - -from flask import Flask, request, make_response -from flask.json import dumps as json_dumps - -# Used for /health for now -app = Flask(__name__) - -logger = logging.getLogger(__name__) - -POLL_INTERVAL = 15 - - -def _gh_authenticate_app(installation_id, entity_type, *, job_id=None, - repo_full_name=None, - entity_id=None, entity_name=None): - """Wrap gh.authenticate_app and log every failure to installation_events. - - Successful auths are not logged (gh.authenticate_app's @ttl_cache makes - success a hot path). Only GitHubAPIError failures are recorded — the - cache itself doesn't store exceptions, so transient errors won't poison - subsequent calls. - """ - app_id = GHAPP_PERSONAL_ID if entity_type == EntityType.USER else GHAPP_ORG_ID - try: - return gh.authenticate_app(int(installation_id), app_id) - except gh.GitHubAPIError as e: - outcome = WebhookOutcome.AUTH_404 if e.status_code == 404 else WebhookOutcome.AUTH_OTHER_ERROR - event_str = (f"auth_attempt.{e.status_code}" - if e.status_code == 404 - else "auth_attempt.other_error") - synthetic_payload = { - "installation_id": int(installation_id), - "app_id": app_id, - "entity_type": entity_type.value, - "entity_id": entity_id, - "entity_name": entity_name, - "repository": ({"full_name": repo_full_name} - if repo_full_name else None), - "workflow_job": {"id": job_id} if job_id else None, - "http_status": e.status_code, - "error_message": str(e), - } - try: - db.add_installation_event( - source="scheduler", - event=event_str, - outcome=outcome, - installation_id=int(installation_id), - app_id=app_id, - entity_type=entity_type.value, - entity_id=entity_id, - entity_name=entity_name, - payload=synthetic_payload, - ) - except Exception: - logger.exception("Failed to record auth_attempt event_str=%s installation_id=%s", - event_str, installation_id) - raise - - -def sync_jobs_state(): - """ - Sync job status between GitHub and the jobs table. - - For each active job, check GitHub for its actual status. If GitHub says - completed but database disagrees, mark it completed. If GitHub says in_progress - but database says pending, update to running. - - `gh.authenticate_app` is cached, so we iterate jobs flatly without batching. - """ - jobs = db.get_active_jobs() - if not jobs: - return - - for job in jobs: - assert job['status'] in ['pending', 'running'], f"Job job_id={job['job_id']} is not running or pending, status={job['status']}" - job_id = job["job_id"] - installation_id = job["installation_id"] - entity_type = EntityType(job["entity_type"]) - repo = job.get("repo_full_name") - if not repo: - continue - - try: - token = _gh_authenticate_app( - int(installation_id), entity_type=entity_type, - job_id=job_id, - repo_full_name=repo, - entity_id=job["entity_id"], - entity_name=job["entity_name"], - ) - except gh.GitHubAPIError as e: - if e.status_code == 404: - logger.warning("Installation not found installation_id=%s entity_type=%s, marking job %s failed", - installation_id, entity_type, job_id) - db.mark_job_failed(job_id, { - "version": 1, - "message": f"installation not found for installation_id={installation_id} entity_type={entity_type}", - }) - continue - logger.error("Failed to authenticate for installation installation_id=%s entity_type=%s: %s", - installation_id, entity_type, e) - continue - - try: - gh_job = gh.get_job_info(repo, job_id, token) - except gh.GitHubAPIError as e: - if e.status_code == 404: - logger.warning("Job not found job_id=%s entity=%s entity_id=%s entity_type=%s: marking as failed", - job_id, entity_type, job['entity_id'], job['entity_name']) - db.mark_job_failed(job_id, { - "version": 1, - "message": f"job not found for job_id={job_id} entity={job['entity_name']} entity_id={job['entity_id']} entity_type={entity_type}", - }) - continue - logger.error("Failed to get status for job job_id=%s entity=%s entity_id=%s entity_type=%s: %s", - job_id, entity_type, job['entity_id'], job['entity_name'], e) - continue - - gh_job_status = gh_job.get("status") # queued, in_progress, completed - gh_job_conclusion = gh_job.get("conclusion") # null, success, failure, cancelled, ... - # A non-null conclusion means the job is done, even if status says in_progress - if gh_job_conclusion is not None: - gh_job_status = "completed" - - if gh_job_status == "completed": - logger.info("GH reconcile: job %s is completed on GitHub (was %s in DB)", job_id, job["status"]) - db.mark_job_completed(job_id, gh_job.get("runner_name")) - elif gh_job_status == "in_progress" and job["status"] == "pending": - logger.info("GH reconcile: job %s is in_progress on GitHub (was pending in DB)", job_id) - db.mark_job_running(job_id, gh_job.get("runner_name")) - - -def _gh_runner_key_for_worker(worker): - """Return (installation_id, entity_type, entity_id, gh_runner_target). - - For organizations: gh_runner_target = entity_name. - For users: gh_runner_target = repo_full_name. - """ - entity_type = EntityType(worker["entity_type"]) - target = worker["entity_name"] if entity_type == EntityType.ORGANIZATION else worker["repo_full_name"] - return (int(worker["installation_id"]), entity_type, int(worker["entity_id"]), target) - - -def _get_gh_runners(gh_runner_key, token, gh_runners_by_target): - if gh_runner_key not in gh_runners_by_target: - _, entity_type, _, target = gh_runner_key - try: - if entity_type == EntityType.ORGANIZATION: - group_id = gh.ensure_runner_group(target, token, RUNNER_GROUP_NAME) - raw = gh.list_runners_org_group(token, target, group_id) - else: - raw = [r for r in gh.list_runners_repo(token, target) - if r.get("name", "").startswith(RUNNER_NAME_PREFIX)] - except gh.GitHubAPIError as e: - logger.error("Failed to list GH runners for %s/%s: %s", entity_type, target, e) - gh_runners_by_target[gh_runner_key] = {} - return gh_runners_by_target[gh_runner_key] - gh_runners_by_target[gh_runner_key] = {r["name"]: r for r in raw} - return gh_runners_by_target[gh_runner_key] - - -def _delete_gh_runner(worker_name, token, entity_type, gh_runner_target, runner_id): - """Delete a GH runner by id. Logs on failure, swallows exceptions.""" - try: - if entity_type == EntityType.ORGANIZATION: - gh.delete_runner_org(token, gh_runner_target, runner_id) - else: - gh.delete_runner_repo(token, gh_runner_target, runner_id) - logger.info("Deleted GH runner name=%s id=%s from entity=%s", worker_name, runner_id, gh_runner_target) - return True - except Exception as e: - logger.error("Failed to delete GH runner name=%s id=%s from entity=%s: %s", worker_name, runner_id, gh_runner_target, e) - return False - - -def _fail_and_cleanup(worker, pod, token, entity_type, gh_runner_target, gh_runner: dict | None, reason: FailureReason): - """Mark a worker failed, kill its pod, and remove any stale GH registration. - - If GitHub has a runner for this worker (gh_runner is not None), try to delete - it first. A non-2xx from GitHub (e.g. 422 "runner is busy") is our signal - that GH thinks a job is actually executing — abort cleanup so we don't kill - a worker that is doing useful work we missed signal for. Otherwise proceed: - collect diagnostics, mark the worker failed, and kill the pod so its slot - frees up. Phase 5's grace window later removes the Failed pod. - """ - logger.warning("Health check failed for pod=%s reason=%s", worker["pod_name"], reason.value) - if gh_runner: - if not _delete_gh_runner(worker["pod_name"], token, entity_type, gh_runner_target, gh_runner["id"]): - logger.warning("Aborting cleanup for worker=%s: GitHub refused to delete the runner (may be running a job)", - worker["pod_name"]) - return - try: - failure_info = k8s.collect_pod_failure_info(pod, reason=reason) - except Exception as e: - logger.error("collect_pod_failure_info failed for %s: %s", worker["pod_name"], e) - failure_info = {"version": 2, "reason": reason.value, "collect_error": str(e)} - db.mark_worker_failed(worker["pod_name"], - pod.spec.node_name or worker["k8s_node"], - failure_info, - datetime.datetime.now(datetime.timezone.utc)) - try: - k8s.kill_pod(pod) - except Exception as e: - logger.error("kill_pod failed for %s: %s", pod.metadata.name, e) - - -def _age_seconds(ts): - """Seconds elapsed since ts (a datetime, possibly naive). Returns +inf if ts is None.""" - if ts is None: - return float("inf") - if ts.tzinfo is None: - ts = ts.replace(tzinfo=datetime.timezone.utc) - return (datetime.datetime.now(datetime.timezone.utc) - ts).total_seconds() - - -def _group_by(elements: Iterable, key: Callable[[Any], Any]): - # itertools.groupby requires sorted elements by key before grouping by key - return itertools.groupby(sorted(elements, key=key), key=key) - - -def _sync_workers_state_phase_1_orphan_sweep(pods_by_name, workers_by_name): - """Phase 1 of sync_workers_state: mark workers without any corresponding pods as orphaned.""" - for pod_name, w in workers_by_name.items(): - if pod_name not in pods_by_name and w["status"] in ["pending", "running"]: - # There are no pods for that worker - db.mark_worker_orphaned(pod_name) - - -def _sync_workers_state_phase_2_pod_phase_sync(pods_by_name, workers_by_name): - """Phase 2 of sync_workers_state: synchronize pod phases with worker status in the database.""" - for pod_name, pod in pods_by_name.items(): - if pod.status.phase == "Running" and pod_name in workers_by_name and workers_by_name[pod_name]["status"] in ["pending"]: - db.mark_worker_running(pod_name, pod.spec.node_name, k8s.get_runner_running_at(pod)) - elif pod.status.phase == "Succeeded" and pod_name in workers_by_name and workers_by_name[pod_name]["status"] in ["pending", "running"]: - db.mark_worker_completed(pod_name, pod.spec.node_name, k8s.get_pod_finished_at(pod)) - elif pod.status.phase == "Failed" and pod_name in workers_by_name and workers_by_name[pod_name]["status"] in ["pending", "running"]: - try: - failure_info = k8s.collect_pod_failure_info(pod, reason=FailureReason.POD_FAILED) - except Exception as e: - logger.error("Failed to collect failure info for pod %s: %s", pod.metadata.name, e) - failure_info = { - "version": 2, - "reason": FailureReason.POD_FAILED.value, - "collect_error": str(e), - } - assert failure_info and "version" in failure_info and isinstance(failure_info["version"], int), \ - f"Failed pod {pod_name} requires failure_info with int 'version' field" - db.mark_worker_failed(pod_name, pod.spec.node_name, failure_info, k8s.get_pod_finished_at(pod)) - - -def _sync_workers_state_phase_3_health_checks(pods_by_name, workers_by_name, gh_runners_by_target): - """Phase 3 of sync_workers_state: GitHub runner health checks. - - For pending/running workers grouped by GitHub runner scope: if a Running pod older - than RUNNER_REGISTRATION_TIMEOUT_SECONDS is still missing from GH, or a Pending pod - is older than POD_PENDING_TIMEOUT_SECONDS, kill the pod and mark the worker failed. - Populates ``gh_runners_by_target`` as a side effect for Phase 4 to consume. - """ - # Sort before groupby so workers with the same scope are grouped together - # (groupby only groups adjacent equal keys). - workers_by_gh_runner_key = _group_by((w for w in workers_by_name.values() if w["status"] in ["pending", "running"]), key=_gh_runner_key_for_worker) - for gh_runner_key, workers in workers_by_gh_runner_key: - installation_id, entity_type, entity_id, gh_runner_target = gh_runner_key - try: - token = _gh_authenticate_app( - installation_id, entity_type=entity_type, - entity_id=entity_id, - entity_name=(gh_runner_target if entity_type == EntityType.ORGANIZATION else None), - repo_full_name=(gh_runner_target if entity_type == EntityType.USER else None), - ) - except gh.GitHubAPIError as e: - logger.error("Failed to authenticate for installation_id=%s entity_type=%s gh_runner_target=%s: %s", installation_id, entity_type, gh_runner_target, e) - continue - - # Consume the groupby elements into a list that we can iterate multiple times - workers = list(workers) - - gh_runners = _get_gh_runners(gh_runner_key, token, gh_runners_by_target) - - logger.debug(f"Checking for workers={RUNNER_NAME_PREFIX}%s in runners={RUNNER_NAME_PREFIX}%s for target=%s entity_type=%s", - sorted([w["pod_name"].removeprefix(RUNNER_NAME_PREFIX) for w in workers]), - sorted([r.removeprefix(RUNNER_NAME_PREFIX) for r in gh_runners.keys()]), - gh_runner_target, - entity_type) - - for w in workers: - worker_name = w["pod_name"] - assert worker_name in pods_by_name - pod = pods_by_name[worker_name] - gh_runner = gh_runners.get(worker_name) - - worker_status = w["status"] - runner_status, runner_busy = (gh_runner["status"], gh_runner["busy"]) if gh_runner else (None, None) - - # If the worker is still pending - if (worker_status) == ("pending"): - if _age_seconds(pod.metadata.creation_timestamp) < POD_PENDING_TIMEOUT_SECONDS: - logger.debug("Worker worker=%s worker_status=%s runner_status=%s is still pending", worker_name, worker_status, runner_status) - continue - logger.warning("Worker worker=%s worker_status=%s runner_status=%s is still pending after more than %d seconds, marking as failed", worker_name, worker_status, runner_status, POD_PENDING_TIMEOUT_SECONDS) - _fail_and_cleanup(w, pod, token, entity_type, gh_runner_target, gh_runner, - reason=FailureReason.POD_STUCK_PENDING) - continue - - # If the worker is running but the runner is unknown, it may be that the runner has already self-unregistered after executing a job - elif (worker_status, runner_status) == ("running", None): - if db.job_exists_for_pod(worker_name): - logger.debug("Worker worker=%s worker_status=%s runner_status=%s runner has already run a job and self-unregistered, skipping", worker_name, worker_status, runner_status) - continue - if _age_seconds(w["running_at"]) < RUNNER_REGISTRATION_TIMEOUT_SECONDS: - logger.info("Worker worker=%s worker_status=%s runner_status=%s is not known github runner and may still register", worker_name, worker_status, runner_status) - continue - logger.warning("Worker worker=%s worker_status=%s runner_status=%s is not known github runner and failed to register in %d seconds, marking as failed", worker_name, worker_status, runner_status, RUNNER_REGISTRATION_TIMEOUT_SECONDS) - _fail_and_cleanup(w, pod, token, entity_type, gh_runner_target, gh_runner, - reason=FailureReason.RUNNER_NEVER_REGISTERED) - continue - - # If the worker is running but the runner isn't running yet - elif (worker_status, runner_status) == ("running", "offline"): - if _age_seconds(w["running_at"]) < RUNNER_REGISTRATION_TIMEOUT_SECONDS: - logger.info("Worker worker=%s worker_status=%s runner_status=%s is known github runner and may still register", worker_name, worker_status, runner_status) - continue - logger.warning("Worker worker=%s worker_status=%s runner_status=%s is known github runner and failed to register in %d seconds, marking as failed", worker_name, worker_status, runner_status, RUNNER_REGISTRATION_TIMEOUT_SECONDS) - _fail_and_cleanup(w, pod, token, entity_type, gh_runner_target, gh_runner, - reason=FailureReason.RUNNER_NEVER_REGISTERED) - - # If the worker and runner are running but the runner hasn't picked up a job yet - elif (worker_status, runner_status, runner_busy) == ("running", "online", False): - if _age_seconds(w["running_at"]) < RUNNER_PENDING_TIMEOUT_SECONDS: - logger.info("Worker worker=%s worker_status=%s runner_status=%s is known github runner and may still pick up a job", worker_name, worker_status, runner_status) - continue - logger.warning("Worker worker=%s worker_status=%s runner_status=%s is known github runner and failed to pick up a job in %d seconds, marking as failed", worker_name, worker_status, runner_status, RUNNER_PENDING_TIMEOUT_SECONDS) - _fail_and_cleanup(w, pod, token, entity_type, gh_runner_target, gh_runner, - reason=FailureReason.RUNNER_IDLE) - continue - - # If the worker and runner are running and the runner has picked up a job - elif (worker_status, runner_status, runner_busy) == ("running", "online", True): - # Nothing to do, everything is working! - pass - - # If the worker is running, but the status of the runner is unknown - elif (worker_status) == ("running"): - assert runner_status not in [None, "offline", "online"] - logger.info("Worker worker=%s worker_status=%s runner_status=%s has unkown github status", worker_name, worker_status, runner_status) - if _age_seconds(w["running_at"]) < RUNNER_REGISTRATION_TIMEOUT_SECONDS: - logger.info("Worker worker=%s worker_status=%s runner_status=%s is known github runner and may still register", worker_name, worker_status, runner_status, ) - continue - logger.warning("Worker worker=%s worker_status=%s runner_status=%s is known github runner and in unknown state for after %d seconds, marking as failed", worker_name, worker_status, runner_status, RUNNER_REGISTRATION_TIMEOUT_SECONDS) - _fail_and_cleanup(w, pod, token, entity_type, gh_runner_target, gh_runner, - reason=FailureReason.RUNNER_NEVER_REGISTERED) - continue - - else: # pragma: no cover - assert False, f"unexpected worker status (worker_status={worker_status!r}, runner_status={runner_status!r}, runner_busy={runner_busy!r}) for worker={worker_name}" - - -def _sync_workers_state_phase_4_gh_cleanup(workers_by_name, gh_runners_by_target): - """Phase 4 of sync_workers_state: delete orphan/completed/failed runners on GitHub. - - Any runner matching RUNNER_NAME_PREFIX whose worker row is terminal or missing - gets deleted on GitHub. Reads the cache populated by Phase 3. - """ - for gh_runner_key, gh_runners in gh_runners_by_target.items(): - installation_id, entity_type, entity_id, gh_runner_target = gh_runner_key - try: - token = _gh_authenticate_app( - installation_id, entity_type=entity_type, - entity_id=entity_id, - entity_name=(gh_runner_target if entity_type == EntityType.ORGANIZATION else None), - repo_full_name=(gh_runner_target if entity_type == EntityType.USER else None), - ) - except gh.GitHubAPIError as e: - logger.error("Failed to authenticate for installation_id=%s entity_type=%s gh_runner_target=%s: %s", installation_id, entity_type, gh_runner_target, e) - continue - - for name, gh_runner in gh_runners.items(): - if not name.startswith(RUNNER_NAME_PREFIX): - continue - if name in workers_by_name and workers_by_name[name]["status"] in ("completed", "failed"): - logging.info("Runner runner=%s has matching completed worker=%s", name, name) - _delete_gh_runner(name, token, entity_type, gh_runner_target, gh_runner["id"]) - elif name not in workers_by_name: - logging.info("Runner runner=%s is unknown", name) - _delete_gh_runner(name, token, entity_type, gh_runner_target, gh_runner["id"]) - - -def _sync_workers_state_phase_5_delete_terminal_pods(pods_by_name): - """Phase 5 of sync_workers_state: delete completed|failed pods after the grace period. - - Pods in Succeeded/Failed phase are kept for POD_DELETE_GRACE_SECONDS so operators - can still inspect them via ``kubectl logs``; once the grace window elapses they - are deleted from the cluster. - """ - now = datetime.datetime.now(datetime.timezone.utc) - for pod_name, pod in pods_by_name.items(): - if pod.status.phase not in ("Succeeded", "Failed"): - continue - finished = k8s.get_pod_finished_at(pod) or pod.metadata.creation_timestamp - if finished and finished.tzinfo is None: - finished = finished.replace(tzinfo=datetime.timezone.utc) - elapsed = (now - finished).total_seconds() if finished else float("inf") - if elapsed < POD_DELETE_GRACE_SECONDS: - continue - try: - k8s.delete_pod(pod) - except Exception as e: - logger.error("Failed to delete pod %s: %s", pod.metadata.name, e) - - -def sync_workers_state(): - """Reconcile worker state across Kubernetes, GitHub, and the workers table.""" - pods_by_name = {p.metadata.name: p for p in k8s.list_pods()} - - # GitHub runners registered to a given (installation_id, entity_type, entity_id, target). - # Phase 3 populates this lazily; Phase 4 reads from it. - gh_runners_by_target: dict[tuple, dict] = {} - - workers_by_name = {w["pod_name"]: w for w in db.get_workers_for_reconcile()} - # 1. Orphan sweep — workers in `pending`/`running` with no matching k8s pod - # are marked completed. - _sync_workers_state_phase_1_orphan_sweep(pods_by_name, workers_by_name) - # 2. Pod phase sync — k8s Running/Succeeded/Failed phases propagate to the - # workers table (setting running_at / completed_at / failure_info). - _sync_workers_state_phase_2_pod_phase_sync(pods_by_name, workers_by_name) - - workers_by_name = {w["pod_name"]: w for w in db.get_workers_for_reconcile()} # refresh workers - # 3. Health checks — for pending/running workers grouped by GitHub runner - # scope: if a Running pod older than RUNNER_REGISTRATION_TIMEOUT_SECONDS - # is still missing from GH, or a Pending pod is older than - # POD_PENDING_TIMEOUT_SECONDS, kill the pod (activeDeadlineSeconds=1) - # so it transitions to Failed; the worker is marked failed with - # diagnostics. If GitHub refuses to delete the runner (e.g. 422 "busy"), - # abort cleanup for that worker — it may genuinely be running a job. - _sync_workers_state_phase_3_health_checks(pods_by_name, workers_by_name, gh_runners_by_target) - - workers_by_name = {w["pod_name"]: w for w in db.get_workers_for_reconcile()} # refresh workers - # 4. GitHub-side cleanup — any runner matching RUNNER_NAME_PREFIX whose - # worker row is terminal or missing gets deleted on GitHub. - _sync_workers_state_phase_4_gh_cleanup(workers_by_name, gh_runners_by_target) - # 5. Delete k8s pods in Succeeded/Failed phase after POD_DELETE_GRACE_SECONDS - # have elapsed since container termination, so operators can still - # `kubectl logs` them during the grace window. - _sync_workers_state_phase_5_delete_terminal_pods(pods_by_name) - - - -def demand_match(): - """ - Match demand (pending jobs) with supply (k8s workers). - - Iterates pending jobs in FIFO order. For each job, checks: - 1. Pool demand vs supply — skip if demand already met - 2. Org max_workers cap — skip if org is at capacity - 3. K8s node capacity — skip if no available slot - Then provisions a runner. - """ - pending_jobs = db.get_pending_jobs() - if not pending_jobs: - logger.debug("No pending jobs to process") - return - - logger.info("Processing %d pending jobs: [%s]", len(pending_jobs), ', '.join([str(j["job_id"]) for j in pending_jobs])) - - jobs_by_pool = _group_by(pending_jobs, key=lambda j: j["k8s_pool"]) - - for k8s_pool, jobs in jobs_by_pool: - available_slots = k8s.get_available_slots(label_selector=f"riseproject.dev/board={k8s_pool}") - logger.info("Capacity for k8s_pool=%s available_slots=%s", k8s_pool, available_slots) - if available_slots <= 0: - continue - - for job in jobs: - assert available_slots >= 1 - - job_id = job["job_id"] - if job.get("status") != "pending": - logger.info("Job %s status is %s, not pending, skipping", job_id, job.get("status")) - continue - - k8s_pool = job["k8s_pool"] - k8s_image = job["k8s_image"] - installation_id = job["installation_id"] - entity_name = job["entity_name"] - labels = job["job_labels"] - entity_type = EntityType(job["entity_type"]) - entity_id = job["entity_id"] - repo_full_name = job["repo_full_name"] - provider = job["provider"] - - # Check demand vs supply, matched by entity_id + job_labels - job_count, worker_count = db.get_pool_demand(entity_id, labels) - if job_count <= worker_count: - logger.info("Demand met for entity=%s entity_id=%s entity_type=%s labels=%s, jobs_count=%d workers_count=%d", - entity_name, entity_id, entity_type, labels, job_count, worker_count) - continue - - # Check max_workers cap - entity_config = ENTITY_CONFIG.get(int(entity_id), {"max_workers": 20}) - max_workers = entity_config.get("max_workers") - if max_workers is not None: - entity_worker_count = db.get_total_workers_for_entity(entity_id) - if entity_worker_count >= max_workers: - logger.info("Max workers allocated for entity=%s entity_id=%s entity_type=%s labels=%s workers_count=%d max_workers=%d)", - entity_name, entity_id, entity_type, labels, entity_worker_count, max_workers) - continue - - # Reserve name in DB first — detects collision before creating k8s pod - runner_name = None - for _ in range(5): # max retries for name collision - suffix = ''.join(random.choices(string.ascii_lowercase + string.digits, k=9)) - candidate = f"{RUNNER_NAME_PREFIX}{suffix}" - try: - db.add_worker(provider, entity_id, entity_name, entity_type.value, - installation_id, - repo_full_name if entity_type == EntityType.USER else None, - k8s_pool, candidate, - job_labels=labels, k8s_image=k8s_image) - runner_name = candidate - break - except DuplicateRunnerNameException: - logger.warning("Runner name %s collision, regenerating", candidate) - continue - - if runner_name is None: - logger.error("Failed to generate unique runner name for entity=%s entity_id=%s entity_type=%s pool=%s after retries", entity_name, entity_id, entity_type, k8s_pool) - continue - - # Name reserved in DB, now safe to provision - try: - token = _gh_authenticate_app( - int(installation_id), entity_type=entity_type, - entity_id=entity_id, - entity_name=entity_name, - repo_full_name=repo_full_name, - ) - - if entity_type == EntityType.ORGANIZATION: - group_id = gh.ensure_runner_group(entity_name, token, RUNNER_GROUP_NAME) - jit_config = gh.create_jit_runner_config_org(token, group_id, labels, entity_name, runner_name) - else: - jit_config = gh.create_jit_runner_config_repo(token, labels, repo_full_name, runner_name) - - k8s.provision_runner(jit_config, runner_name, k8s_image, k8s_pool, entity_id, entity_name) - - logger.info("Provisioned runner %s for entity=%s entity_id=%s entity_type=%s pool=%s", runner_name, entity_name, entity_id, entity_type, k8s_pool) - - except Exception as e: - logger.error("Failed to provision runner %s for entity=%s entity_id=%s entity_type=%s pool=%s, error: %s", runner_name, entity_name, entity_id, entity_type, k8s_pool, str(e)) - failure_info = { - "version": 2, # bump when the structure changes - "reason": FailureReason.POD_ALLOCATION_FAILURE.value, - "containers": {}, - "events": [], - "pod_message": None, - "pod_reason": None, - } - db.mark_worker_failed(runner_name, k8s_node=None, failure_info=failure_info, completed_at=None) - - available_slots -= 1 - if available_slots == 0: - logger.debug("Capacity for k8s_pool=%s is now 0", k8s_pool) - break - - -# --- HTTP Handlers --- - -@app.route("/health", methods=['GET']) -def health(): - return "ok" - - -_STATUS_COLORS = {"pending": "#ccc504", "running": "#2563eb", "completed": "#16a34a", "failed": "#d90606"} - -def _format_status(status): - color = _STATUS_COLORS.get(status, "#666") - return f'[{status:9s}]' - -def _format_labels(job_labels): - """Format job_labels for display. Handles both list and JSON string.""" - if isinstance(job_labels, str): - labels = json.loads(job_labels) - else: - labels = job_labels or [] - return ('[' + ", ".join(labels) + ']') if labels else "" - - -def _format_timestamp(created_at): - """Format a created_at value (datetime or unix float string) for display.""" - if not created_at: - return "?" - if isinstance(created_at, datetime.datetime): - return created_at.strftime("%Y-%m-%d %H:%M:%S UTC") - return datetime.datetime.fromtimestamp(float(created_at), tz=datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC") - - -def render_job(job) -> str: - status = _format_status(job["status"]) - job_id = job["job_id"] - repo = job["repo_full_name"] - html_url = job["html_url"] - labels = _format_labels(job["job_labels"]) - pod = job["k8s_pod"] or "" - created_str = _format_timestamp(job.get("created_at")) - link = f'{repo}#{job_id}' if html_url else f"{repo}#{job_id}" - return f'{status} {created_str} {labels} {pod} {link}' - - -def render_worker(worker) -> list[str]: - status = _format_status(worker["status"]) - created_str = _format_timestamp(worker['created_at']) - labels = _format_labels(worker['job_labels']) - pod = worker['pod_name'] - node = worker['k8s_node'] or '' - lines = [f'{status} {created_str} {labels} {pod} (node: {node})'] - if worker["status"] == "failed" and worker["failure_info"]: - failure_info = worker["failure_info"] - version = failure_info.get("version", 1) - if version == 1: - pass - else: - if failure_info.get("reason"): - lines.append(f" Reason: {failure_info['reason']}") - pod_reason = failure_info.get("pod_reason") - pod_message = failure_info.get("pod_message") - if pod_reason or pod_message: - lines.append(f" Pod: {pod_reason or '?'} {pod_message or ''}".rstrip()) - for name, container in (failure_info.get("containers") or {}).items(): - exit_code = container.get("exit_code") - c_reason = container.get("reason") or "?" - c_message = container.get("message") or "" - lines.append(f" Container {name}: exit={exit_code} {c_reason} {c_message}".rstrip()) - logs = container.get("logs") - if logs: - for log_line in logs.splitlines(): - lines.append(f" | {log_line}") - for ev in failure_info.get("events") or []: - ts = ev.get("last_seen") or ev.get("first_seen") or "unknown" - lines.append(f" {ts} [{ev['type']}] {ev['reason']}: {ev['message']}") - else: - try: - events = k8s.get_pod_events(worker["pod_name"]) - if events: - for ev in events: - ts = ev.last_timestamp or ev.event_time or ev.metadata.creation_timestamp - ts_str = ts.strftime("%Y-%m-%d %H:%M:%S") if ts else "unknown" - lines.append(f" {ts_str} [{ev.type}] {ev.reason}: {ev.message}") - else: - lines.append(f" Events: (none)") - except Exception: - lines.append(f" Events: (error fetching)") - - return lines - -def _wants_json(): - return request.path.endswith('.json') or request.accept_mimetypes.best == 'application/json' - - -def _json_response(data): - return make_response(json_dumps(data, default=str), 200, {"Content-Type": "application/json"}) - - -@app.route("/usage", methods=['GET']) -@app.route("/usage.json", methods=['GET']) -def usage(): - active_jobs, active_workers = db.get_active_jobs_and_workers() - - if _wants_json(): - return _json_response({"jobs": active_jobs, "workers": active_workers}) - - # HTML: group by (entity_name, job_labels) for display - groups = {} - for job in active_jobs: - labels_key = json.dumps(job["job_labels"]) - key = (job["entity_id"], labels_key) - if key not in groups: - groups[key] = {"entity_name": job["entity_name"], "k8s_pool": job["k8s_pool"], "jobs": [], "workers": []} - groups[key]["jobs"].append(job) - - for worker in active_workers: - labels_key = json.dumps(worker["job_labels"]) - key = (worker["entity_id"], labels_key) - if key not in groups: - groups[key] = {"entity_name": worker["entity_name"], "k8s_pool": worker["k8s_pool"], "jobs": [], "workers": []} - groups[key]["workers"].append(worker) - - lines = [] - for (_, labels_key), group in sorted(groups.items()): - labels_display = _format_labels(labels_key) - lines.append(f"=== {group['entity_name']} / {labels_display} ({group['k8s_pool']}) ===") - if group["jobs"]: - lines.append(f" Jobs ({len(group['jobs'])}):") - for job in sorted(group["jobs"], key=lambda j: j["created_at"]): - lines.append(f' - {render_job(job)}') - else: - lines.append(" Jobs: none") - if group["workers"]: - lines.append(f" Workers ({len(group['workers'])}):") - for worker in sorted(group["workers"], key=lambda w: w["created_at"]): - lines.append(f' - {'\n '.join(render_worker(worker))}') - else: - lines.append(" Workers: none") - lines.append("") - if not lines: - lines.append("No active pools.") - return make_response(f"{'Usage - Prod' if PROD else 'Usage - Staging'}
{chr(10).join(lines)}
", 200, {"Content-Type": "text/html"}) - - -def _parse_date_param(value: str | None) -> str | None: - """Parse a date parameter. Supports ISO dates (YYYY-MM-DD) and relative (-Xd for X days ago).""" - if not value: - return None - import re - m = re.match(r'^-(\d+)d$', value) - if m: - days_ago = int(m.group(1)) - return (datetime.date.today() - datetime.timedelta(days=days_ago)).isoformat() - return value - - -def _build_link_header(base_url: str, page: int, per_page: int, total: int, - extra_params: dict[str, str] | None = None) -> str: - """Build a Link header for pagination, matching GitHub API format. - - See: https://docs.github.com/en/rest/using-the-rest-api/using-pagination-in-the-rest-api - """ - last_page = max(0, (total - 1) // per_page) - - def _url(p: int) -> str: - params = f"page={p}&per_page={per_page}" - if extra_params: - for k, v in extra_params.items(): - params += f"&{k}={v}" - return f"{base_url}?{params}" - - links = [] - if page > 0: - links.append(f'<{_url(0)}>; rel="first"') - links.append(f'<{_url(page - 1)}>; rel="prev"') - if page < last_page: - links.append(f'<{_url(page + 1)}>; rel="next"') - links.append(f'<{_url(last_page)}>; rel="last"') - return ", ".join(links) - - -@app.route("/history", methods=['GET']) -@app.route("/history.json", methods=['GET']) -@app.route("/jobs", methods=['GET']) -@app.route("/jobs.json", methods=['GET']) -def jobs(): - start = _parse_date_param(request.args.get("start")) - end = _parse_date_param(request.args.get("end")) - page = request.args.get("page", 0, type=int) - per_page = request.args.get("per_page", 100, type=int) - - if start is not None: - try: - datetime.date.fromisoformat(start) - except: - return make_response('invalid parameter start, must be YYYY-MM-DD', 400) - if end is not None: - try: - datetime.date.fromisoformat(end) - except: - return make_response('invalid parameter end, must be YYYY-MM-DD', 400) - if page < 0: - return make_response('invalid parameter page, must be >= 0', 400) - if per_page <= 0: - return make_response('invalid parameter per_page, must be > 0', 400) - - jobs, total = db.get_all_jobs(start=start, end=end, page=page, per_page=per_page) - - if _wants_json(): - resp = _json_response(jobs) - extra = {} - if start: - extra["start"] = start - if end: - extra["end"] = end - link = _build_link_header(request.base_url.split('?')[0], page, per_page, total, extra) - if link: - resp.headers["link"] = link - return resp - - # HTML - lines = [] - for job in jobs: - lines.append(render_job(job)) - if not lines: - lines = ["No jobs found."] - - return make_response(f"{'History - Prod' if PROD else 'History - Staging'}
{chr(10).join(lines)}
", 200, {"Content-Type": "text/html"}) - - -@app.route("/workers", methods=['GET']) -@app.route("/workers.json", methods=['GET']) -def workers(): - start = _parse_date_param(request.args.get("start")) - end = _parse_date_param(request.args.get("end")) - page = request.args.get("page", 0, type=int) - per_page = request.args.get("per_page", 100, type=int) - - if start is not None: - try: - datetime.date.fromisoformat(start) - except: - return make_response('invalid parameter start, must be YYYY-MM-DD', 400) - if end is not None: - try: - datetime.date.fromisoformat(end) - except: - return make_response('invalid parameter end, must be YYYY-MM-DD', 400) - if page < 0: - return make_response('invalid parameter page, must be >= 0', 400) - if per_page <= 0: - return make_response('invalid parameter per_page, must be > 0', 400) - - workers, total = db.get_all_workers(start=start, end=end, page=page, per_page=per_page) - - if _wants_json(): - resp = _json_response(workers) - extra = {} - if start: - extra["start"] = start - if end: - extra["end"] = end - link = _build_link_header(request.base_url.split('?')[0], page, per_page, total, extra) - if link: - resp.headers["link"] = link - return resp - - # HTML - lines = [] - for worker in workers: - lines.extend(render_worker(worker)) - if not lines: - lines = ["No workers found."] - - return make_response(f"{'Workers - Prod' if PROD else 'Workers - Staging'}
{chr(10).join(lines)}
", 200, {"Content-Type": "text/html"}) - - -def _scheduler_iteration(): - # Serialize demand matching across scheduler containers: hold one DB connection - # for the full schduler and lock the workers table exclusively. Thread-local - # caching in db.py ensures all nested db calls reuse this connection - # so they respect the lock without self-deadlocking on the pool. - with db.hold_connection() as conn: - with conn.cursor() as cur: - cur.execute("LOCK TABLE workers IN EXCLUSIVE MODE") - - sync_jobs_state() - sync_workers_state() - demand_match() - - -if __name__ == "__main__": - # Set the logging level for all loggers to INFO - logging.basicConfig( - level=logging.getLevelNamesMapping()[os.environ.get("LOGLEVEL", "INFO")], - format='%(pathname)s:%(lineno)d::%(funcName)s: [%(levelname)s] %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - - # Ensure PostgreSQL schema/tables exist - db.ensure_schema() - - def http_worker(): - from waitress import serve - - HOST = "0.0.0.0" - PORT = 8080 - - print(f"Starting server on http://{HOST}:{PORT}") - serve(app, host=HOST, port=PORT) - - http_thread = threading.Thread(target=http_worker, daemon=True) - http_thread.start() - - while True: - try: - _scheduler_iteration() - except Exception as e: - logger.error("Scheduler error: %s\n%s", e, traceback.format_exc()) - - db.wait_for_job(POLL_INTERVAL) diff --git a/serverless.yml b/container/serverless.yml similarity index 71% rename from serverless.yml rename to container/serverless.yml index 8bc445e..6c31a1b 100644 --- a/serverless.yml +++ b/container/serverless.yml @@ -47,9 +47,6 @@ custom: PROD_URL: ${env:PROD_URL} STAGING_URL: ${env:STAGING_URL} LOGLEVEL: ${self:custom.${self:provider.stage}.loglevel} - # Cutover knobs: populate once the ghfe-go function URL is known and - # you're ready to forward an entity's workflow_job webhooks to Go. - GO_GHFE_URL: ${env:GO_GHFE_URL} # Health check configuration healthCheck: type: http @@ -58,29 +55,6 @@ custom: failureThreshold: 3 # VPC for PostgreSQL access privateNetworkId: "58fa41d0-f6a4-4b6f-8f65-b788563842c1" - ghfe-go: - registryImage: ${env:REGISTRY}/${env:IMAGE}:ghfe-${self:custom.${self:provider.stage}.container-tag}-go - port: 8080 - cpuLimit: 500 - memoryLimit: 512 - secret: - GHAPP_WEBHOOK_SECRET: ${env:GHAPP_WEBHOOK_SECRET} - GHAPP_ORG_PRIVATE_KEY: ${env:GHAPP_ORG_PRIVATE_KEY} - GHAPP_PERSONAL_PRIVATE_KEY: ${env:GHAPP_PERSONAL_PRIVATE_KEY} - K8S_KUBECONFIG: ${env:K8S_KUBECONFIG} - POSTGRES_URL: ${env:POSTGRES_URL} - TRACE_API_SECRET: ${env:TRACE_API_SECRET} - env: - PROD: ${self:custom.${self:provider.stage}.prod} - PROD_URL: ${env:PROD_URL} - STAGING_URL: ${env:STAGING_URL} - LOGLEVEL: ${self:custom.${self:provider.stage}.loglevel} - healthCheck: - type: http - httpPath: /health - interval: 10s - failureThreshold: 3 - privateNetworkId: "58fa41d0-f6a4-4b6f-8f65-b788563842c1" scheduler: registryImage: ${env:REGISTRY}/${env:IMAGE}:scheduler-${self:custom.${self:provider.stage}.container-tag} port: 8080 diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index edebd18..0000000 --- a/pyproject.toml +++ /dev/null @@ -1,15 +0,0 @@ -[tool.coverage.run] -source = ["container"] -branch = true - -[tool.coverage.report] -show_missing = true -skip_covered = false -exclude_lines = [ - "pragma: no cover", - "if __name__ == .__main__.:", - "raise NotImplementedError", -] - -[tool.pytest.ini_options] -addopts = "--cov --cov-report=term-missing --cov-report=xml" diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index 82e78d7..0000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,7 +0,0 @@ --r container/requirements.txt - -requests-mock==1.12.1 -pytest -pytest-cov -pytest-xdist -diff-cover