diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..6dca1da --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,35 @@ +--- +name: Bug report +about: Report incorrect behaviour +title: "" +labels: bug +assignees: "" +--- + +## What happened + +A clear description of the bug. + +## Expected behaviour + +What you expected instead. + +## Reproduction + +A minimal snippet. Mock the client with `snowflake_sql_api.testing` if the issue +does not need a live account (see docs/testing.md): + +```python +# ... +``` + +## Environment + +- `snowflake-sql-api` version: +- Python version: +- OS: + +## Additional context + +Tracebacks (redact secrets), Snowflake `code` / `sqlState` if relevant, anything +else that helps. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000..a2bca8f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,24 @@ +--- +name: Feature request +about: Suggest an enhancement +title: "" +labels: enhancement +assignees: "" +--- + +## Problem + +What are you trying to do that the library does not support today? + +## Proposed solution + +What you would like to see. Keep in mind the project's design principles +(pure-Python, small footprint, vendor-neutral, sync/async parity). + +## Alternatives considered + +Other approaches, existing workarounds, or related libraries. + +## Additional context + +Anything else (links to Snowflake SQL API docs, example use case). diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..3d058c2 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,22 @@ +## Summary + +What changed and why. Write for the reviewer and for downstream consumers, not +for yourself: lead with the user-visible behaviour, not the implementation +narrative. + +## Backwards compatibility + +This is a library; downstream code depends on it. State the impact: + +- [ ] No public API change, or +- [ ] Public API changed (describe it, and the migration for callers) + +Do not add compatibility shims for behaviour that never shipped; just change the +code. + +## Test plan + +- [ ] `coverage run -m pytest && coverage report` passes (coverage holds >= 89%) +- [ ] `ruff check`, `black --check`, and `mypy` pass (or `pre-commit run --all-files`) +- [ ] New behaviour is covered by tests; each fixed bug has a `test_regression_*` +- [ ] Public API changes documented in the README / `docs/` diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..af04d92 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,20 @@ +version: 2 +updates: + # Python dependencies declared in pyproject.toml. + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" + open-pull-requests-limit: 5 + labels: + - "dependencies" + + # GitHub Actions used by the CI and publish workflows. + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + open-pull-requests-limit: 5 + labels: + - "dependencies" + - "github-actions" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 59eab31..8d50c9a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,6 +15,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + with: + persist-credentials: false - uses: actions/setup-python@v5 with: python-version: "3.12" @@ -38,6 +40,8 @@ jobs: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4 + with: + persist-credentials: false - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -46,5 +50,10 @@ jobs: python -m pip install --upgrade pip pip install -e '.[dev]' - name: pytest - # Coverage gate raised to >= 89% from Phase 2 (--cov-fail-under). - run: pytest --cov=snowflake_sql_api --cov-report=term-missing + # Use `coverage run` (not `pytest --cov`) so tracing starts before the + # package's own pytest11 plugin (snowflake_sql_api.testing) is imported; + # otherwise import-time lines read as uncovered. The 89% gate lives in + # pyproject.toml [tool.coverage.report] fail_under. + run: | + coverage run -m pytest + coverage report diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..db17372 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,82 @@ +name: Publish to PyPI + +# Tag-driven release. Push a `vX.Y.Z` tag and this builds the sdist/wheel +# (version derived from the tag by hatch-vcs) and uploads to PyPI via OIDC +# trusted publishing. See RELEASING.md. +on: + push: + tags: + - "v*" + +jobs: + lint: + name: Lint and type-check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + - name: Install + run: | + python -m pip install --upgrade pip + pip install -e '.[dev]' + - name: ruff + run: ruff check snowflake_sql_api tests + - name: black + run: black --check snowflake_sql_api tests + - name: mypy + run: mypy snowflake_sql_api + + test: + name: Test (py${{ matrix.python-version }}) + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install + run: | + python -m pip install --upgrade pip + pip install -e '.[dev]' + - name: pytest + # See CI: `coverage run` so tracing precedes the pytest11 plugin import. + run: | + coverage run -m pytest + coverage report + + publish: + name: Build and publish + needs: [lint, test] + runs-on: ubuntu-latest + environment: pypi + permissions: + id-token: write # OIDC trusted publishing; no API token needed + steps: + - uses: actions/checkout@v4 + with: + # hatch-vcs needs the tag (and history) to derive the version. + fetch-depth: 0 + persist-credentials: false + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + - name: Install build tooling + run: | + python -m pip install --upgrade pip + pip install build twine + - name: Build sdist and wheel + run: python -m build + - name: Check distribution metadata + run: twine check dist/* + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..200f675 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,39 @@ +# Pre-commit hooks for snowflake-sql-api. Install with `pre-commit install` +# (the hook then runs on every commit); run the full set with +# `pre-commit run --all-files`. Hook versions track the `[dev]` floors in +# pyproject.toml. +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.15.17 + hooks: + - id: ruff-check + + - repo: https://github.com/psf/black + # Last black that still runs on Python 3.9 (the project's runtime floor); + # 25.12.0+ require 3.10+. Kept in sync with the [dev] cap in pyproject.toml. + rev: 25.9.0 + hooks: + - id: black + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-toml + - id: check-merge-conflict + - id: check-added-large-files + # Keypair auth: belt-and-braces with the *.p8 / *.pem .gitignore rules. + - id: detect-private-key + + # mypy runs as a local hook (not the mirror) so it resolves project imports + # and stubs from the installed dev environment, matching the CI invocation. + - repo: local + hooks: + - id: mypy + name: mypy + entry: mypy snowflake_sql_api + language: system + types: [python] + pass_filenames: false diff --git a/AGENTS.md b/AGENTS.md index 5e9e399..43dcd80 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -38,6 +38,7 @@ Module layout under `snowflake_sql_api/` (clean separation, `py.typed`): | `aclient.py` | Asynchronous client (same surface, `await`-based) | | `row_mapping.py` | Optional dataclass / Pydantic row mapping | | `cli.py` | Command-line interface (`snowflake-sql-api query ...`) | +| `testing.py` | Shipped test helper: `FakeSnowflake` (httpx.MockTransport), `make_client`/`make_async_client`, pytest fixtures (pytest11) | ### Dependencies @@ -52,21 +53,23 @@ than failing at import time. ## Development Commands ```bash -# Install in editable mode with dev tooling +# Install in editable mode with dev tooling, then wire the git hooks pip install -e '.[dev]' +pre-commit install -# Run tests with coverage -pytest --cov=snowflake_sql_api --cov-report=term-missing +# Run tests with coverage (NOT `pytest --cov`; see Known Quirks) +coverage run -m pytest && coverage report # Run only the known-bug regression tests pytest -k regression -# Lint / format / type-check +# Lint / format / type-check (or run all hooks at once) ruff check snowflake_sql_api tests black --check snowflake_sql_api tests mypy snowflake_sql_api +pre-commit run --all-files -# Build a distribution +# Build a distribution (version comes from the git tag via hatch-vcs) python -m build ``` @@ -89,18 +92,90 @@ private keys, `.gitignore` excludes `*.pem` / `*.p8` / `*private_key*`. - **Correctness over surface.** Type coercion and partition handling are correctness-critical, prefer well-tested core behavior to breadth. +## Known Quirks + +Behaviour that looks wrong but is intentional. Do not "fix" these without reading +the linked regression test first. + +- **Account locator: claim vs host** (`auth.py`). The JWT claim account + (`iss`/`sub`) strips the region/cloud suffix and uppercases + (`xy12345.ap-southeast-2` -> `XY12345`); the API host keeps the full account + (`xy12345.ap-southeast-2.snowflakecomputing.com`). Conflating them breaks JWT + validation. `normalize_account_locator` vs `account_hostname`. Regression: + `test_regression_bug1`. +- **`result(poll=False)` raises on 202** (`client.py` / `aclient.py` + `_collect`). A still-running async statement must raise `ResultNotReady`, never + return its in-progress HTTP 202 body as if it were a result set. Regression: + `test_regression_bug3`. +- **Fetch every partition, in order** (`pagination.py`). `query` returns + partition 0 (inline) plus partitions 1..N (fetched by index). Stopping at + partition 0 silently truncates large results. Regression: + `test_regression_bug4`. +- **`on_query` streaming hook is deferred** to the v0.2.0 toolkit + (`query_stream`, Phase 8). The hook fires for `query`/`execute`/`submit` today; + there is no streaming path yet, so no regression test until the feature lands + (this is spike bug #2, intentionally not yet covered). +- **No PEP 604 unions at runtime** (py3.9 floor). ruff's `UP` (pyupgrade) rule is + omitted on purpose: it would rewrite `Optional[...]` / `Union[...]` to + `X | None`, which raises at import time on 3.9 for typing generics (PEP 604 on + generics is 3.10+). Keep `from __future__ import annotations` plus + `Optional`/`Union`. +- **mypy `python_version = "3.10"` vs the 3.9-3.13 matrix.** 3.10 is the lowest + this mypy accepts; true 3.9 runtime compatibility is enforced by the pytest + matrix, which imports every module under 3.9. +- **Coverage uses `coverage run`, not `pytest --cov`.** The package ships a pytest + plugin via the `pytest11` entry point, so `snowflake_sql_api.testing` (and the + whole package) is imported at plugin-load time, before pytest-cov starts + tracing. `pytest --cov` then reports import-time lines as uncovered (~20 points + lost). `coverage run -m pytest` starts tracing first. The 89% gate lives in + `pyproject.toml` `[tool.coverage.report] fail_under`. + ## Testing - Unit tests mock the HTTP layer; no network access required for the default suite. +- Mock the client in your own tests with the shipped `snowflake_sql_api.testing` + helper (`FakeSnowflake` + `make_client`/`make_async_client`, or the + auto-registered `fake_snowflake` / `snowflake_client` / `async_snowflake_client` + fixtures). No respx. See `docs/testing.md`. - Each fixed bug gets a named regression test (`test_regression_*`) so it cannot silently return. -- Target coverage: >= 89%, enforced in CI across Python 3.9-3.13. +- Target coverage: >= 89%, enforced across Python 3.9-3.13. Run with + `coverage run -m pytest && coverage report` (see Known Quirks). + +## Common Mistakes + +- Hand-editing a version string. The version comes from the git tag (hatch-vcs); + `_version.py` is generated and gitignored. A feature PR must not touch it. See + `RELEASING.md`. +- Running `pytest --cov` and reacting to the false coverage drop. Use + `coverage run -m pytest`. +- Adding a runtime dependency without strong justification. The small + install / fast cold start is the whole point; new optional features go behind + an extra. +- Rewriting `Optional[...]` to `X | None` (breaks the 3.9 runtime). +- Conflating the JWT claim account with the API host (see Known Quirks). +- Forgetting the async counterpart of a sync change (sync/async parity). +- Forgetting a `test_regression_*` for a fixed bug. + +## Before Finishing + +1. `pre-commit run --all-files` is clean (ruff, black, mypy, yaml/toml, private-key). +2. `coverage run -m pytest && coverage report` passes and coverage holds >= 89%. +3. Sync/async parity: any client change has its counterpart, or a stated reason. +4. Fixed bugs have a `test_regression_*`; public API changes are in the + README / `docs/`. + +## Security + +Report vulnerabilities privately, see [SECURITY.md](SECURITY.md). Never commit +private keys (`.gitignore` and a `detect-private-key` pre-commit hook guard +this). ## Contributing Before opening a PR: -- [ ] Tests pass (`pytest --cov`) and coverage holds. +- [ ] Tests pass (`coverage run -m pytest && coverage report`) and coverage holds. - [ ] Formatted (`black`) and linted (`ruff`), type hints on public APIs (`mypy`). - [ ] No hardcoded account/region/role/warehouse values; configuration is generic. - [ ] Public API changes documented in the README. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5e02f87..13c494b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -10,6 +10,14 @@ git clone https://github.com/hampsterx/snowflake-sql-api cd snowflake-sql-api python -m venv .venv && source .venv/bin/activate pip install -e '.[dev]' +pre-commit install +``` + +`pre-commit install` wires the lint/format/type/secret hooks to run on every +commit. Run the full set against the whole tree at any time: + +```bash +pre-commit run --all-files ``` ## Before opening a PR @@ -18,10 +26,15 @@ pip install -e '.[dev]' ruff check snowflake_sql_api tests black --check snowflake_sql_api tests mypy snowflake_sql_api -pytest --cov=snowflake_sql_api --cov-report=term-missing +coverage run -m pytest && coverage report ``` -All four must pass. CI runs the same checks across Python 3.9-3.13. +All four must pass (and `pre-commit run --all-files` is clean). Use `coverage run +-m pytest`, not `pytest --cov`: the package ships a pytest plugin +(`snowflake_sql_api.testing`), and `coverage run` starts tracing before that +plugin is imported so import-time lines are measured correctly. The coverage gate +is enforced (`fail_under = 89` in `pyproject.toml`); `coverage report` exits +non-zero below it. CI runs the same checks across Python 3.9-3.13. Checklist: diff --git a/RELEASING.md b/RELEASING.md new file mode 100644 index 0000000..6e8a404 --- /dev/null +++ b/RELEASING.md @@ -0,0 +1,70 @@ +# Releasing + +`snowflake-sql-api` is released by pushing a git tag. There is no version to +hand-edit and no manual upload step. + +## Version source + +The version comes from the git tag via [`hatch-vcs`](https://github.com/ofek/hatch-vcs). +At build time the resolved version is written to `snowflake_sql_api/_version.py` +(gitignored). **Do not hand-edit a version anywhere** (not in `__init__.py`, not +in `pyproject.toml`); a feature PR that edits a version string is wrong. + +Tag convention: `vX.Y.Z` (PEP 440 pre-releases are fine too, e.g. `v0.2.0rc1`). + +## Cutting a release + +```bash +# from a clean master that has the commits you want to ship +git fetch origin +git switch master && git pull --ff-only + +git tag v0.1.0 +git push origin v0.1.0 +``` + +Pushing the tag triggers `.github/workflows/publish.yml`, which: + +1. Runs the lint + test matrix (same checks as CI). +2. Builds the sdist and wheel (`python -m build`); the version reflects the tag. +3. Runs `twine check dist/*`. +4. Publishes to PyPI via OIDC trusted publishing (no API token). + +## One-time prerequisites + +Trusted publishing must be configured before the first real release: + +1. **PyPI trusted publisher**: on the PyPI project, add a GitHub publisher for + `hampsterx/snowflake-sql-api`, workflow `publish.yml`, environment `pypi`. + (For the very first upload, create the project via a TestPyPI dry-run or a + pending publisher.) +2. **GitHub environment**: create an environment named `pypi` in the repo + settings. It can be empty; it exists to gate the publish job and bind the + trusted-publisher claim. + +## TestPyPI dry-run (optional) + +To rehearse without touching the real index, build locally and upload to +TestPyPI with a token: + +```bash +python -m build +twine check dist/* +twine upload --repository testpypi dist/* +``` + +## At v0.1.0 + +When cutting the first `v0.1.0`, flip the `Development Status` classifier in +`pyproject.toml` from `3 - Alpha` to `4 - Beta` in the same commit that precedes +the tag. + +## Fixing a bad tag + +```bash +git tag -d v0.1.0 +git push origin :refs/tags/v0.1.0 +``` + +Re-tag only if nothing was published yet; PyPI does not allow re-uploading a +version that already exists. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..60eb34f --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,31 @@ +# Security Policy + +## Reporting a vulnerability + +Please report security issues privately. Do **not** open a public issue for a +suspected vulnerability. + +- Preferred: open a private advisory via GitHub Security Advisories on this + repository ("Security" tab -> "Report a vulnerability"). +- Alternative: email the maintainer at tim.vdh@gmail.com with details and, if + possible, a minimal reproduction. + +Please include the affected version, the impact, and steps to reproduce. You can +expect an acknowledgement within a few days. Once a fix is available it will be +released to PyPI and the advisory published. + +## Scope + +This library performs keypair (JWT) authentication and sends SQL over HTTPS to +Snowflake's SQL API. Of particular interest: + +- handling of private keys and generated JWTs; +- parameter binding / SQL injection surface; +- TLS and request construction in the transport layer. + +## Handling secrets + +Never commit private keys. The repository's `.gitignore` excludes `*.pem`, +`*.p8`, `*.key`, and `*private_key*`, and a pre-commit `detect-private-key` hook +provides a second check. Provide keys at runtime via a file path, in-memory +bytes, or environment variables (see [docs/authentication.md](docs/authentication.md)). diff --git a/docs/authentication.md b/docs/authentication.md new file mode 100644 index 0000000..f01ac7b --- /dev/null +++ b/docs/authentication.md @@ -0,0 +1,115 @@ +# Authentication + +`snowflake-sql-api` uses Snowflake's keypair (JWT) authentication. Each request +carries a short-lived RS256 JWT signed by your RSA private key; Snowflake +verifies it against the public key registered on your user. There is no password +and no browser SSO flow. + +## 1. Generate an RSA key pair + +Unencrypted private key (simplest): + +```bash +openssl genrsa 2048 | openssl pkcs8 -topk8 -inform PEM -out rsa_key.p8 -nocrypt +openssl rsa -in rsa_key.p8 -pubout -out rsa_key.pub +``` + +Passphrase-encrypted private key (recommended for shared environments): + +```bash +openssl genrsa 2048 | openssl pkcs8 -topk8 -inform PEM -out rsa_key.p8 +openssl rsa -in rsa_key.p8 -pubout -out rsa_key.pub +``` + +Keep `rsa_key.p8` secret. Never commit it: the project's `.gitignore` already +excludes `*.p8`, `*.pem`, and `*private_key*`, and a pre-commit `detect-private-key` +hook is a second line of defence. + +## 2. Register the public key on your Snowflake user + +Take the body of `rsa_key.pub` (everything between the `BEGIN`/`END` lines, with +the newlines removed) and set it on the user. Run this as a role with the +privilege to alter the user (e.g. `SECURITYADMIN`): + +```sql +ALTER USER my_user SET RSA_PUBLIC_KEY='MIIBIjANBgkqh...rest of the public key...'; +``` + +Verify it took: + +```sql +DESCRIBE USER my_user; -- look for RSA_PUBLIC_KEY_FP (the fingerprint) +``` + +The client computes the same `SHA256:` fingerprint from your private key +and puts it in the JWT issuer claim, so this fingerprint must match. + +## 3. Construct the client + +From a key file: + +```python +from snowflake_sql_api import SnowflakeClient + +client = SnowflakeClient( + account="myorg-myaccount", + user="MY_USER", + private_key_path="/path/to/rsa_key.p8", +) +``` + +From in-memory PEM bytes (e.g. fetched from a secrets manager): + +```python +client = SnowflakeClient( + account="myorg-myaccount", + user="MY_USER", + private_key=pem_bytes, # bytes + private_key_passphrase="my-passphrase", # only if the key is encrypted +) +``` + +Or entirely from the environment (see [getting-started.md](getting-started.md) +for the variable list): + +```python +client = SnowflakeClient.from_env() +``` + +## The account-locator region gotcha + +This is the single most common keypair failure. Two account forms are derived +**differently**, and the client handles both for you: + +- The **JWT claim** account (issuer/subject) must drop any region/cloud suffix + and be uppercased: `xy12345.ap-southeast-2` becomes `XY12345`. Leaving the + region in makes the JWT invalid. +- The **API host** keeps the full account (the region routes the request): + `xy12345.ap-southeast-2` becomes + `xy12345.ap-southeast-2.snowflakecomputing.com`. + +You pass the full account locator (whatever Snowflake shows you, region included) +and the library splits it correctly. The org-account dash form +(`myorg-myaccount`) has no dot, so it is preserved as-is. + +If your account uses PrivateLink or a non-standard host, pass `host=` explicitly +to bypass host derivation (the claim account is still derived from `account`). + +## Token lifetime + +Tokens are signed with a one-hour lifetime (Snowflake's cap) and cached, with a +small renewal margin so an in-flight request never races expiry. You do not need +to manage tokens yourself. + +## Troubleshooting auth + +A `SnowflakeAuthError` (HTTP 401) almost always means one of: + +- the public key is not registered, or does not match the private key; +- the account locator carried a region into the claim (not possible via this + client unless you bypass it); +- significant clock skew between your host and Snowflake; +- an encrypted key supplied without (or with the wrong) passphrase, which raises + a `SnowflakeConfigError` before any request. + +See [troubleshooting.md](troubleshooting.md) for more. diff --git a/docs/cli.md b/docs/cli.md new file mode 100644 index 0000000..eb40e39 --- /dev/null +++ b/docs/cli.md @@ -0,0 +1,58 @@ +# CLI + +Installing the package puts a `snowflake-sql-api` command on your PATH for +ad-hoc queries. + +## Configuration + +The CLI reads connection settings from the environment (the same `SNOWFLAKE_*` +variables as `SnowflakeClient.from_env()`): + +```bash +export SNOWFLAKE_ACCOUNT="myorg-myaccount" +export SNOWFLAKE_USER="MY_USER" +export SNOWFLAKE_PRIVATE_KEY_PATH="/path/to/rsa_key.p8" +# optional: SNOWFLAKE_ROLE / SNOWFLAKE_WAREHOUSE / SNOWFLAKE_DATABASE / SNOWFLAKE_SCHEMA +``` + +See [getting-started.md](getting-started.md) for the full variable list and +[authentication.md](authentication.md) for key setup. + +## Running a query + +```bash +snowflake-sql-api query "SELECT current_version()" +``` + +Output is JSON (a list of row objects), with type-coerced values rendered +JSON-safely: `Decimal` as a string, dates/timestamps as ISO 8601, binary as hex. + +```bash +$ snowflake-sql-api query "SELECT 1 AS n, 'hi' AS greeting" +[ + { + "N": 1, + "GREETING": "hi" + } +] +``` + +## Version + +```bash +snowflake-sql-api --version +``` + +## Exit codes + +| Code | Meaning | +|------|---------| +| `0` | query ran, result printed | +| `1` | a `SnowflakeError` occurred (message on stderr) | +| `2` | no subcommand given (help printed) | + +## Scope + +This is the minimal `query` command (JSON output). Richer output formats +(`--format table|csv|json|jsonl`), reading SQL from a file, and a progress spinner +are planned for a later release. diff --git a/docs/getting-started.md b/docs/getting-started.md new file mode 100644 index 0000000..2000980 --- /dev/null +++ b/docs/getting-started.md @@ -0,0 +1,128 @@ +# Getting Started + +`snowflake-sql-api` is a small, pure-Python client for Snowflake's +[SQL API v2](https://docs.snowflake.com/en/developer-guide/sql-api/index). This +page gets you from install to a first query. + +## Install + +```bash +pip install snowflake-sql-api +``` + +Optional extras (kept out of the default install to stay small): + +```bash +pip install "snowflake-sql-api[pandas]" # DataFrame output helpers +pip install "snowflake-sql-api[pydantic]" # typed-row mapping +``` + +Requires Python 3.9 or newer. Core dependencies are `httpx`, `PyJWT`, and +`cryptography`. + +## Prerequisites + +You need keypair (JWT) authentication set up: an RSA key pair, with the public +key registered on your Snowflake user. See [authentication.md](authentication.md) +for the full walkthrough. The short version: + +```bash +openssl genrsa 2048 | openssl pkcs8 -topk8 -inform PEM -out rsa_key.p8 -nocrypt +openssl rsa -in rsa_key.p8 -pubout -out rsa_key.pub +``` + +Then register the public key on the user (run as a role that can alter the user): + +```sql +ALTER USER my_user SET RSA_PUBLIC_KEY=''; +``` + +## Your first query + +```python +from snowflake_sql_api import SnowflakeClient + +client = SnowflakeClient( + account="myorg-myaccount", # or a region locator like "xy12345.ap-southeast-2" + user="MY_USER", + private_key_path="/path/to/rsa_key.p8", + warehouse="MY_WH", # optional session context + database="MY_DB", + schema="PUBLIC", +) + +rows = client.query("SELECT id, name FROM users WHERE active = ?", [True]) +for row in rows: + print(row["ID"], row["NAME"]) + +client.close() +``` + +`query` returns a list of dicts keyed by column name, with values coerced to +native Python types (numbers to `int`/`Decimal`, dates/timestamps to +`datetime`/`date`/`time`, VARIANT to `dict`/`list`, and so on). + +Use it as a context manager to close the underlying HTTP client automatically: + +```python +with SnowflakeClient(account="myorg-myaccount", user="MY_USER", + private_key_path="/path/to/rsa_key.p8") as client: + version = client.query_scalar("SELECT current_version()") +``` + +## Query helpers + +| Method | Returns | +|--------|---------| +| `query(sql, params)` | all rows (list of dicts) | +| `query_one(sql, params)` | first row, or `None` | +| `query_scalar(sql, params)` | first column of the first row, or `None` | +| `query_column(sql, params)` | first column across all rows | +| `execute(sql, params)` | rows affected (DML/DDL) | +| `insert_many(table, columns, rows)` | rows inserted (batched, bound) | +| `submit(sql, params)` | a `QueryHandle` for a long-running statement | + +Bind parameters are positional (`?`) and always sent as server-side bindings, +never string-interpolated. + +## Configuration from the environment + +`from_env()` reads `SNOWFLAKE_*` variables, which keeps credentials out of code: + +```python +client = SnowflakeClient.from_env() +``` + +| Variable | Purpose | +|----------|---------| +| `SNOWFLAKE_ACCOUNT` | account locator (required) | +| `SNOWFLAKE_USER` | user name (required) | +| `SNOWFLAKE_PRIVATE_KEY` | PEM key contents, or | +| `SNOWFLAKE_PRIVATE_KEY_PATH` | path to a PEM/DER key file | +| `SNOWFLAKE_PRIVATE_KEY_PASSPHRASE` | passphrase for an encrypted key | +| `SNOWFLAKE_ROLE` / `SNOWFLAKE_WAREHOUSE` | session role / warehouse | +| `SNOWFLAKE_DATABASE` / `SNOWFLAKE_SCHEMA` | session database / schema | +| `SNOWFLAKE_HOST` | override the derived API hostname (PrivateLink, etc.) | + +## Async + +The async client mirrors the sync surface with `await` and an async context +manager: + +```python +from snowflake_sql_api import AsyncSnowflakeClient + +async def main(): + async with AsyncSnowflakeClient.from_env() as client: + rows = await client.query("SELECT current_timestamp()") + print(rows) +``` + +## Next steps + +- [authentication.md](authentication.md): keypair setup, encrypted keys, the + account-locator region gotcha. +- [cli.md](cli.md): the `snowflake-sql-api` command. +- [testing.md](testing.md): mock the client in your own tests, no Snowflake + account required. +- [troubleshooting.md](troubleshooting.md): auth failures, polling, partitions. diff --git a/docs/testing.md b/docs/testing.md new file mode 100644 index 0000000..01753ac --- /dev/null +++ b/docs/testing.md @@ -0,0 +1,159 @@ +# Testing + +`snowflake_sql_api.testing` lets you drive `SnowflakeClient` / +`AsyncSnowflakeClient` against canned results with **no network and no Snowflake +account**. It ships with the package (no extra dependency: it is built on +`httpx.MockTransport`, and `httpx` is already a core dependency). You do not need +`respx` or any other HTTP mock. + +## Quick start + +```python +from snowflake_sql_api.testing import FakeSnowflake, make_client + +fake = FakeSnowflake() +fake.register("SELECT id, name FROM users", [ + {"ID": 1, "NAME": "alice"}, + {"ID": 2, "NAME": "bob"}, +]) + +client = make_client(fake) +assert client.query("SELECT id, name FROM users") == [ + {"ID": 1, "NAME": "alice"}, + {"ID": 2, "NAME": "bob"}, +] +``` + +`make_client(fake)` returns a real `SnowflakeClient` (generated throwaway key, an +`httpx.MockTransport` wired to `fake`, instant polling). Every code path runs for +real except the wire: auth signs a JWT, the transport builds requests, results +are coerced. Only Snowflake is faked. + +## Registering results + +Rows can be dicts (column names and types inferred) or positional lists with an +explicit `columns` spec: + +```python +fake.register("SELECT id, name FROM t", [[1, "alice"], [2, "bob"]], + columns=["ID", "NAME"]) +``` + +Give explicit column types when inference is not enough (e.g. a fixed-point +scale, or a timestamp variant): + +```python +from decimal import Decimal + +fake.register("SELECT amount FROM t", [{"AMOUNT": Decimal("10.50")}], + columns=[{"name": "AMOUNT", "type": "fixed", "scale": 2}]) +``` + +Values are native Python objects. The fake encodes them to the SQL API wire form +and the client coerces them straight back, so the round trip is lossless for: +integers and `Decimal`, floats, text, booleans, `bytes`, `date` / `time` / +`datetime` (naive and tz-aware), and VARIANT (`dict` / `list`). `None` becomes a +SQL NULL. + +### Predicate matching + +When you cannot pin an exact SQL string (generated SQL, bound inserts), match on +a predicate: + +```python +fake.register_match(lambda sql: sql.startswith("SELECT count"), [{"N": 99}]) +``` + +Lookups try exact matches first, then predicates in registration order. + +### DML and errors + +```python +fake.register_dml("DELETE FROM t WHERE id = 1", rowcount=1) +assert client.execute("DELETE FROM t WHERE id = 1") == 1 + +fake.register_error("SELECT bad syntax", "SQL compilation error", code="000904") +# client.query("SELECT bad syntax") now raises SnowflakeProgrammingError +``` + +## Multi-partition results + +Split rows across partitions to exercise the partition-fetch path: + +```python +rows = [{"N": n} for n in range(1000)] +fake.register("SELECT n FROM big", rows, partitions=4) +assert client.query("SELECT n FROM big") == rows # all partitions, in order +``` + +## Long-running / async-submit + +`polls_before_ready` makes a statement report RUNNING (HTTP 202) for that many +status polls before completing: + +```python +fake.register("CALL slow()", [{"DONE": True}], polls_before_ready=2) + +handle = client.submit("CALL slow()") +# handle.result(poll=False) would raise ResultNotReady here +rows = handle.result() # polls until ready, then returns rows +``` + +## Async client + +```python +from snowflake_sql_api.testing import FakeSnowflake, make_async_client + +async def test_async(): + fake = FakeSnowflake() + fake.register("SELECT 1", [{"N": 1}]) + async with make_async_client(fake) as client: + assert await client.query_scalar("SELECT 1") == 1 +``` + +## Pytest fixtures + +Installing the package registers three fixtures via a `pytest11` entry point. No +`conftest.py` wiring is needed: + +| Fixture | Provides | +|---------|----------| +| `fake_snowflake` | a fresh `FakeSnowflake` | +| `snowflake_client` | a `SnowflakeClient` wired to `fake_snowflake` | +| `async_snowflake_client` | an `AsyncSnowflakeClient` wired to `fake_snowflake` (registered only when `pytest-asyncio` is installed) | + +```python +def test_users(fake_snowflake, snowflake_client): + fake_snowflake.register("SELECT name FROM users", [{"NAME": "alice"}]) + assert snowflake_client.query_column("SELECT name FROM users") == ["alice"] +``` + +> **Coverage note:** because the package ships a pytest plugin, measure coverage +> with `coverage run -m pytest` rather than `pytest --cov`. The former starts +> tracing before the plugin is imported; the latter records the plugin's +> import-time lines as uncovered. + +## Drop-in for application code + +If your code constructs its own client, build it against the fake by injecting +the mock transport through `http_client`, or patch your factory to return +`make_client(fake)`: + +```python +import httpx +from snowflake_sql_api import SnowflakeClient + +client = SnowflakeClient( + account="myorg-myaccount", + user="MY_USER", + private_key=test_key_bytes, + http_client=httpx.Client(transport=fake.transport), +) +``` + +## Asserting on what ran + +```python +fake.submitted_statements # list of SQL strings submitted, in order +fake.requests # every httpx.Request the fake handled +``` diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md new file mode 100644 index 0000000..34af05a --- /dev/null +++ b/docs/troubleshooting.md @@ -0,0 +1,102 @@ +# Troubleshooting + +The client maps SQL API responses onto a typed exception hierarchy rooted at +`SnowflakeError`. Catch that to handle everything, or branch on the specific +types below. + +| Exception | Raised when | +|-----------|-------------| +| `SnowflakeConfigError` | bad/missing config, before any request (no key, unreadable key, missing extra) | +| `SnowflakeAuthError` | JWT generation or Snowflake auth failed (HTTP 401) | +| `SnowflakeProgrammingError` | the SQL failed to compile or execute (HTTP 422) | +| `SnowflakeRequestError` | an HTTP/protocol-level failure (carries `code`, `sql_state`, `status_code`) | +| `SnowflakeTimeoutError` | an HTTP request or a statement exceeded its timeout | +| `SnowflakeRetryError` | retries exhausted (original error on `__cause__`) | +| `ResultNotReady` | `QueryHandle.result(poll=False)` while the statement is still running | + +## Authentication failures (401 / `SnowflakeAuthError`) + +Work through these in order: + +1. **Public key not registered or mismatched.** `DESCRIBE USER my_user` and + compare `RSA_PUBLIC_KEY_FP` against the key you are signing with. Regenerating + the key without re-running `ALTER USER ... SET RSA_PUBLIC_KEY` is the usual + cause. +2. **Account locator.** The JWT claim account must have the region/cloud suffix + stripped (`xy12345.ap-southeast-2` to `XY12345`); the client does this for + you. If you hand-built a host or claim, recheck it. See + [authentication.md](authentication.md#the-account-locator-region-gotcha). +3. **Clock skew.** JWTs are time-bound; a host clock off by minutes will be + rejected. Sync the clock (NTP). +4. **Encrypted key.** A passphrase-protected key without the right + `private_key_passphrase` raises `SnowflakeConfigError` at construction, not a + 401. + +## SQL errors (422 / `SnowflakeProgrammingError`) + +The statement reached Snowflake and was rejected (syntax error, missing object, +constraint violation). The exception carries Snowflake's `code` and `sql_state`: + +```python +from snowflake_sql_api.exceptions import SnowflakeProgrammingError + +try: + client.query("SELECT * FROM does_not_exist") +except SnowflakeProgrammingError as exc: + print(exc.code, exc.sql_state, exc) +``` + +## Long-running statements (202 / polling) + +Large or slow statements come back as HTTP 202 ("still running"). The standard +`query`/`execute` calls poll automatically until completion (bounded by +`statement_timeout` if set, else a default wait). + +For fire-and-forget, submit asynchronously and poll yourself: + +```python +handle = client.submit("CALL long_running_proc()") + +handle.status() # "RUNNING" or "SUCCESS" + +# Non-blocking fetch: raises ResultNotReady if still running, instead of +# returning a misleading in-progress payload. +from snowflake_sql_api.exceptions import ResultNotReady +try: + rows = handle.result(poll=False) +except ResultNotReady: + ... # check again later + +rows = handle.result() # blocking: polls until done +``` + +If a statement never finishes within the wait, you get `SnowflakeTimeoutError`. + +## Large result sets (partitions) + +The SQL API splits large results into partitions: partition 0 arrives inline, +the rest are fetched by index. `query` fetches **every** partition and returns +the rows in order, so you never silently get a truncated result. No action +needed; just be aware a single `query` may issue several GETs for a big result. + +## Transient failures and retries + +Connect/read timeouts and HTTP 429/5xx are retried with exponential backoff and +full jitter. DML submits reuse their `requestId` with `retry=true`, so a retried +insert/update cannot double-apply. When retries are exhausted you get +`SnowflakeRetryError` with the last underlying error attached as `__cause__`. + +Tune via `retry_policy=RetryPolicy(...)` on the client constructor. + +## Connection / network + +`httpx` transport errors (DNS, TLS, connection refused) surface after retries as +`SnowflakeRetryError`. Check the derived host +(`.snowflakecomputing.com`) is reachable from your network; for +PrivateLink or custom endpoints pass `host=` explicitly. + +## Testing without a real account + +To unit-test code that uses this client, mock it with the shipped +`snowflake_sql_api.testing` helper (no network, no Snowflake). See +[testing.md](testing.md). diff --git a/pyproject.toml b/pyproject.toml index 96e14df..0cf91cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["hatchling>=1.18"] +requires = ["hatchling>=1.18", "hatch-vcs>=0.4"] build-backend = "hatchling.build" [project] @@ -37,13 +37,16 @@ pandas = ["pandas>=1.3"] pydantic = ["pydantic>=2.0"] dev = [ "pytest>=7.0", - "pytest-cov>=4.0", + "coverage[toml]>=7.0", "pytest-asyncio>=0.21", "respx>=0.20", "ruff>=0.4", - "black>=24.0", + # Capped below 25.10: 25.12.0+ drop Python 3.9, which the package still + # supports. Keeps CI and the pinned pre-commit hook on the same black. + "black>=24.0,<25.10", "mypy>=1.8", "build>=1.0", + "pre-commit>=3.5", ] [project.urls] @@ -54,8 +57,19 @@ Issues = "https://github.com/hampsterx/snowflake-sql-api/issues" [project.scripts] snowflake-sql-api = "snowflake_sql_api.cli:main" +# Auto-registers the snowflake_sql_api.testing pytest fixtures (fake_snowflake, +# snowflake_client, async_snowflake_client) for any project with the package +# installed. The fixtures are no-ops if pytest is absent. +[project.entry-points.pytest11] +snowflake_sql_api = "snowflake_sql_api.testing" + +# Version is derived from the git tag (hatch-vcs); never hand-edit it. The build +# writes the resolved version into snowflake_sql_api/_version.py (gitignored). [tool.hatch.version] -path = "snowflake_sql_api/__init__.py" +source = "vcs" + +[tool.hatch.build.hooks.vcs] +version-file = "snowflake_sql_api/_version.py" [tool.hatch.build.targets.wheel] packages = ["snowflake_sql_api"] @@ -126,6 +140,9 @@ omit = ["snowflake_sql_api/cli.py", "snowflake_sql_api/row_mapping.py"] [tool.coverage.report] show_missing = true +# Enforced locally too, not just in CI: `coverage report` exits non-zero below +# this. (Run coverage via `coverage run -m pytest`, not `pytest --cov`; see AGENTS.md.) +fail_under = 89 exclude_lines = [ "pragma: no cover", "raise NotImplementedError", diff --git a/snowflake_sql_api/__init__.py b/snowflake_sql_api/__init__.py index aa08398..94dc923 100644 --- a/snowflake_sql_api/__init__.py +++ b/snowflake_sql_api/__init__.py @@ -12,7 +12,19 @@ from __future__ import annotations -__version__ = "0.1.0.dev0" +try: + # Written at build time by hatch-vcs from the git tag (see pyproject.toml). + from ._version import __version__ +except ImportError: # pragma: no cover - plain source checkout, never built + # No build-generated version file (e.g. running from a raw clone). Fall back + # to installed package metadata, then to a dev placeholder. + from importlib.metadata import PackageNotFoundError + from importlib.metadata import version as _pkg_version + + try: + __version__ = _pkg_version("snowflake-sql-api") + except PackageNotFoundError: + __version__ = "0.0.0.dev0" from .aclient import AsyncQueryHandle, AsyncSnowflakeClient from .client import QueryHandle, SnowflakeClient diff --git a/snowflake_sql_api/aclient.py b/snowflake_sql_api/aclient.py index 76430cc..f7a2f5c 100644 --- a/snowflake_sql_api/aclient.py +++ b/snowflake_sql_api/aclient.py @@ -14,6 +14,8 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple +import httpx + from .auth import KeypairAuthenticator, account_hostname from .bindings import to_bindings from .client import DEFAULT_POLL_TIMEOUT, _columns_from, _rows_affected @@ -93,6 +95,7 @@ def __init__( retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY, user_agent: Optional[str] = None, on_query: Optional[OnQuery] = None, + http_client: Optional[httpx.AsyncClient] = None, ) -> None: if private_key is None and private_key_path is None: raise SnowflakeConfigError("a private_key or private_key_path is required") @@ -122,6 +125,7 @@ def __init__( timeout=timeout, retry_policy=retry_policy, user_agent=user_agent, + client=http_client, ) @classmethod diff --git a/snowflake_sql_api/client.py b/snowflake_sql_api/client.py index 7fd3b2f..50ee49b 100644 --- a/snowflake_sql_api/client.py +++ b/snowflake_sql_api/client.py @@ -14,6 +14,8 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple +import httpx + from .auth import KeypairAuthenticator, account_hostname from .bindings import to_bindings from .escaping import quote_identifier, quote_name @@ -131,6 +133,7 @@ def __init__( retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY, user_agent: Optional[str] = None, on_query: Optional[OnQuery] = None, + http_client: Optional[httpx.Client] = None, ) -> None: if private_key is None and private_key_path is None: raise SnowflakeConfigError("a private_key or private_key_path is required") @@ -160,6 +163,7 @@ def __init__( timeout=timeout, retry_policy=retry_policy, user_agent=user_agent, + client=http_client, ) @classmethod diff --git a/snowflake_sql_api/testing.py b/snowflake_sql_api/testing.py new file mode 100644 index 0000000..381316c --- /dev/null +++ b/snowflake_sql_api/testing.py @@ -0,0 +1,673 @@ +"""In-process testing utilities for ``snowflake-sql-api``. + +Drive :class:`~snowflake_sql_api.client.SnowflakeClient` and +:class:`~snowflake_sql_api.aclient.AsyncSnowflakeClient` against canned results +with **no network and no real Snowflake account**. A :class:`FakeSnowflake` +registry maps SQL statements to rows and exposes itself as an +``httpx.MockTransport`` handler, which plugs into the client's existing transport +seam (``http_client=``). No ``respx`` (or any other test dependency) is needed: +``httpx`` is already a core dependency. + +Quick start:: + + from snowflake_sql_api.testing import FakeSnowflake, make_client + + fake = FakeSnowflake() + fake.register("SELECT id, name FROM users", [ + {"ID": 1, "NAME": "alice"}, + {"ID": 2, "NAME": "bob"}, + ]) + + client = make_client(fake) + assert client.query("SELECT id, name FROM users") == [ + {"ID": 1, "NAME": "alice"}, + {"ID": 2, "NAME": "bob"}, + ] + +Rows can be given as dicts (column order and types inferred) or as positional +lists with an explicit ``columns`` spec. Values are native Python objects; the +fake encodes them to the SQL API wire form so the client coerces them straight +back (a clean round trip), covering numbers, Decimals, booleans, text, binary, +dates/times/timestamps, and VARIANT (dict/list). + +Pytest fixtures (``fake_snowflake``, ``snowflake_client``, +``async_snowflake_client``) are auto-registered via the ``pytest11`` entry point +when the package is installed; no ``conftest.py`` wiring is required. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from datetime import date, datetime, time, timedelta, timezone +from decimal import Decimal +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Optional, + Sequence, + Tuple, + Union, +) + +import httpx + +if TYPE_CHECKING: + from .aclient import AsyncSnowflakeClient + from .client import SnowflakeClient + +__all__ = [ + "FakeSnowflake", + "FakeSnowflakeError", + "make_client", + "make_async_client", + "ok_body", + "running_body", +] + +STATEMENTS_PATH = "/api/v2/statements" + +#: A row given to :meth:`FakeSnowflake.register`: a name->value dict or a +#: positional sequence of cell values (the latter needs an explicit ``columns``). +RowInput = Union[Dict[str, Any], Sequence[Any]] + +#: A column spec entry: a bare name (type inferred) or an explicit +#: ``{"name": ..., "type": ..., "scale": ...}`` dict. +ColumnSpec = Union[str, Dict[str, Any]] + + +class FakeSnowflakeError(AssertionError): + """Raised for a test-setup mistake (unregistered statement, bad request). + + Subclasses :class:`AssertionError` so pytest surfaces it prominently: it + signals a gap in the fake's wiring, not a simulated Snowflake error. Use + :meth:`FakeSnowflake.register_error` to simulate real Snowflake failures. + """ + + +# --------------------------------------------------------------------------- +# SQL API envelope builders (single source of truth, shared with the test suite) +# --------------------------------------------------------------------------- + + +def ok_body( + row_type: List[Dict[str, Any]], + data: List[List[Any]], + *, + partitions: int = 1, + handle: str = "stmt-1", + stats: Optional[Dict[str, int]] = None, +) -> Dict[str, Any]: + """Build a 200 success body. ``partitions`` sets the partitionInfo length. + + Partition 0's ``rowCount`` reflects ``data``; the remaining entries are + placeholders (their rows are served by the partition GET routes). + """ + if partitions < 1: + raise ValueError("partitions must be >= 1") + partition_info: List[Dict[str, int]] = [{"rowCount": len(data)}] + partition_info += [{"rowCount": 0} for _ in range(partitions - 1)] + body: Dict[str, Any] = { + "resultSetMetaData": { + "numRows": len(data), + "format": "jsonv2", + "partitionInfo": partition_info, + "rowType": row_type, + }, + "data": data, + "code": "090001", + "sqlState": "00000", + "statementHandle": handle, + "statementStatusUrl": f"{STATEMENTS_PATH}/{handle}", + } + if stats is not None: + body["stats"] = stats + return body + + +def running_body(handle: str = "stmt-1", code: str = "333334") -> Dict[str, Any]: + """Build a 202 'still running / submitted async' body.""" + return { + "code": code, + "message": "Asynchronous execution in progress.", + "statementHandle": handle, + "statementStatusUrl": f"{STATEMENTS_PATH}/{handle}", + } + + +# --------------------------------------------------------------------------- +# Wire encoding (the exact inverse of snowflake_sql_api.types.coerce_value) +# --------------------------------------------------------------------------- + +_EPOCH_DATE = date(1970, 1, 1) +_EPOCH_NAIVE = datetime(1970, 1, 1) + + +def _fraction(micros: int) -> str: + return "" if micros == 0 else f".{micros:06d}" + + +def _epoch_delta(value: datetime) -> timedelta: + """Signed timedelta from the epoch for an absolute instant (as UTC).""" + if value.tzinfo is not None: + value = value.astimezone(timezone.utc).replace(tzinfo=None) + return value - _EPOCH_NAIVE + + +def _epoch_string(delta: timedelta) -> str: + """Format an epoch ``timedelta`` as the SQL API numeric wire string. + + Uses integer math (no float division, so large timestamps keep full + precision) and carries the sign into the fractional part so the value + decodes back exactly, including pre-1970 sub-second instants: the production + decoder in ``types.py`` reads the fraction as signed, so ``-0.5s`` must be + ``"-0.500000"``, not ``"-1.500000"``. + """ + total_us = (delta.days * 86400 + delta.seconds) * 1_000_000 + delta.microseconds + negative = total_us < 0 + secs, frac_us = divmod(abs(total_us), 1_000_000) + if frac_us == 0: + return f"-{secs}" if negative and secs else str(secs) + body = f"{secs}.{frac_us:06d}" + return f"-{body}" if negative else body + + +def _infer_type(value: Any) -> Tuple[str, Optional[int]]: + """Infer ``(snowflake_type, scale)`` from a native Python value.""" + if isinstance(value, bool): # before int: bool is a subclass of int + return "boolean", None + if isinstance(value, int): + return "fixed", 0 + if isinstance(value, Decimal): + exponent = value.as_tuple().exponent + scale = -exponent if isinstance(exponent, int) and exponent < 0 else 0 + return "fixed", scale + if isinstance(value, float): + return "real", None + if isinstance(value, (bytes, bytearray)): + return "binary", None + if isinstance(value, datetime): # before date: datetime is a subclass of date + return ("timestamp_tz" if value.tzinfo is not None else "timestamp_ntz"), None + if isinstance(value, date): + return "date", None + if isinstance(value, time): + return "time", None + if isinstance(value, dict): + return "object", None + if isinstance(value, (list, tuple)): + return "array", None + return "text", None + + +def _encode_cell(value: Any, col_type: str) -> Optional[str]: + """Encode one native value to its SQL API wire string (or ``None``).""" + if value is None: + return None + if col_type == "boolean": + return "true" if value else "false" + if col_type in ("variant", "object", "array"): + return json.dumps(value) + if col_type == "binary": + return bytes(value).hex() + if col_type == "date": + return str((value - _EPOCH_DATE).days) + if col_type == "time": + secs = value.hour * 3600 + value.minute * 60 + value.second + return f"{secs}{_fraction(value.microsecond)}" + if col_type in ("timestamp_ntz", "timestamp_ltz", "timestamp_tz"): + if col_type == "timestamp_ntz": + naive = value.replace(tzinfo=None) if value.tzinfo else value + return _epoch_string(naive - _EPOCH_NAIVE) + instant = ( + value if value.tzinfo is not None else value.replace(tzinfo=timezone.utc) + ) + if col_type == "timestamp_ltz": + return _epoch_string(_epoch_delta(instant)) + offset = instant.utcoffset() + offset_min = int(offset.total_seconds() // 60) if offset is not None else 0 + return f"{_epoch_string(_epoch_delta(instant))} {offset_min + 1440}" + if col_type == "fixed": + # Plain decimal text, never scientific notation: str(Decimal("1E+2")) is + # "1E+2", which the decoder's int()/Decimal() parse would reject. A + # whole-number Decimal still decodes to int (scale 0), matching Snowflake. + return format(value, "f") if isinstance(value, Decimal) else str(value) + if col_type in ("real", "float", "double"): + return repr(float(value)) + # text / char / string / varchar / unknown: stringify as-is. + return str(value) + + +# --------------------------------------------------------------------------- +# Registry +# --------------------------------------------------------------------------- + + +@dataclass +class _Error: + status: int + message: str + code: Optional[str] = None + sql_state: Optional[str] = None + + +@dataclass +class _Reg: + """A registered result: wire rows split across partitions, plus behavior.""" + + row_type: List[Dict[str, Any]] + chunks: List[List[List[Any]]] + polls_before_ready: int = 0 + stats: Optional[Dict[str, int]] = None + error: Optional[_Error] = None + + +@dataclass +class _State: + """Per-submission handle state (poll progress).""" + + reg: _Reg + remaining_polls: int + + +def _normalize_columns( + rows: List[RowInput], columns: Optional[Sequence[ColumnSpec]] +) -> List[Tuple[str, str, Optional[int]]]: + """Resolve column ``(name, type, scale)`` triples from rows and/or specs.""" + if columns is not None: + resolved: List[Tuple[str, str, Optional[int]]] = [] + for index, spec in enumerate(columns): + if isinstance(spec, str): + name, ctype, scale = spec, None, None + else: + name = spec["name"] + ctype = spec.get("type") + scale = spec.get("scale") + if ctype is None: + ctype, scale = _infer_column(rows, name, index) + resolved.append((name, ctype, scale)) + return resolved + + if not rows: + return [] + first = rows[0] + if not isinstance(first, dict): + raise FakeSnowflakeError( + "positional rows require an explicit `columns=` spec; " + "pass dict rows to infer column names" + ) + return [ + _named(name, *_infer_column(rows, name, idx)) for idx, name in enumerate(first) + ] + + +def _named( + name: str, ctype: str, scale: Optional[int] +) -> Tuple[str, str, Optional[int]]: + return name, ctype, scale + + +def _infer_column( + rows: List[RowInput], name: str, index: int +) -> Tuple[str, Optional[int]]: + """Infer a column's type/scale from the first non-null value across rows.""" + for row in rows: + value = row[name] if isinstance(row, dict) else row[index] + if value is not None: + return _infer_type(value) + return "text", None + + +def _cell(row: RowInput, name: str, index: int) -> Any: + return row[name] if isinstance(row, dict) else row[index] + + +def _split(data: List[List[Any]], partitions: int) -> List[List[List[Any]]]: + """Split wire rows into ``partitions`` chunks, preserving order.""" + if partitions < 1: + raise ValueError("partitions must be >= 1") + if partitions == 1: + return [data] + size = -(-len(data) // partitions) or 1 # ceil division, at least 1 + chunks = [data[i : i + size] for i in range(0, len(data), size)] + while len(chunks) < partitions: + chunks.append([]) + return chunks + + +class FakeSnowflake: + """In-memory Snowflake SQL API stand-in backed by ``httpx.MockTransport``. + + Register results for SQL statements, then build a client with + :func:`make_client` / :func:`make_async_client` (or pass ``.transport`` to an + ``httpx.Client``). Lookups try exact-string matches first, then predicate + matches in registration order. + """ + + def __init__(self) -> None: + self._exact: Dict[str, _Reg] = {} + self._predicates: List[Tuple[Callable[[str], bool], _Reg]] = [] + self._state: Dict[str, _State] = {} + self._counter = 0 + #: Every request the fake handled, in order (for assertions). + self.requests: List[httpx.Request] = [] + + # -- registration ----------------------------------------------------- + + def register( + self, + sql: str, + rows: Sequence[RowInput], + *, + columns: Optional[Sequence[ColumnSpec]] = None, + partitions: int = 1, + polls_before_ready: int = 0, + stats: Optional[Dict[str, int]] = None, + ) -> "FakeSnowflake": + """Register the result rows returned for an exact ``sql`` string. + + ``rows`` are dicts (names/types inferred) or positional sequences (need + ``columns``). ``partitions`` splits the rows across result partitions to + exercise multi-partition fetching. ``polls_before_ready`` makes the + statement report RUNNING (HTTP 202) for that many status polls first. + """ + self._exact[sql] = self._build( + list(rows), columns, partitions, polls_before_ready, stats + ) + return self + + def register_match( + self, + predicate: Callable[[str], bool], + rows: Sequence[RowInput], + *, + columns: Optional[Sequence[ColumnSpec]] = None, + partitions: int = 1, + polls_before_ready: int = 0, + stats: Optional[Dict[str, int]] = None, + ) -> "FakeSnowflake": + """Register a result for any statement where ``predicate(sql)`` is true.""" + reg = self._build(list(rows), columns, partitions, polls_before_ready, stats) + self._predicates.append((predicate, reg)) + return self + + def register_dml(self, sql: str, rowcount: int) -> "FakeSnowflake": + """Register a DML/DDL statement, returning ``rowcount`` rows affected.""" + self._exact[sql] = _Reg(row_type=[], chunks=[[[str(rowcount)]]]) + return self + + def register_error( + self, + sql: str, + message: str, + *, + status: int = 422, + code: Optional[str] = None, + sql_state: Optional[str] = None, + ) -> "FakeSnowflake": + """Register a statement that fails (default HTTP 422 -> programming error).""" + self._exact[sql] = _Reg( + row_type=[], + chunks=[[]], + error=_Error( + status=status, message=message, code=code, sql_state=sql_state + ), + ) + return self + + def _build( + self, + rows: List[RowInput], + columns: Optional[Sequence[ColumnSpec]], + partitions: int, + polls_before_ready: int, + stats: Optional[Dict[str, int]], + ) -> _Reg: + specs = _normalize_columns(rows, columns) + row_type: List[Dict[str, Any]] = [] + for name, ctype, scale in specs: + entry: Dict[str, Any] = {"name": name, "type": ctype, "nullable": True} + if scale is not None: + entry["scale"] = scale + row_type.append(entry) + wire: List[List[Any]] = [ + [ + _encode_cell(_cell(row, name, index), ctype) + for index, (name, ctype, _scale) in enumerate(specs) + ] + for row in rows + ] + return _Reg( + row_type=row_type, + chunks=_split(wire, partitions), + polls_before_ready=polls_before_ready, + stats=stats, + ) + + # -- transport seam --------------------------------------------------- + + @property + def transport(self) -> httpx.MockTransport: + """An ``httpx.MockTransport`` wired to this fake (for sync or async).""" + return httpx.MockTransport(self.handle) + + # -- introspection ---------------------------------------------------- + + @property + def submitted_statements(self) -> List[str]: + """Every SQL statement submitted via POST, in order.""" + out: List[str] = [] + for request in self.requests: + if request.method == "POST" and request.url.path == STATEMENTS_PATH: + out.append(json.loads(request.content).get("statement", "")) + return out + + # -- request handling ------------------------------------------------- + + def handle(self, request: httpx.Request) -> httpx.Response: + """Route a request to the right canned response (the MockTransport hook).""" + self.requests.append(request) + path = request.url.path + if request.method == "POST" and path == STATEMENTS_PATH: + return self._on_submit(request) + if request.method == "POST" and path.endswith("/cancel"): + return httpx.Response(200, json={"code": "090001"}) + if request.method == "GET" and path.startswith(STATEMENTS_PATH + "/"): + return self._on_get(request) + raise FakeSnowflakeError(f"unexpected request: {request.method} {path}") + + def _lookup(self, sql: str) -> Optional[_Reg]: + if sql in self._exact: + return self._exact[sql] + for predicate, reg in self._predicates: + if predicate(sql): + return reg + return None + + def _next_handle(self) -> str: + self._counter += 1 + return f"stmt-{self._counter}" + + def _on_submit(self, request: httpx.Request) -> httpx.Response: + sql = json.loads(request.content).get("statement", "") + reg = self._lookup(sql) + if reg is None: + raise FakeSnowflakeError(f"no result registered for statement: {sql!r}") + handle = self._next_handle() + if reg.error is not None: + err = reg.error + return httpx.Response( + err.status, + json={ + "message": err.message, + "code": err.code, + "sqlState": err.sql_state, + "statementHandle": handle, + }, + ) + self._state[handle] = _State(reg=reg, remaining_polls=reg.polls_before_ready) + async_exec = request.url.params.get("async") == "true" + if async_exec or reg.polls_before_ready > 0: + return httpx.Response(202, json=running_body(handle)) + return self._result_response(handle) + + def _on_get(self, request: httpx.Request) -> httpx.Response: + handle = request.url.path.rsplit("/", 1)[-1] + state = self._state.get(handle) + if state is None: + raise FakeSnowflakeError(f"unknown statement handle: {handle}") + partition = request.url.params.get("partition") + if partition is not None: + index = int(partition) + chunks = state.reg.chunks + if not 0 <= index < len(chunks): + # Surface partition bugs instead of masking them as an empty page. + raise FakeSnowflakeError( + f"partition {index} out of range (have {len(chunks)})" + ) + return httpx.Response(200, json={"data": chunks[index]}) + if state.remaining_polls > 0: + state.remaining_polls -= 1 + return httpx.Response(202, json=running_body(handle)) + return self._result_response(handle) + + def _result_response(self, handle: str) -> httpx.Response: + reg = self._state[handle].reg + body = ok_body( + reg.row_type, + reg.chunks[0], + partitions=len(reg.chunks), + handle=handle, + stats=reg.stats, + ) + return httpx.Response(200, json=body) + + +# --------------------------------------------------------------------------- +# Client factories +# --------------------------------------------------------------------------- + +_THROWAWAY_KEY_PEM: Optional[bytes] = None + + +def _throwaway_key() -> bytes: + """A cached throwaway RSA private key (PEM). The fake never validates it.""" + global _THROWAWAY_KEY_PEM + if _THROWAWAY_KEY_PEM is None: + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import rsa + + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + _THROWAWAY_KEY_PEM = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + return _THROWAWAY_KEY_PEM + + +#: Constructor args the factories set themselves; passing them via **kwargs would +#: raise an opaque duplicate-keyword TypeError, so reject them with a clear error. +_HELPER_MANAGED = ("private_key", "http_client") + + +def _reject_managed_kwargs(name: str, kwargs: Dict[str, Any]) -> None: + overlap = sorted(set(_HELPER_MANAGED) & kwargs.keys()) + if overlap: + raise FakeSnowflakeError( + f"{name} manages {overlap} internally; pass session/client options only" + ) + + +def make_client( + fake: FakeSnowflake, + *, + account: str = "testorg-testaccount", + user: str = "TEST_USER", + poll_interval: float = 0.0, + **kwargs: Any, +) -> "SnowflakeClient": + """Build a :class:`SnowflakeClient` wired to ``fake`` (no network). + + Extra keyword args pass straight through to the client constructor (but not + the args this factory sets itself; see ``_HELPER_MANAGED``). + """ + from .client import SnowflakeClient + + _reject_managed_kwargs("make_client", kwargs) + return SnowflakeClient( + account, + user, + private_key=_throwaway_key(), + http_client=httpx.Client(transport=fake.transport), + poll_interval=poll_interval, + **kwargs, + ) + + +def make_async_client( + fake: FakeSnowflake, + *, + account: str = "testorg-testaccount", + user: str = "TEST_USER", + poll_interval: float = 0.0, + **kwargs: Any, +) -> "AsyncSnowflakeClient": + """Build an :class:`AsyncSnowflakeClient` wired to ``fake`` (no network).""" + from .aclient import AsyncSnowflakeClient + + _reject_managed_kwargs("make_async_client", kwargs) + return AsyncSnowflakeClient( + account, + user, + private_key=_throwaway_key(), + http_client=httpx.AsyncClient(transport=fake.transport), + poll_interval=poll_interval, + **kwargs, + ) + + +# --------------------------------------------------------------------------- +# Pytest fixtures (auto-discovered via the pytest11 entry point in pyproject.toml) +# --------------------------------------------------------------------------- + +try: + import pytest + + @pytest.fixture + def fake_snowflake() -> FakeSnowflake: + """A fresh :class:`FakeSnowflake` registry.""" + return FakeSnowflake() + + @pytest.fixture + def snowflake_client(fake_snowflake: FakeSnowflake) -> "Iterator[SnowflakeClient]": + """A :class:`SnowflakeClient` wired to the ``fake_snowflake`` fixture.""" + client = make_client(fake_snowflake) + try: + yield client + finally: + client.close() + + try: + import pytest_asyncio + + @pytest_asyncio.fixture + async def async_snowflake_client( + fake_snowflake: FakeSnowflake, + ) -> "AsyncIterator[AsyncSnowflakeClient]": + """An :class:`AsyncSnowflakeClient` wired to ``fake_snowflake``.""" + client = make_async_client(fake_snowflake) + try: + yield client + finally: + await client.aclose() + + except ImportError: # pragma: no cover - async fixture is optional + pass + +except ImportError: # pragma: no cover - pytest is a test-only dependency + pass diff --git a/tests/support.py b/tests/support.py index 708bc97..84b52e3 100644 --- a/tests/support.py +++ b/tests/support.py @@ -1,8 +1,25 @@ -"""Shared constants and helpers for the HTTP-mocked test suites.""" +"""Shared constants and helpers for the HTTP-mocked test suites. + +The SQL API envelope builders (``ok_body`` / ``running_body``) live in the +shipped ``snowflake_sql_api.testing`` module so the test suite and the public +testing helper share one source of truth; they are re-exported here for the +existing respx-based tests. +""" from __future__ import annotations -from typing import Any, Dict, List, Optional +from snowflake_sql_api.testing import ok_body, running_body + +__all__ = [ + "ACCOUNT", + "USER", + "HOST", + "BASE_URL", + "STATEMENTS_URL", + "statement_url", + "ok_body", + "running_body", +] ACCOUNT = "xy12345.ap-southeast-2" USER = "test_user" @@ -13,48 +30,3 @@ def statement_url(handle: str) -> str: return f"{STATEMENTS_URL}/{handle}" - - -def ok_body( - row_type: List[Dict[str, Any]], - data: List[List[Any]], - *, - partitions: int = 1, - handle: str = "stmt-1", - stats: Optional[Dict[str, int]] = None, -) -> Dict[str, Any]: - """Build a 200 success body. ``partitions`` sets the partitionInfo length. - - Partition 0's ``rowCount`` reflects ``data``; the remaining entries are - placeholders (their rows are served by the partition GET routes). - """ - if partitions < 1: - raise ValueError("partitions must be >= 1") - partition_info = [{"rowCount": len(data)}] - partition_info += [{"rowCount": 0} for _ in range(partitions - 1)] - body: Dict[str, Any] = { - "resultSetMetaData": { - "numRows": len(data), - "format": "jsonv2", - "partitionInfo": partition_info, - "rowType": row_type, - }, - "data": data, - "code": "090001", - "sqlState": "00000", - "statementHandle": handle, - "statementStatusUrl": f"/api/v2/statements/{handle}", - } - if stats is not None: - body["stats"] = stats - return body - - -def running_body(handle: str = "stmt-1", code: str = "333334") -> Dict[str, Any]: - """Build a 202 'still running / submitted async' body.""" - return { - "code": code, - "message": "Asynchronous execution in progress.", - "statementHandle": handle, - "statementStatusUrl": f"/api/v2/statements/{handle}", - } diff --git a/tests/test_testing.py b/tests/test_testing.py new file mode 100644 index 0000000..aa7f047 --- /dev/null +++ b/tests/test_testing.py @@ -0,0 +1,341 @@ +"""Tests for the shipped testing helper (snowflake_sql_api.testing). + +Drives both the sync and async clients through ``FakeSnowflake`` with no network +and no respx, exercising the query/DML/async-submit/partition paths and the +native<->wire round trip for every coerced type. +""" + +from __future__ import annotations + +from datetime import date, datetime, time, timedelta, timezone +from decimal import Decimal + +import pytest + +from snowflake_sql_api import AsyncSnowflakeClient, SnowflakeClient +from snowflake_sql_api.exceptions import ResultNotReady, SnowflakeProgrammingError +from snowflake_sql_api.testing import ( + FakeSnowflake, + FakeSnowflakeError, + make_async_client, + make_client, +) + +# --------------------------------------------------------------------------- +# Sync client through FakeSnowflake +# --------------------------------------------------------------------------- + + +def test_query_returns_registered_rows() -> None: + fake = FakeSnowflake() + fake.register( + "SELECT id, name FROM users", + [{"ID": 1, "NAME": "alice"}, {"ID": 2, "NAME": "bob"}], + ) + with make_client(fake) as client: + assert client.query("SELECT id, name FROM users") == [ + {"ID": 1, "NAME": "alice"}, + {"ID": 2, "NAME": "bob"}, + ] + + +def test_make_client_returns_real_client() -> None: + fake = FakeSnowflake() + fake.register("SELECT 1", [{"N": 1}]) + client = make_client(fake) + assert isinstance(client, SnowflakeClient) + assert client.query_scalar("SELECT 1") == 1 + client.close() + + +def test_make_client_rejects_managed_kwargs() -> None: + fake = FakeSnowflake() + with pytest.raises(FakeSnowflakeError, match="manages"): + make_client(fake, private_key=b"x") + + +def test_query_with_positional_rows_and_columns() -> None: + fake = FakeSnowflake() + fake.register( + "SELECT id, name FROM t", + [[1, "alice"], [2, "bob"]], + columns=["ID", "NAME"], + ) + with make_client(fake) as client: + assert client.query("SELECT id, name FROM t") == [ + {"ID": 1, "NAME": "alice"}, + {"ID": 2, "NAME": "bob"}, + ] + + +def test_explicit_column_types() -> None: + fake = FakeSnowflake() + fake.register( + "SELECT amount FROM t", + [{"AMOUNT": Decimal("10.50")}], + columns=[{"name": "AMOUNT", "type": "fixed", "scale": 2}], + ) + with make_client(fake) as client: + assert client.query("SELECT amount FROM t") == [{"AMOUNT": Decimal("10.50")}] + + +def test_empty_result() -> None: + fake = FakeSnowflake() + fake.register("SELECT 1 WHERE 1=0", []) + with make_client(fake) as client: + assert client.query("SELECT 1 WHERE 1=0") == [] + assert client.query_one("SELECT 1 WHERE 1=0") is None + + +def test_multi_partition_preserves_order() -> None: + fake = FakeSnowflake() + rows = [{"N": n} for n in range(4)] + fake.register("SELECT n FROM big", rows, partitions=3) + with make_client(fake) as client: + assert client.query("SELECT n FROM big") == rows + + +def test_out_of_range_partition_raises() -> None: + fake = FakeSnowflake() + fake.register("SELECT n", [{"N": 1}]) + with make_client(fake) as client: + client.query("SELECT n") # creates the single-partition handle stmt-1 + with pytest.raises(FakeSnowflakeError, match="out of range"): + client._transport.get_statement("stmt-1", partition=99) + + +def test_register_dml_returns_rowcount() -> None: + fake = FakeSnowflake() + fake.register_dml("DELETE FROM t", 7) + with make_client(fake) as client: + assert client.execute("DELETE FROM t") == 7 + + +def test_insert_many_through_fake() -> None: + fake = FakeSnowflake() + sql = 'INSERT INTO "t" ("a", "b") VALUES (?, ?), (?, ?)' + fake.register_dml(sql, 2) + with make_client(fake) as client: + assert client.insert_many("t", ["a", "b"], [[1, "x"], [2, "y"]]) == 2 + assert fake.submitted_statements == [sql] + + +def test_register_error_raises_programming_error() -> None: + fake = FakeSnowflake() + fake.register_error("SELECT bad", "SQL compilation error", code="000904") + with make_client(fake) as client, pytest.raises(SnowflakeProgrammingError) as info: + client.query("SELECT bad") + assert info.value.code == "000904" + + +def test_register_match_predicate() -> None: + fake = FakeSnowflake() + fake.register_match(lambda sql: sql.startswith("SELECT count"), [{"N": 99}]) + with make_client(fake) as client: + assert client.query_scalar("SELECT count(*) FROM t") == 99 + + +def test_unregistered_statement_raises() -> None: + # Single `with` per statement here: a parenthesized multi-context `with` + # is 3.10+ syntax and breaks the 3.9 leg of the test matrix. + fake = FakeSnowflake() + client = make_client(fake) + with pytest.raises(FakeSnowflakeError, match="no result registered"): + client.query("SELECT nope") + client.close() + + +def test_submitted_statements_records_order() -> None: + fake = FakeSnowflake() + fake.register("SELECT 1", [{"N": 1}]) + fake.register("SELECT 2", [{"N": 2}]) + with make_client(fake) as client: + client.query("SELECT 1") + client.query("SELECT 2") + assert fake.submitted_statements == ["SELECT 1", "SELECT 2"] + + +# --------------------------------------------------------------------------- +# Async-submit / polling behavior +# --------------------------------------------------------------------------- + + +def test_submit_then_result() -> None: + fake = FakeSnowflake() + fake.register("CALL proc()", [{"N": 1}]) + with make_client(fake) as client: + handle = client.submit("CALL proc()") + assert handle.status() == "SUCCESS" + assert handle.result() == [{"N": 1}] + + +def test_result_poll_false_raises_on_running() -> None: + fake = FakeSnowflake() + fake.register("CALL slow()", [{"N": 1}], polls_before_ready=1) + with make_client(fake) as client: + handle = client.submit("CALL slow()") + with pytest.raises(ResultNotReady): + handle.result(poll=False) + + +def test_query_polls_until_ready() -> None: + fake = FakeSnowflake() + fake.register("SELECT slow", [{"N": 42}], polls_before_ready=2) + with make_client(fake) as client: + assert client.query("SELECT slow") == [{"N": 42}] + + +def test_handle_cancel() -> None: + fake = FakeSnowflake() + fake.register("CALL slow()", [{"N": 1}], polls_before_ready=5) + with make_client(fake) as client: + handle = client.submit("CALL slow()") + handle.cancel() # should not raise + + +# --------------------------------------------------------------------------- +# Type round-trips (native -> wire -> coerced native) +# --------------------------------------------------------------------------- + +TZ = timezone(timedelta(hours=5, minutes=30)) + +ROUND_TRIP = { + "an_int": 42, + "a_float": 1.5, + "a_str": "hello", + "a_bool": True, + "a_false": False, + "a_decimal": Decimal("12.34"), + "a_date": date(2023, 1, 1), + "a_time": time(23, 1, 59, 456789), + "a_ntz": datetime(2023, 1, 1, 12, 30, 45), + "a_tz": datetime(2023, 1, 1, 12, 30, 45, tzinfo=TZ), + "a_variant_obj": {"k": "v", "n": 1}, + "a_variant_arr": [1, 2, 3], + "a_binary": b"\x01\x02\xff", +} + + +@pytest.mark.parametrize("name, value", list(ROUND_TRIP.items())) +def test_value_round_trips(name: str, value: object) -> None: + fake = FakeSnowflake() + fake.register("SELECT v", [{"V": value}]) + with make_client(fake) as client: + assert client.query_scalar("SELECT v") == value + + +@pytest.mark.parametrize( + "value", + [ + Decimal("12.34"), + Decimal("100"), + Decimal("100.0"), + Decimal("1E+2"), + Decimal("-0.5"), + ], +) +def test_decimal_values_round_trip(value: Decimal) -> None: + # Scientific-notation Decimals (Decimal("1E+2")) must encode as plain digits, + # not "1E+2", or the decoder's int()/Decimal() parse blows up. + fake = FakeSnowflake() + fake.register("SELECT d", [{"D": value}]) + with make_client(fake) as client: + assert client.query_scalar("SELECT d") == value + + +def test_null_in_typed_column() -> None: + fake = FakeSnowflake() + fake.register("SELECT n FROM t", [{"N": 1}, {"N": None}]) + with make_client(fake) as client: + assert client.query("SELECT n FROM t") == [{"N": 1}, {"N": None}] + + +def test_timestamp_ltz_round_trips() -> None: + fake = FakeSnowflake() + aware = datetime(2023, 6, 1, 8, 0, 0, tzinfo=timezone.utc) + fake.register( + "SELECT ts", + [{"TS": aware}], + columns=[{"name": "TS", "type": "timestamp_ltz"}], + ) + with make_client(fake) as client: + assert client.query_scalar("SELECT ts") == aware + + +@pytest.mark.regression +def test_regression_pre_epoch_subsecond_timestamps() -> None: + """Pre-1970 sub-second instants must round-trip. + + The encoder derives wire seconds/micros from a timedelta, whose fraction is + always positive; the decoder reads the fraction as signed. A naive split + produced ``-1.500000`` for -0.5s, decoding one second early. The fix carries + the sign into the fraction (``-0.500000``). + """ + ntz = datetime(1969, 12, 31, 23, 59, 59, 500000) + ltz = datetime(1969, 12, 31, 23, 59, 59, 500000, tzinfo=timezone.utc) + tz = datetime(1969, 12, 31, 23, 59, 59, 500000, tzinfo=TZ) + fake = FakeSnowflake() + fake.register( + "SELECT ts_ntz, ts_ltz, ts_tz", + [{"NTZ": ntz, "LTZ": ltz, "TZ": tz}], + columns=[ + {"name": "NTZ", "type": "timestamp_ntz"}, + {"name": "LTZ", "type": "timestamp_ltz"}, + {"name": "TZ", "type": "timestamp_tz"}, + ], + ) + with make_client(fake) as client: + assert client.query_one("SELECT ts_ntz, ts_ltz, ts_tz") == { + "NTZ": ntz, + "LTZ": ltz, + "TZ": tz, + } + + +# --------------------------------------------------------------------------- +# Async client through FakeSnowflake +# --------------------------------------------------------------------------- + + +async def test_async_query() -> None: + fake = FakeSnowflake() + fake.register("SELECT id FROM t", [{"ID": 1}, {"ID": 2}]) + client = make_async_client(fake) + assert isinstance(client, AsyncSnowflakeClient) + assert await client.query("SELECT id FROM t") == [{"ID": 1}, {"ID": 2}] + await client.aclose() + + +async def test_async_submit_and_result() -> None: + fake = FakeSnowflake() + fake.register("CALL proc()", [{"N": 5}], polls_before_ready=1) + async with make_async_client(fake) as client: + handle = await client.submit("CALL proc()") + assert await handle.result() == [{"N": 5}] + + +async def test_async_multi_partition() -> None: + fake = FakeSnowflake() + rows = [{"N": n} for n in range(5)] + fake.register("SELECT n FROM big", rows, partitions=2) + async with make_async_client(fake) as client: + assert await client.query("SELECT n FROM big") == rows + + +# --------------------------------------------------------------------------- +# Auto-registered pytest fixtures +# --------------------------------------------------------------------------- + + +def test_fixtures_sync( + fake_snowflake: FakeSnowflake, snowflake_client: SnowflakeClient +) -> None: + fake_snowflake.register("SELECT 1", [{"N": 1}]) + assert snowflake_client.query_scalar("SELECT 1") == 1 + + +async def test_fixtures_async( + fake_snowflake: FakeSnowflake, async_snowflake_client: AsyncSnowflakeClient +) -> None: + fake_snowflake.register("SELECT 2", [{"N": 2}]) + assert await async_snowflake_client.query_scalar("SELECT 2") == 2