Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ type s3Downloader interface {

// IndexDownloader downloads and validates the hansard SQLite search index from S3.
type IndexDownloader struct {
s3 s3Downloader
bucket string
s3 s3Downloader
bucket string
localPath string
}

// NewIndexDownloaderFromEnv constructs an IndexDownloader from environment variables.
Expand All @@ -51,7 +52,7 @@ func NewIndexDownloaderFromEnv(ctx context.Context) (*IndexDownloader, error) {

// NewIndexDownloader constructs an IndexDownloader from explicit dependencies (for testing).
func NewIndexDownloader(s3Client s3Downloader, bucket string) *IndexDownloader {
return &IndexDownloader{s3: s3Client, bucket: bucket}
return &IndexDownloader{s3: s3Client, bucket: bucket, localPath: localIndexPath}
}

// Download fetches s3://bucket/sqliteKey to /tmp/index.sqlite using streaming I/O,
Expand All @@ -67,7 +68,7 @@ func (d *IndexDownloader) Download(ctx context.Context, sqliteKey, expectedSHA25
}
defer out.Body.Close()

f, err := os.Create(localIndexPath)
f, err := os.Create(d.localPath)
if err != nil {
return "", fmt.Errorf("create local index file: %w", err)
}
Expand All @@ -86,11 +87,11 @@ func (d *IndexDownloader) Download(ctx context.Context, sqliteKey, expectedSHA25
return "", usecase.ErrChecksumMismatch
}

if err := verifySchemaVersion(localIndexPath); err != nil {
if err := verifySchemaVersion(d.localPath); err != nil {
return "", err
}

return localIndexPath, nil
return d.localPath, nil
}

func verifySchemaVersion(path string) error {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,27 +71,26 @@ func sha256hex(data []byte) string {

func TestIndexDownloader_HappyPath(t *testing.T) {
_, content := createTestDB(t, "v1")
localPath := filepath.Join(t.TempDir(), "index.sqlite")

expectedHash := sha256hex(content)
mock := &mockS3Downloader{body: content}

// Override localIndexPath to write to a temp location.
// The downloader writes to /tmp/index.sqlite; in tests we accept this.
d := &IndexDownloader{s3: mock, bucket: "test-bucket"}
d := &IndexDownloader{s3: mock, bucket: "test-bucket", localPath: localPath}
got, err := d.Download(context.Background(), "some/key", expectedHash)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != localIndexPath {
t.Errorf("want %q, got %q", localIndexPath, got)
if got != localPath {
t.Errorf("want %q, got %q", localPath, got)
}
}

func TestIndexDownloader_ChecksumMismatch(t *testing.T) {
_, content := createTestDB(t, "v1")

mock := &mockS3Downloader{body: content}
d := &IndexDownloader{s3: mock, bucket: "test-bucket"}
d := &IndexDownloader{s3: mock, bucket: "test-bucket", localPath: filepath.Join(t.TempDir(), "index.sqlite")}

_, err := d.Download(context.Background(), "some/key", "deadbeef")
if !errors.Is(err, usecase.ErrChecksumMismatch) {
Expand All @@ -104,7 +103,7 @@ func TestIndexDownloader_SchemaMismatch(t *testing.T) {

expectedHash := sha256hex(content)
mock := &mockS3Downloader{body: content}
d := &IndexDownloader{s3: mock, bucket: "test-bucket"}
d := &IndexDownloader{s3: mock, bucket: "test-bucket", localPath: filepath.Join(t.TempDir(), "index.sqlite")}

_, err := d.Download(context.Background(), "some/key", expectedHash)
if !errors.Is(err, usecase.ErrSchemaMismatch) {
Expand Down
Loading