diff --git a/extensions/tn_attestation/broadcast.go b/extensions/tn_attestation/broadcast.go new file mode 100644 index 000000000..31486adc4 --- /dev/null +++ b/extensions/tn_attestation/broadcast.go @@ -0,0 +1,132 @@ +package tn_attestation + +import ( + "context" + "fmt" + "net" + "net/url" + "strings" + + "github.com/trufnetwork/kwil-db/common" + rpcclient "github.com/trufnetwork/kwil-db/core/rpc/client" + userjsonrpc "github.com/trufnetwork/kwil-db/core/rpc/client/user/jsonrpc" + "github.com/trufnetwork/kwil-db/core/types" +) + +// ensureBroadcaster initializes the RPC client for transaction submission if not already set. +// Called during startup and leader acquisition to ensure the leader can broadcast sign_attestation +// transactions. Prefers extension-specific rpc_url config, falling back to node's RPC endpoint. +func (e *signerExtension) ensureBroadcaster(service *common.Service) { + if service == nil || service.LocalConfig == nil { + return + } + if e.Broadcaster() != nil && e.TxQueryClient() != nil { + return + } + + endpoint := "" + if cfg, ok := service.LocalConfig.Extensions[ExtensionName]; ok { + if v := strings.TrimSpace(cfg["rpc_url"]); v != "" { + endpoint = v + } + } + + if endpoint == "" && service.LocalConfig.RPC.ListenAddress != "" { + endpoint = service.LocalConfig.RPC.ListenAddress + } + if endpoint == "" { + e.Logger().Warn("tn_attestation: cannot build broadcaster (no rpc endpoint configured)") + return + } + + u, err := normalizeListenAddressForClient(endpoint) + if err != nil { + e.Logger().Warn("tn_attestation: invalid rpc endpoint", "endpoint", endpoint, "error", err) + return + } + + broadcaster, queryClient := makeBroadcasterFromURL(u) + e.setBroadcaster(broadcaster) + e.setTxQueryClient(queryClient) + e.startStatusWorker() +} + +// normalizeListenAddressForClient converts a server bind address (e.g., "0.0.0.0:8080") +// to a client-usable localhost URL. Needed because the extension runs on the same node +// but cannot connect to wildcard addresses like 0.0.0.0 or [::]. +func normalizeListenAddressForClient(listen string) (*url.URL, error) { + if listen == "" { + return nil, fmt.Errorf("empty listen address") + } + endpoint := listen + if !strings.HasPrefix(endpoint, "http://") && !strings.HasPrefix(endpoint, "https://") { + endpoint = "http://" + endpoint + } + u, err := url.Parse(endpoint) + if err != nil { + return nil, err + } + host, port, err := net.SplitHostPort(u.Host) + if err == nil { + clean := strings.Trim(host, "[]") + if clean == "" { + u.Host = net.JoinHostPort("127.0.0.1", port) + } else if ip := net.ParseIP(clean); ip != nil && ip.IsUnspecified() { + u.Host = net.JoinHostPort("127.0.0.1", port) + } + } + return u, nil +} + +// makeBroadcasterFromURL creates broadcaster and query client from the normalized RPC endpoint. +// Returns the same instance for both interfaces to share a single connection pool. +func makeBroadcasterFromURL(u *url.URL) (TxBroadcaster, TxQueryClient) { + client := userjsonrpc.NewClient(u) + br := &jsonRPCBroadcaster{client: client} + return br, br +} + +type jsonRPCBroadcaster struct { + client *userjsonrpc.Client +} + +func (b *jsonRPCBroadcaster) BroadcastTx(ctx context.Context, tx *types.Transaction, sync uint8) (types.Hash, *types.TxResult, error) { + mode := rpcclient.BroadcastWaitAccept + if sync == uint8(rpcclient.BroadcastWaitCommit) || sync == 1 { + mode = rpcclient.BroadcastWaitCommit + } + + hash, err := b.client.Broadcast(ctx, tx, mode) + if err != nil { + return types.Hash{}, nil, err + } + + if mode == rpcclient.BroadcastWaitAccept { + return hash, nil, nil + } + + resp, err := b.client.TxQuery(ctx, hash) + if err != nil { + return hash, nil, fmt.Errorf("tx query failed: %w", err) + } + if resp == nil || resp.Result == nil { + return hash, nil, fmt.Errorf("transaction result missing") + } + + return hash, resp.Result, nil +} + +func (b *jsonRPCBroadcaster) TxQuery(ctx context.Context, txHash types.Hash) (*types.TxQueryResponse, error) { + return b.client.TxQuery(ctx, txHash) +} + +// TxBroadcaster matches the subset of the JSON-RPC client used by the signing +// worker to inject transactions. +type TxBroadcaster interface { + BroadcastTx(ctx context.Context, tx *types.Transaction, sync uint8) (types.Hash, *types.TxResult, error) +} + +// TxQueryClient interface for querying transaction status. +type TxQueryClient interface { + TxQuery(ctx context.Context, txHash types.Hash) (*types.TxQueryResponse, error) +} diff --git a/extensions/tn_attestation/canonical.go b/extensions/tn_attestation/canonical.go new file mode 100644 index 000000000..0329e8d42 --- /dev/null +++ b/extensions/tn_attestation/canonical.go @@ -0,0 +1,114 @@ +package tn_attestation + +import ( + "bytes" + "crypto/sha256" + "encoding/binary" + "fmt" +) + +// CanonicalPayload represents the eight attestation fields stored in result_canonical. +// The byte layout mirrors the SQL migration: fixed-width integers followed by +// length-prefixed blobs (little-endian 4-byte prefixes for variable sections). +// +// Layout: +// +// 1 byte version +// 1 byte algorithm +// 8 bytes block height (big-endian) +// 4 + n data provider (length-prefixed) +// 4 + m stream ID (length-prefixed) +// 2 bytes action ID (big-endian) +// 4 + k arguments (length-prefixed) +// 4 + r result (length-prefixed) +type CanonicalPayload struct { + Version uint8 + Algorithm uint8 + BlockHeight uint64 + DataProvider []byte + StreamID []byte + ActionID uint16 + Args []byte + Result []byte + + raw []byte +} + +// ParseCanonicalPayload decodes the canonical payload into structured fields. +// The function validates every length prefix and returns descriptive errors so +// future maintainers can diagnose storage corruption quickly. +func ParseCanonicalPayload(data []byte) (*CanonicalPayload, error) { + if len(data) < 1+1+8+2 { + return nil, fmt.Errorf("canonical payload too short: got %d bytes", len(data)) + } + + cursor := 0 + payload := &CanonicalPayload{ + Version: data[cursor], + Algorithm: data[cursor+1], + } + cursor += 2 + + payload.BlockHeight = binary.BigEndian.Uint64(data[cursor : cursor+8]) + cursor += 8 + + var err error + if payload.DataProvider, cursor, err = readLengthPrefixed(data, cursor); err != nil { + return nil, fmt.Errorf("decode data_provider: %w", err) + } + if payload.StreamID, cursor, err = readLengthPrefixed(data, cursor); err != nil { + return nil, fmt.Errorf("decode stream_id: %w", err) + } + + if len(data) < cursor+2 { + return nil, fmt.Errorf("canonical payload truncated before action_id") + } + payload.ActionID = binary.BigEndian.Uint16(data[cursor : cursor+2]) + cursor += 2 + + if payload.Args, cursor, err = readLengthPrefixed(data, cursor); err != nil { + return nil, fmt.Errorf("decode args: %w", err) + } + if payload.Result, cursor, err = readLengthPrefixed(data, cursor); err != nil { + return nil, fmt.Errorf("decode result: %w", err) + } + + if cursor != len(data) { + return nil, fmt.Errorf("canonical payload has %d trailing bytes", len(data)-cursor) + } + + payload.raw = append(payload.raw[:0], data...) // ensure private copy + return payload, nil +} + +// SigningBytes returns the backing canonical bytes that must be covered by the +// validator's signature (fields 1 through 8). Callers should treat the slice as +// immutable. +func (p *CanonicalPayload) SigningBytes() []byte { + return p.raw +} + +// SigningDigest computes sha256(SigningBytes()) to match the on-chain verifier +// expectations. The digest is returned as a value to prevent accidental reuse of +// the backing slice. +func (p *CanonicalPayload) SigningDigest() [sha256.Size]byte { + return sha256.Sum256(p.SigningBytes()) +} + +// readLengthPrefixed decodes a little-endian uint32 length followed by that many bytes. +func readLengthPrefixed(data []byte, cursor int) ([]byte, int, error) { + if len(data) < cursor+4 { + return nil, cursor, fmt.Errorf("truncated length prefix at offset %d", cursor) + } + + length := binary.LittleEndian.Uint32(data[cursor : cursor+4]) + cursor += 4 + + if len(data) < cursor+int(length) { + return nil, cursor, fmt.Errorf("declared length %d exceeds remaining %d bytes", length, len(data)-cursor) + } + + chunk := data[cursor : cursor+int(length)] + cursor += int(length) + return bytes.Clone(chunk), cursor, nil +} diff --git a/extensions/tn_attestation/canonical_test.go b/extensions/tn_attestation/canonical_test.go new file mode 100644 index 000000000..5d5c08f7a --- /dev/null +++ b/extensions/tn_attestation/canonical_test.go @@ -0,0 +1,94 @@ +package tn_attestation + +import ( + "bytes" + "crypto/sha256" + "encoding/binary" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseCanonicalPayload_Success(t *testing.T) { + version := uint8(1) + algo := uint8(1) + height := uint64(12345) + actionID := uint16(9) + dataProvider := []byte("provider-1") + streamID := []byte("stream-xyz") + args := []byte{0x01, 0x02, 0x03} + result := []byte{0xAA, 0xBB} + + raw := buildCanonical(version, algo, height, dataProvider, streamID, actionID, args, result) + + payload, err := ParseCanonicalPayload(raw) + require.NoError(t, err) + require.NotNil(t, payload) + + require.Equal(t, version, payload.Version) + require.Equal(t, algo, payload.Algorithm) + require.Equal(t, height, payload.BlockHeight) + require.Equal(t, dataProvider, payload.DataProvider) + require.Equal(t, streamID, payload.StreamID) + require.Equal(t, actionID, payload.ActionID) + require.Equal(t, args, payload.Args) + require.Equal(t, result, payload.Result) + + // Signing digest should equal sha256(raw) + expectedDigest := sha256.Sum256(raw) + require.Equal(t, expectedDigest, payload.SigningDigest()) + require.True(t, bytes.Equal(raw, payload.SigningBytes())) +} + +func TestParseCanonicalPayload_TruncatedPrefix(t *testing.T) { + base := buildCanonical(1, 1, 1, []byte("a"), []byte("b"), 1, []byte{0x01}, []byte{0x02}) + // Corrupt by chopping last byte + corrupted := base[:len(base)-1] + + _, err := ParseCanonicalPayload(corrupted) + require.Error(t, err) + require.Contains(t, err.Error(), "decode result") +} + +func TestParseCanonicalPayload_ExtraBytes(t *testing.T) { + base := buildCanonical(1, 1, 1, []byte("a"), []byte("b"), 1, []byte{0x01}, []byte{0x02}) + extra := append(base, []byte{0xFF, 0xFF}...) + + _, err := ParseCanonicalPayload(extra) + require.Error(t, err) + require.Contains(t, err.Error(), "trailing bytes") +} + +// buildCanonical mirrors the SQL encoder to generate canonical payloads. +func buildCanonical(version, algo uint8, height uint64, provider, stream []byte, actionID uint16, args, result []byte) []byte { + buf := bytes.NewBuffer(nil) + buf.WriteByte(version) + buf.WriteByte(algo) + + heightBytes := make([]byte, 8) + binary.BigEndian.PutUint64(heightBytes, height) + buf.Write(heightBytes) + + lengthBytes := make([]byte, 4) + binary.LittleEndian.PutUint32(lengthBytes, uint32(len(provider))) + buf.Write(lengthBytes) + buf.Write(provider) + + binary.LittleEndian.PutUint32(lengthBytes, uint32(len(stream))) + buf.Write(lengthBytes) + buf.Write(stream) + + actionBytes := make([]byte, 2) + binary.BigEndian.PutUint16(actionBytes, actionID) + buf.Write(actionBytes) + + binary.LittleEndian.PutUint32(lengthBytes, uint32(len(args))) + buf.Write(lengthBytes) + buf.Write(args) + + binary.LittleEndian.PutUint32(lengthBytes, uint32(len(result))) + buf.Write(lengthBytes) + buf.Write(result) + + return buf.Bytes() +} diff --git a/extensions/tn_attestation/constants.go b/extensions/tn_attestation/constants.go index 183b71dc4..fec8a15f9 100644 --- a/extensions/tn_attestation/constants.go +++ b/extensions/tn_attestation/constants.go @@ -1,3 +1,4 @@ package tn_attestation +// ExtensionName is the identifier for the attestation signing extension. const ExtensionName = "tn_attestation" diff --git a/extensions/tn_attestation/doc.go b/extensions/tn_attestation/doc.go new file mode 100644 index 000000000..339b775de --- /dev/null +++ b/extensions/tn_attestation/doc.go @@ -0,0 +1,15 @@ +// Package tn_attestation implements the attestation signing workflow for TN validators. +// +// When a user requests an attestation via request_attestation (SQL), the extension: +// 1. Queues the hash via queue_for_signing precompile (non-deterministic, leader-only) +// 2. Processes queued hashes on leader's EndBlock +// 3. Signs canonical payloads using the validator's secp256k1 key +// 4. Broadcasts sign_attestation transactions back to consensus +// +// Key components: +// - ValidatorSigner: Thread-safe secp256k1 signing with EVM compatibility +// - CanonicalPayload: Structured representation of the 8-field attestation format +// - Leader callbacks: OnAcquire, OnLose, OnEndBlock lifecycle hooks +// +// Initialize the extension by calling InitializeExtension() during node startup. +package tn_attestation diff --git a/extensions/tn_attestation/extension.go b/extensions/tn_attestation/extension.go new file mode 100644 index 000000000..1d7079c0f --- /dev/null +++ b/extensions/tn_attestation/extension.go @@ -0,0 +1,286 @@ +package tn_attestation + +import ( + "context" + "fmt" + "sync" + + "github.com/trufnetwork/kwil-db/common" + "github.com/trufnetwork/kwil-db/core/crypto/auth" + "github.com/trufnetwork/kwil-db/core/log" + sql "github.com/trufnetwork/kwil-db/node/types/sql" +) + +// signerExtension captures node-level wiring required for the attestation signer. +// The struct will evolve as we thread additional dependencies (engine, accounts, +// signer, broadcaster, etc.) through the extension during subsequent steps. +type signerExtension struct { + logger log.Logger + service *common.Service + + scanIntervalBlocks int64 + scanBatchLimit int64 + lastScanHeight int64 + isLeader bool + + engine common.Engine + db sql.DB + accounts common.Accounts + + broadcaster TxBroadcaster + txQueryClient TxQueryClient + nodeSigner auth.Signer + + statusOnce sync.Once + statusQueue chan txStatusWork + + processOverride func(context.Context, []string) + + mu sync.RWMutex +} + +var ( + extensionOnce sync.Once + extensionInst *signerExtension +) + +// getExtension returns the singleton instance, initialising it lazily so tests +// can replace or reset state as needed. +func getExtension() *signerExtension { + extensionOnce.Do(func() { + extensionInst = &signerExtension{ + logger: log.New(log.WithLevel(log.LevelInfo)).New(ExtensionName), + scanIntervalBlocks: 100, + scanBatchLimit: 100, + } + }) + return extensionInst +} + +// SetExtension allows tests to inject a pre-configured instance. +func SetExtension(ext *signerExtension) { + extensionInst = ext +} + +// Logger provides the extension logger, defaulting to a module-specific child of +// the global logger. +func (e *signerExtension) Logger() log.Logger { + e.mu.RLock() + defer e.mu.RUnlock() + return e.logger +} + +// Service retrieves the cached service pointer. The service includes configs, +// identity, and logger; storing it lets the extension re-use those resources +// outside hook invocations. +func (e *signerExtension) Service() *common.Service { + e.mu.RLock() + defer e.mu.RUnlock() + return e.service +} + +// setService captures the service and refreshes the module logger. +func (e *signerExtension) setService(svc *common.Service) { + e.mu.Lock() + defer e.mu.Unlock() + e.service = svc + if svc != nil && svc.Logger != nil { + e.logger = svc.Logger.New(ExtensionName) + } +} + +// applyConfig reads extension-specific config values from service.LocalConfig.Extensions[ExtensionName]. +// Supports scan_interval_blocks and scan_batch_limit for tuning fallback DB scan behavior. +func (e *signerExtension) applyConfig(service *common.Service) { + if service == nil || service.LocalConfig == nil { + return + } + if cfg, ok := service.LocalConfig.Extensions[ExtensionName]; ok { + if v, ok := cfg["scan_interval_blocks"]; ok && v != "" { + if parsed, err := parsePositiveInt64(v); err == nil { + e.setScanIntervalBlocks(parsed) + } else { + e.Logger().Warn("invalid scan_interval_blocks; using default", "value", v, "error", err) + } + } + if v, ok := cfg["scan_batch_limit"]; ok && v != "" { + if parsed, err := parsePositiveInt64(v); err == nil { + e.setScanBatchLimit(parsed) + } else { + e.Logger().Warn("invalid scan_batch_limit; using default", "value", v, "error", err) + } + } + } +} + +// setApp captures references to engine, database, and accounts subsystems from the app. +// Called during hook registration to wire runtime dependencies. +func (e *signerExtension) setApp(app *common.App) { + e.mu.Lock() + defer e.mu.Unlock() + if app != nil { + e.engine = app.Engine + e.db = app.DB + e.accounts = app.Accounts + } +} + +func (e *signerExtension) Engine() common.Engine { + e.mu.RLock() + defer e.mu.RUnlock() + return e.engine +} + +func (e *signerExtension) DB() sql.DB { + e.mu.RLock() + defer e.mu.RUnlock() + return e.db +} + +func (e *signerExtension) Accounts() common.Accounts { + e.mu.RLock() + defer e.mu.RUnlock() + return e.accounts +} + +func (e *signerExtension) setBroadcaster(b TxBroadcaster) { + e.mu.Lock() + defer e.mu.Unlock() + e.broadcaster = b +} + +func (e *signerExtension) Broadcaster() TxBroadcaster { + e.mu.RLock() + defer e.mu.RUnlock() + return e.broadcaster +} + +func (e *signerExtension) setTxQueryClient(q TxQueryClient) { + e.mu.Lock() + defer e.mu.Unlock() + e.txQueryClient = q +} + +func (e *signerExtension) TxQueryClient() TxQueryClient { + e.mu.RLock() + defer e.mu.RUnlock() + return e.txQueryClient +} + +func (e *signerExtension) setNodeSigner(s auth.Signer) { + e.mu.Lock() + defer e.mu.Unlock() + e.nodeSigner = s +} + +func (e *signerExtension) NodeSigner() auth.Signer { + e.mu.RLock() + defer e.mu.RUnlock() + return e.nodeSigner +} + +func (e *signerExtension) setProcessOverride(fn func(context.Context, []string)) { + e.mu.Lock() + defer e.mu.Unlock() + e.processOverride = fn +} + +func (e *signerExtension) setLeader(isLeader bool, height int64) { + e.mu.Lock() + defer e.mu.Unlock() + e.isLeader = isLeader + if isLeader && height > 0 { + e.lastScanHeight = height + } +} + +func (e *signerExtension) Leader() bool { + e.mu.RLock() + defer e.mu.RUnlock() + return e.isLeader +} + +func (e *signerExtension) setScanIntervalBlocks(v int64) { + e.mu.Lock() + defer e.mu.Unlock() + if v > 0 { + e.scanIntervalBlocks = v + } +} + +func (e *signerExtension) setScanBatchLimit(v int64) { + e.mu.Lock() + defer e.mu.Unlock() + if v > 0 { + e.scanBatchLimit = v + } +} + +func (e *signerExtension) ScanIntervalBlocks() int64 { + e.mu.RLock() + defer e.mu.RUnlock() + if e.scanIntervalBlocks <= 0 { + return 100 + } + return e.scanIntervalBlocks +} + +func (e *signerExtension) ScanBatchLimit() int64 { + e.mu.RLock() + defer e.mu.RUnlock() + if e.scanBatchLimit <= 0 { + return 100 + } + return e.scanBatchLimit +} + +func (e *signerExtension) recordScanHeight(height int64) { + e.mu.Lock() + defer e.mu.Unlock() + if height > e.lastScanHeight { + e.lastScanHeight = height + } +} + +func (e *signerExtension) LastScanHeight() int64 { + e.mu.RLock() + defer e.mu.RUnlock() + return e.lastScanHeight +} + +func (e *signerExtension) shouldPerformScan(height int64) bool { + interval := e.ScanIntervalBlocks() + if interval <= 0 { + return false + } + + last := e.LastScanHeight() + if last == 0 { + e.recordScanHeight(height) + return true + } + + if height-last >= interval { + e.recordScanHeight(height) + return true + } + return false +} + +func parsePositiveInt64(raw string) (int64, error) { + var v int64 + _, err := fmt.Sscan(raw, &v) + if err != nil { + return 0, err + } + if v <= 0 { + return 0, fmt.Errorf("value must be positive, got %d", v) + } + return v, nil +} + +// getTxQueryClient retrieves the query client for transaction status polling. +// Returns nil if broadcaster not yet initialized. +func (e *signerExtension) getTxQueryClient() TxQueryClient { + return e.TxQueryClient() +} diff --git a/extensions/tn_attestation/harness_integration_test.go b/extensions/tn_attestation/harness_integration_test.go new file mode 100644 index 000000000..75da23e29 --- /dev/null +++ b/extensions/tn_attestation/harness_integration_test.go @@ -0,0 +1,438 @@ +//go:build kwiltest + +package tn_attestation + +import ( + "context" + "encoding/hex" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/trufnetwork/kwil-db/common" + "github.com/trufnetwork/kwil-db/config" + kcrypto "github.com/trufnetwork/kwil-db/core/crypto" + "github.com/trufnetwork/kwil-db/core/crypto/auth" + "github.com/trufnetwork/kwil-db/core/log" + ktypes "github.com/trufnetwork/kwil-db/core/types" + "github.com/trufnetwork/kwil-db/extensions/precompiles" + erc20shim "github.com/trufnetwork/kwil-db/node/exts/erc20-bridge/erc20" + orderedsync "github.com/trufnetwork/kwil-db/node/exts/ordered-sync" + kwilTesting "github.com/trufnetwork/kwil-db/testing" + databasesize "github.com/trufnetwork/node/extensions/database-size" + "github.com/trufnetwork/node/extensions/tn_cache" + "github.com/trufnetwork/node/extensions/tn_utils" + "github.com/trufnetwork/node/internal/migrations" + "github.com/trufnetwork/sdk-go/core/util" +) + +func init() { + // Register extension precompiles for tests (except tn_attestation which tests handle individually) + err := precompiles.RegisterInitializer(tn_cache.ExtensionName, tn_cache.InitializeCachePrecompile) + if err != nil { + panic("failed to register tn_cache precompiles: " + err.Error()) + } + + err = precompiles.RegisterInitializer(databasesize.ExtensionName, databasesize.InitializeDatabaseSizePrecompile) + if err != nil { + panic("failed to register database_size precompiles: " + err.Error()) + } + + tn_utils.InitializeExtension() + // Note: tn_attestation precompile is registered by individual tests via ensurePrecompileRegistered() +} + +func TestSigningWorkflowWithHarness(t *testing.T) { + // Integration test covering the complete production signing workflow: + // request_attestation (SQL) → prepareSigningWork (Go) → submitSignature (Go) + // → transaction marshaling → sign_attestation (SQL) with leader authorization. + // Tests that real migrations work correctly with transaction encoding/decoding. + const ( + testActionName = "harness_attestation_action" + testActionID = 21 + attestedValue = int64(9001) + ) + + // Reset extension singletons before test to avoid conflicts + orderedsync.ForTestingReset() + erc20shim.ForTestingResetSingleton() + erc20shim.ForTestingClearAllInstances(context.Background(), nil) + + // Ensure tn_attestation precompile is registered (needed for queue_for_signing in migrations). + // Track whether we registered it so we can clean up afterwards and not interfere with + // other tests that expect to perform the registration themselves. + registered := precompiles.RegisteredPrecompiles() + _, alreadyRegistered := registered[ExtensionName] + ensurePrecompileRegistered(t) + if !alreadyRegistered { + defer delete(precompiles.RegisteredPrecompiles(), ExtensionName) + } + + ownerAddr := util.Unsafe_NewEthereumAddressFromString("0x0000000000000000000000000000000000000a22") + requesterAddrValue := util.Unsafe_NewEthereumAddressFromString("0xabc0000000000000000000000000000000000a22") + requesterAddr := &requesterAddrValue + + options := &kwilTesting.Options{ + UseTestContainer: true, + SetupMetaStore: true, + } + + kwilTesting.RunSchemaTest(t, kwilTesting.SchemaTest{ + Name: "tn_attestation_signing_harness", + SeedScripts: migrations.GetSeedScriptPaths(), + Owner: ownerAddr.Address(), + FunctionTests: []kwilTesting.TestFunc{ + func(ctx context.Context, platform *kwilTesting.Platform) error { + platform.Deployer = ownerAddr.Bytes() + + // Provision a lightweight action and allowlist entry so the request + // mirrors production: the Go test intentionally exercises the exact SQL + // path that nodes run when users hit the public API. + require.NoError(t, setupTestAttestationAction(ctx, platform, testActionName, testActionID)) + + // Request the attestation through the live migration. This ensures the + // canonical payload we inspect later is produced by the SQL we ship. + dataProvider := []byte("provider-harness") + streamID := []byte("stream-harness") + argsBytes, err := tn_utils.EncodeActionArgs([]any{attestedValue}) + require.NoError(t, err) + + engineCtx := newHarnessEngineContext(ctx, platform, requesterAddr) + + var attestationHash []byte + _, err = platform.Engine.Call(engineCtx, platform.DB, "", "request_attestation", []any{ + dataProvider, + streamID, + testActionName, + argsBytes, + false, + int64(0), + }, func(row *common.Row) error { + if len(row.Values) != 1 { + return fmt.Errorf("expected single return value, got %d", len(row.Values)) + } + hash, ok := row.Values[0].([]byte) + if !ok { + return fmt.Errorf("expected BYTEA return, got %T", row.Values[0]) + } + attestationHash = append([]byte(nil), hash...) + return nil + }) + require.NoError(t, err, "request_attestation failed") + require.NotEmpty(t, attestationHash, "request_attestation should return attestation hash") + + // At this point we expect a single row inserted into the persisted + // table. Fetch it back and validate every column so future changes that + // alter canonical layout or metadata will trip this test. + stored := fetchAttestationRowHarness(t, ctx, platform, attestationHash) + require.Equal(t, attestationHash, stored.attestationHash) + require.Equal(t, requesterAddr.Bytes(), stored.requester) + require.NotEmpty(t, stored.resultCanonical, "canonical payload should be stored") + require.False(t, stored.encryptSig, "encrypt_sig must be false in MVP") + require.Nil(t, stored.signature, "signature should be NULL before signing") + require.Nil(t, stored.validatorPubKey, "validator_pubkey should be NULL before signing") + require.Nil(t, stored.signedHeight, "signed_height should be NULL before signing") + + // The canonical blob should round-trip through the Go parser; we assert + // the critical fields so the SQL encoder and Go decoder stay in lockstep. + payload, err := ParseCanonicalPayload(stored.resultCanonical) + require.NoError(t, err, "canonical payload should be parseable") + require.Equal(t, uint8(1), payload.Version) + require.Equal(t, uint8(1), payload.Algorithm) + require.Equal(t, dataProvider, payload.DataProvider) + require.Equal(t, streamID, payload.StreamID) + require.Equal(t, uint16(testActionID), payload.ActionID) + require.Equal(t, argsBytes, payload.Args) + require.NotEmpty(t, payload.Result, "query result should be stored") + + // Finally ensure we can derive the digest that the signing service uses; + // downstream tests rely on this helper, and this assertion guarantees the + // canonical format remains stable. + digest := payload.SigningDigest() + require.Len(t, digest, 32, "digest should be 32 bytes (SHA-256)") + + // Phase 2: Prepare signing work - validator generates signature + privateKey, _, err := kcrypto.GenerateSecp256k1Key(nil) + require.NoError(t, err) + + ResetValidatorSignerForTesting() + t.Cleanup(ResetValidatorSignerForTesting) + require.NoError(t, InitializeValidatorSigner(privateKey)) + validatorSigner := GetValidatorSigner() + require.NotNil(t, validatorSigner) + + nodeSigner := auth.GetNodeSigner(privateKey) + require.NotNil(t, nodeSigner) + pubKey, ok := nodeSigner.PubKey().(*kcrypto.Secp256k1PublicKey) + require.True(t, ok, "unexpected validator pubkey type") + + // Setup extension with real dependencies + service := &common.Service{ + Logger: log.DiscardLogger, + GenesisConfig: &config.GenesisConfig{ChainID: "attestation-harness"}, + LocalConfig: &config.Config{}, + } + + ext := getExtension() + ext.setService(service) + ext.setApp(&common.App{ + Engine: platform.Engine, + DB: platform.DB, + Accounts: &signerAccountsStub{}, + Service: service, + }) + ext.setNodeSigner(nodeSigner) + + hashHex := hex.EncodeToString(attestationHash) + prepared, err := ext.prepareSigningWork(ctx, hashHex) + require.NoError(t, err) + require.Len(t, prepared, 1, "expected one prepared signature") + + // Verify signature was generated correctly + require.Equal(t, hashHex, prepared[0].HashHex) + require.Equal(t, attestationHash, prepared[0].Hash) + require.Equal(t, requesterAddr.Bytes(), prepared[0].Requester) + require.Len(t, prepared[0].Signature, 65, "EVM signature is 65 bytes") + require.Equal(t, stored.createdHeight, prepared[0].CreatedHeight) + + // Phase 3: Submit signature via production flow (tests transaction marshaling) + const signHeight = int64(42) + + // Create test broadcaster that unmarshals transaction and executes sign_attestation + broadcaster := &harnessExecutingBroadcaster{ + t: t, + platform: platform, + pubKey: pubKey, + nodeSigner: nodeSigner, + signHeight: signHeight, + } + ext.setBroadcaster(broadcaster) + + // Use production submitSignature - this marshals the transaction, + // the broadcaster unmarshals it, and executes the real SQL action + err = ext.submitSignature(ctx, prepared[0]) + require.NoError(t, err, "submitSignature should succeed") + + // Verify the broadcaster was called and executed successfully + require.Equal(t, 1, broadcaster.calls, "should broadcast exactly once") + + // Verify signed state in database + signedRow := fetchAttestationRowHarness(t, ctx, platform, attestationHash) + require.NotNil(t, signedRow.signature, "signature should be recorded") + require.Equal(t, prepared[0].Signature, signedRow.signature) + require.NotNil(t, signedRow.validatorPubKey, "validator pubkey should be recorded") + require.Equal(t, nodeSigner.CompactID(), signedRow.validatorPubKey, "validator pubkey should match node signer identity") + require.NotNil(t, signedRow.signedHeight, "signed height should be recorded") + require.Equal(t, signHeight, *signedRow.signedHeight) + + return nil + }, + }, + }, options) +} + +// setupTestAttestationAction creates a test action and registers it in the attestation allowlist +func setupTestAttestationAction(ctx context.Context, platform *kwilTesting.Platform, actionName string, actionID int) error { + engineCtx := &common.EngineContext{ + TxContext: &common.TxContext{ + Ctx: ctx, + Signer: platform.Deployer, + Caller: string(platform.Deployer), + TxID: platform.Txid(), + BlockContext: &common.BlockContext{ + Height: 1, + }, + }, + OverrideAuthz: true, + } + + createAction := ` +CREATE OR REPLACE ACTION ` + actionName + `( + $value INT8 +) PUBLIC VIEW RETURNS TABLE(result INT8) { + RETURN NEXT $value; +};` + + if err := platform.Engine.Execute(engineCtx, platform.DB, createAction, nil, nil); err != nil { + return fmt.Errorf("create action: %w", err) + } + + insertAllowlist := ` +INSERT INTO attestation_actions(action_name, action_id) +VALUES ($action_name, $action_id) +ON CONFLICT (action_name) DO UPDATE SET action_id = EXCLUDED.action_id;` + + params := map[string]any{ + "action_name": actionName, + "action_id": actionID, + } + + if err := platform.Engine.Execute(engineCtx, platform.DB, insertAllowlist, params, nil); err != nil { + return fmt.Errorf("insert attestation action allowlist: %w", err) + } + + return nil +} + +func newHarnessEngineContext(ctx context.Context, platform *kwilTesting.Platform, requester *util.EthereumAddress) *common.EngineContext { + return &common.EngineContext{ + TxContext: &common.TxContext{ + Ctx: ctx, + Signer: requester.Bytes(), + Caller: requester.Address(), + TxID: platform.Txid(), + BlockContext: &common.BlockContext{ + Height: 1, + }, + }, + } +} + +func fetchAttestationRowHarness(t *testing.T, ctx context.Context, platform *kwilTesting.Platform, hash []byte) harnessAttestationRow { + engineCtx := &common.EngineContext{ + TxContext: &common.TxContext{ + Ctx: ctx, + Signer: platform.Deployer, + Caller: string(platform.Deployer), + TxID: platform.Txid(), + BlockContext: &common.BlockContext{ + Height: 1, + }, + }, + OverrideAuthz: true, + } + + var rowData harnessAttestationRow + err := platform.Engine.Execute(engineCtx, platform.DB, ` +SELECT requester, attestation_hash, result_canonical, encrypt_sig, signature, validator_pubkey, signed_height, created_height +FROM attestations +WHERE attestation_hash = $hash; +`, map[string]any{"hash": hash}, func(row *common.Row) error { + rowData.requester = append([]byte(nil), row.Values[0].([]byte)...) + rowData.attestationHash = append([]byte(nil), row.Values[1].([]byte)...) + rowData.resultCanonical = append([]byte(nil), row.Values[2].([]byte)...) + rowData.encryptSig = row.Values[3].(bool) + if row.Values[4] != nil { + rowData.signature = append([]byte(nil), row.Values[4].([]byte)...) + } + if row.Values[5] != nil { + rowData.validatorPubKey = append([]byte(nil), row.Values[5].([]byte)...) + } + if row.Values[6] != nil { + height := row.Values[6].(int64) + rowData.signedHeight = &height + } + rowData.createdHeight = row.Values[7].(int64) + return nil + }) + require.NoError(t, err) + return rowData +} + +type harnessAttestationRow struct { + requester []byte + attestationHash []byte + resultCanonical []byte + encryptSig bool + signature []byte + validatorPubKey []byte + signedHeight *int64 + createdHeight int64 +} + +type harnessExecutingBroadcaster struct { + t *testing.T + platform *kwilTesting.Platform + pubKey *kcrypto.Secp256k1PublicKey + nodeSigner auth.Signer + signHeight int64 + calls int +} + +func (b *harnessExecutingBroadcaster) BroadcastTx(ctx context.Context, tx *ktypes.Transaction, sync uint8) (ktypes.Hash, *ktypes.TxResult, error) { + b.calls++ + + // Parse transaction payload + payload := new(ktypes.ActionExecution) + if err := payload.UnmarshalBinary(tx.Body.Payload); err != nil { + return ktypes.Hash{}, nil, err + } + + require.Equal(b.t, "sign_attestation", payload.Action) + require.Len(b.t, payload.Arguments, 1) + require.Len(b.t, payload.Arguments[0], 4) + + // Decode arguments + hashBytes := b.decodeByteArg(payload.Arguments[0][0]) + requesterBytes := b.decodeByteArg(payload.Arguments[0][1]) + createdHeight := b.decodeInt64Arg(payload.Arguments[0][2]) + sigBytes := b.decodeByteArg(payload.Arguments[0][3]) + + // Get caller identifier for leader check + // For leader authorization to work, Signer must be the Ethereum address derived from the proposer's public key + signer := b.nodeSigner.CompactID() + caller, err := auth.GetNodeIdentifier(b.pubKey) + require.NoError(b.t, err) + + // Create engine context with leader as proposer + txCtx := &common.TxContext{ + Ctx: ctx, + BlockContext: &common.BlockContext{ + Height: b.signHeight, + Proposer: b.pubKey, + }, + Signer: signer, + Caller: caller, + TxID: b.platform.Txid(), + Authenticator: auth.Secp256k1Auth, + } + + // Execute real sign_attestation action from migrations + res, err := b.platform.Engine.Call( + &common.EngineContext{TxContext: txCtx}, + b.platform.DB, + "", + "sign_attestation", + []any{hashBytes, requesterBytes, createdHeight, sigBytes}, + func(*common.Row) error { return nil }, + ) + + require.NoError(b.t, err) + require.NotNil(b.t, res) + if res.Error != nil { + b.t.Fatalf("sign_attestation failed: %v", res.Error) + } + + return ktypes.Hash{}, &ktypes.TxResult{Code: uint32(ktypes.CodeOk)}, nil +} + +func (b *harnessExecutingBroadcaster) decodeByteArg(arg *ktypes.EncodedValue) []byte { + val, err := arg.Decode() + require.NoError(b.t, err) + switch typed := val.(type) { + case []byte: + return typed + case *[]byte: + require.NotNil(b.t, typed) + return *typed + default: + b.t.Fatalf("unexpected byte arg type %T", val) + return nil + } +} + +func (b *harnessExecutingBroadcaster) decodeInt64Arg(arg *ktypes.EncodedValue) int64 { + val, err := arg.Decode() + require.NoError(b.t, err) + switch typed := val.(type) { + case int64: + return typed + case *int64: + require.NotNil(b.t, typed) + return *typed + default: + b.t.Fatalf("unexpected int64 arg type %T", val) + return 0 + } +} diff --git a/extensions/tn_attestation/integration_test.go b/extensions/tn_attestation/integration_test.go new file mode 100644 index 000000000..ec1ae9660 --- /dev/null +++ b/extensions/tn_attestation/integration_test.go @@ -0,0 +1,363 @@ +package tn_attestation + +import ( + "bytes" + "context" + "encoding/hex" + "fmt" + "math/big" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/require" + "github.com/trufnetwork/kwil-db/common" + "github.com/trufnetwork/kwil-db/config" + kcrypto "github.com/trufnetwork/kwil-db/core/crypto" + "github.com/trufnetwork/kwil-db/core/crypto/auth" + "github.com/trufnetwork/kwil-db/core/log" + ktypes "github.com/trufnetwork/kwil-db/core/types" + "github.com/trufnetwork/kwil-db/node/types/sql" +) + +func TestSigningWorkflowIntegration(t *testing.T) { + t.Helper() + + t.Run("QueuePath", func(t *testing.T) { + runSigningIntegration(t, true) + }) + + t.Run("FallbackPath", func(t *testing.T) { + runSigningIntegration(t, false) + }) +} + +func runSigningIntegration(t *testing.T, useQueue bool) { + t.Helper() + + resetIntegrationState() + + privateKey, publicKey, err := kcrypto.GenerateSecp256k1Key(nil) + require.NoError(t, err) + require.NoError(t, InitializeValidatorSigner(privateKey)) + defer ResetValidatorSignerForTesting() + + // Canonical payload mirrors the SQL construction to exercise the full pipeline. + version := uint8(1) + algo := uint8(1) + height := uint64(77) + actionID := uint16(5) + dataProvider := []byte("provider-queue-flow") + streamID := []byte("stream-queue-flow") + args := []byte{0x01, 0x02} + result := []byte{0x03} + + canonical := buildCanonicalPayload(version, algo, height, dataProvider, streamID, actionID, args, result) + payload, err := ParseCanonicalPayload(canonical) + require.NoError(t, err) + + hash := computeAttestationHash(payload) + hashHex := hex.EncodeToString(hash[:]) + requester := []byte("requester-1") + + engine := &integrationEngineStub{ + rows: []*common.Row{ + { + Values: []any{ + hash[:], + requester, + canonical, + int64(123), + }, + }, + }, + hashRows: []*common.Row{ + {Values: []any{hashHex}}, + }, + } + + ext := getExtension() + ext.logger = log.DiscardLogger + ext.service = &common.Service{ + Logger: log.DiscardLogger, + Identity: publicKey.Bytes(), + GenesisConfig: &config.GenesisConfig{ + ChainID: "integration-test-chain", + }, + } + ext.engine = engine + ext.db = integrationDBStub{} + ext.accounts = &signerAccountsStub{} + ext.scanIntervalBlocks = 1 + ext.scanBatchLimit = 10 + ext.nodeSigner = auth.GetNodeSigner(privateKey) + + broadcaster := &captureBroadcaster{} + ext.broadcaster = broadcaster + + queue := GetAttestationQueue() + queue.Clear() + if useQueue { + queue.Enqueue(hashHex) + } + + ctx := context.Background() + records, err := ext.fetchUnsignedAttestations(ctx, hash[:]) + require.NoError(t, err) + require.Truef(t, engine.served, "engine.ExecuteWithoutEngineCtx was not invoked (statement=%s)", engine.lastStmt) + require.Lenf(t, records, 1, "fetchUnsignedAttestations returned no rows (statement=%s)", engine.lastStmt) + + prepared, err := ext.prepareSigningWork(ctx, hashHex) + require.NoError(t, err) + require.Len(t, prepared, 1, "expected signing work to be prepared") + + if useQueue { + ext.processAttestationHashes(ctx, []string{hashHex}) + } else { + hashes, err := ext.fetchPendingHashes(ctx, 10) + require.NoError(t, err) + require.Equal(t, []string{hashHex}, hashes) + ext.processAttestationHashes(ctx, hashes) + } + + require.Equal(t, 1, broadcaster.calls, "expected single broadcast") + require.NoError(t, broadcaster.lastErr) + require.Len(t, broadcaster.hashes, 1) + require.Equal(t, hashHex, broadcaster.hashes[0]) + require.Len(t, broadcaster.heights, 1) + require.Equal(t, int64(123), broadcaster.heights[0]) + require.Len(t, broadcaster.signatures, 1) + require.Len(t, broadcaster.signatures[0], 65, "expected 65-byte signature") +} + +func resetIntegrationState() { + extensionOnce = sync.Once{} + SetExtension(nil) + queueOnce = sync.Once{} + attestationQueueSingleton = nil +} + +type captureBroadcaster struct { + hashes []string + signatures [][]byte + heights []int64 + calls int + lastErr error +} + +func (b *captureBroadcaster) BroadcastTx(ctx context.Context, tx *ktypes.Transaction, sync uint8) (ktypes.Hash, *ktypes.TxResult, error) { + b.calls++ + + payload := new(ktypes.ActionExecution) + if err := payload.UnmarshalBinary(tx.Body.Payload); err != nil { + b.lastErr = err + return ktypes.Hash{}, nil, err + } + + if len(payload.Arguments) == 0 || len(payload.Arguments[0]) != 4 { + err := fmt.Errorf("unexpected argument shape") + b.lastErr = err + return ktypes.Hash{}, nil, err + } + + val, err := payload.Arguments[0][0].Decode() + if err != nil { + b.lastErr = err + return ktypes.Hash{}, nil, err + } + var hashBytes []byte + switch typed := val.(type) { + case []byte: + hashBytes = typed + case *[]byte: + if typed == nil { + err := fmt.Errorf("hash argument was null") + b.lastErr = err + return ktypes.Hash{}, nil, err + } + hashBytes = *typed + default: + err := fmt.Errorf("hash argument type %T", val) + b.lastErr = err + return ktypes.Hash{}, nil, err + } + b.hashes = append(b.hashes, hex.EncodeToString(hashBytes)) + + heightVal, err := payload.Arguments[0][2].Decode() + if err != nil { + b.lastErr = err + return ktypes.Hash{}, nil, err + } + var createdHeight int64 + switch typed := heightVal.(type) { + case int64: + createdHeight = typed + case *int64: + if typed == nil { + err := fmt.Errorf("created_height argument was null") + b.lastErr = err + return ktypes.Hash{}, nil, err + } + createdHeight = *typed + default: + err := fmt.Errorf("created_height argument type %T", heightVal) + b.lastErr = err + return ktypes.Hash{}, nil, err + } + b.heights = append(b.heights, createdHeight) + + sigVal, err := payload.Arguments[0][3].Decode() + if err != nil { + b.lastErr = err + return ktypes.Hash{}, nil, err + } + var sigBytes []byte + switch typed := sigVal.(type) { + case []byte: + sigBytes = typed + case *[]byte: + if typed == nil { + err := fmt.Errorf("signature argument was null") + b.lastErr = err + return ktypes.Hash{}, nil, err + } + sigBytes = *typed + default: + err := fmt.Errorf("signature argument type %T", sigVal) + b.lastErr = err + return ktypes.Hash{}, nil, err + } + b.signatures = append(b.signatures, bytes.Clone(sigBytes)) + + return ktypes.Hash{}, &ktypes.TxResult{Code: uint32(ktypes.CodeOk)}, nil +} + +type integrationEngineStub struct { + rows []*common.Row + hashRows []*common.Row + lastStmt string + served bool +} + +func (s *integrationEngineStub) Call(*common.EngineContext, sql.DB, string, string, []any, func(*common.Row) error) (*common.CallResult, error) { + panic("Call not implemented") +} + +func (s *integrationEngineStub) CallWithoutEngineCtx(context.Context, sql.DB, string, string, []any, func(*common.Row) error) (*common.CallResult, error) { + panic("CallWithoutEngineCtx not implemented") +} + +func (s *integrationEngineStub) Execute(*common.EngineContext, sql.DB, string, map[string]any, func(*common.Row) error) error { + panic("Execute not implemented") +} + +func (s *integrationEngineStub) ExecuteWithoutEngineCtx(ctx context.Context, db sql.DB, statement string, params map[string]any, fn func(*common.Row) error) error { + s.lastStmt = statement + if strings.Contains(statement, "GROUP BY attestation_hash") { + for _, row := range s.hashRows { + if err := fn(row); err != nil { + return err + } + } + return nil + } + for _, row := range s.rows { + s.served = true + if err := fn(row); err != nil { + return err + } + } + return nil +} + +type integrationDBStub struct{} + +func (integrationDBStub) Execute(context.Context, string, ...any) (*sql.ResultSet, error) { + return nil, nil +} + +func (integrationDBStub) BeginTx(context.Context) (sql.Tx, error) { + return nil, fmt.Errorf("transactions not supported in stub") +} + +type signerAccountsStub struct{} + +func (signerAccountsStub) Credit(context.Context, sql.Executor, *ktypes.AccountID, *big.Int) error { + return nil +} + +func (signerAccountsStub) Transfer(context.Context, sql.TxMaker, *ktypes.AccountID, *ktypes.AccountID, *big.Int) error { + return nil +} + +func (signerAccountsStub) GetAccount(context.Context, sql.Executor, *ktypes.AccountID) (*ktypes.Account, error) { + return nil, fmt.Errorf("not found") +} + +func (signerAccountsStub) ApplySpend(context.Context, sql.Executor, *ktypes.AccountID, *big.Int, int64) error { + return nil +} + +func buildCanonicalPayload(version, algo uint8, blockHeight uint64, dataProvider, streamID []byte, actionID uint16, args, result []byte) []byte { + versionBytes := []byte{version} + algoBytes := []byte{algo} + + heightBytes := make([]byte, 8) + binaryBigEndianPutUint64(heightBytes, blockHeight) + + actionBytes := make([]byte, 2) + binaryBigEndianPutUint16(actionBytes, actionID) + + segments := [][]byte{ + versionBytes, + algoBytes, + heightBytes, + lengthPrefixLittleEndian(dataProvider), + lengthPrefixLittleEndian(streamID), + actionBytes, + lengthPrefixLittleEndian(args), + lengthPrefixLittleEndian(result), + } + + var buf bytes.Buffer + for _, seg := range segments { + buf.Write(seg) + } + return buf.Bytes() +} + +func lengthPrefixLittleEndian(data []byte) []byte { + if data == nil { + data = []byte{} + } + prefixed := make([]byte, 4+len(data)) + binaryLittleEndianPutUint32(prefixed[:4], uint32(len(data))) + copy(prefixed[4:], data) + return prefixed +} + +func binaryBigEndianPutUint64(b []byte, v uint64) { + _ = b[7] + b[0] = byte(v >> 56) + b[1] = byte(v >> 48) + b[2] = byte(v >> 40) + b[3] = byte(v >> 32) + b[4] = byte(v >> 24) + b[5] = byte(v >> 16) + b[6] = byte(v >> 8) + b[7] = byte(v) +} + +func binaryBigEndianPutUint16(b []byte, v uint16) { + _ = b[1] + b[0] = byte(v >> 8) + b[1] = byte(v) +} + +func binaryLittleEndianPutUint32(b []byte, v uint32) { + _ = b[3] + b[0] = byte(v) + b[1] = byte(v >> 8) + b[2] = byte(v >> 16) + b[3] = byte(v >> 24) +} diff --git a/extensions/tn_attestation/precompile.go b/extensions/tn_attestation/precompile.go index a4e3811af..5286c8519 100644 --- a/extensions/tn_attestation/precompile.go +++ b/extensions/tn_attestation/precompile.go @@ -15,7 +15,7 @@ func registerPrecompile() error { return precompiles.RegisterPrecompile(ExtensionName, precompiles.Precompile{ // No cache needed: this precompile only affects leader's in-memory state, // which is intentionally non-deterministic. All validators return nil (deterministic). - Cache: nil, + Cache: nil, Methods: []precompiles.Method{ { Name: "queue_for_signing", diff --git a/extensions/tn_attestation/processor.go b/extensions/tn_attestation/processor.go new file mode 100644 index 000000000..60a954ea5 --- /dev/null +++ b/extensions/tn_attestation/processor.go @@ -0,0 +1,199 @@ +package tn_attestation + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "fmt" + "strings" + + "github.com/trufnetwork/kwil-db/common" +) + +type attestationRecord struct { + hash []byte + requester []byte + canonical []byte + createdHeight int64 +} + +// PreparedSignature captures the data needed to call sign_attestation once +// broadcasting is wired: the attestation hash, the generated signature, and +// metadata for logging and auditing. +type PreparedSignature struct { + HashHex string + Hash []byte + Requester []byte + Signature []byte + Payload *CanonicalPayload + CreatedHeight int64 +} + +func (e *signerExtension) fetchUnsignedAttestations(ctx context.Context, hash []byte) ([]attestationRecord, error) { + engine := e.Engine() + db := e.DB() + if engine == nil || db == nil { + return nil, fmt.Errorf("attestation extension not initialised with engine/db") + } + + // Returns multiple rows per hash: composite key is (hash, requester, created_height). + // Different requesters can request identical attestations. + records := []attestationRecord{} + err := engine.ExecuteWithoutEngineCtx( + ctx, + db, + `SELECT attestation_hash, requester, result_canonical, created_height + FROM attestations + WHERE attestation_hash = $hash AND signature IS NULL + ORDER BY created_height ASC`, + map[string]any{"hash": hash}, + func(row *common.Row) error { + if len(row.Values) < 4 { + return fmt.Errorf("unexpected attestation row format: got %d columns", len(row.Values)) + } + + rec := attestationRecord{ + hash: bytesClone(row.Values[0].([]byte)), + requester: bytesCloneOrNil(row.Values[1]), + canonical: bytesClone(row.Values[2].([]byte)), + createdHeight: row.Values[3].(int64), + } + records = append(records, rec) + return nil + }, + ) + if err != nil { + return nil, err + } + + return records, nil +} + +func (e *signerExtension) prepareSigningWork(ctx context.Context, hashHex string) ([]*PreparedSignature, error) { + hashHex = strings.TrimPrefix(strings.ToLower(strings.TrimSpace(hashHex)), "0x") + if hashHex == "" { + return nil, fmt.Errorf("attestation hash cannot be empty") + } + + hashBytes, err := hex.DecodeString(hashHex) + if err != nil { + return nil, fmt.Errorf("invalid attestation hash %q: %w", hashHex, err) + } + if len(hashBytes) != sha256.Size { + return nil, fmt.Errorf("attestation hash must be %d bytes, got %d", sha256.Size, len(hashBytes)) + } + + records, err := e.fetchUnsignedAttestations(ctx, hashBytes) + if err != nil { + return nil, err + } + if len(records) == 0 { + return nil, nil + } + + signer := GetValidatorSigner() + if signer == nil { + return nil, fmt.Errorf("validator signer not initialised") + } + + prepared := make([]*PreparedSignature, 0, len(records)) + for _, rec := range records { + payload, err := ParseCanonicalPayload(rec.canonical) + if err != nil { + return nil, fmt.Errorf("parse canonical payload: %w", err) + } + + // Validate stored hash matches caller inputs; SQL computes it from request parameters. + expectedHash := computeAttestationHash(payload) + if !bytes.Equal(expectedHash[:], rec.hash) { + return nil, fmt.Errorf("attestation hash mismatch: expected %x, db %x", expectedHash, rec.hash) + } + + digest := payload.SigningDigest() + signature, err := signer.SignDigest(digest[:]) + if err != nil { + return nil, fmt.Errorf("sign digest: %w", err) + } + + prepared = append(prepared, &PreparedSignature{ + HashHex: hashHex, + Hash: bytesClone(rec.hash), + Requester: bytesCloneOrNil(rec.requester), + Signature: signature, + Payload: payload, + CreatedHeight: rec.createdHeight, + }) + } + + return prepared, nil +} + +func (e *signerExtension) fetchPendingHashes(ctx context.Context, limit int) ([]string, error) { + engine := e.Engine() + db := e.DB() + if engine == nil || db == nil { + return nil, fmt.Errorf("attestation extension not initialised with engine/db") + } + if limit <= 0 { + limit = int(e.ScanBatchLimit()) + } + + hashes := make([]string, 0, limit) + err := engine.ExecuteWithoutEngineCtx( + ctx, + db, + `SELECT encode(attestation_hash, 'hex') AS hash + FROM attestations + WHERE signature IS NULL + GROUP BY attestation_hash + ORDER BY MIN(created_height) ASC + LIMIT $limit`, + map[string]any{"limit": limit}, + func(row *common.Row) error { + if len(row.Values) == 0 { + return nil + } + hash, ok := row.Values[0].(string) + if !ok { + return fmt.Errorf("unexpected hash column type %T", row.Values[0]) + } + hash = strings.TrimSpace(hash) + if hash != "" { + hashes = append(hashes, hash) + } + return nil + }, + ) + if err != nil { + return nil, err + } + return hashes, nil +} + +func computeAttestationHash(p *CanonicalPayload) [sha256.Size]byte { + buf := bytes.NewBuffer(nil) + buf.WriteByte(p.Version) + buf.WriteByte(p.Algorithm) + buf.Write(p.DataProvider) + buf.Write(p.StreamID) + + var actionBytes [2]byte + binary.BigEndian.PutUint16(actionBytes[:], p.ActionID) + buf.Write(actionBytes[:]) + buf.Write(p.Args) + + return sha256.Sum256(buf.Bytes()) +} + +func bytesClone(b []byte) []byte { + return bytes.Clone(b) +} + +func bytesCloneOrNil(v any) []byte { + if v == nil { + return nil + } + return bytes.Clone(v.([]byte)) +} diff --git a/extensions/tn_attestation/processor_test.go b/extensions/tn_attestation/processor_test.go new file mode 100644 index 000000000..f13f2ca95 --- /dev/null +++ b/extensions/tn_attestation/processor_test.go @@ -0,0 +1,374 @@ +package tn_attestation + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "fmt" + "math/big" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/trufnetwork/kwil-db/common" + "github.com/trufnetwork/kwil-db/config" + kwilcrypto "github.com/trufnetwork/kwil-db/core/crypto" + "github.com/trufnetwork/kwil-db/core/crypto/auth" + "github.com/trufnetwork/kwil-db/core/log" + ktypes "github.com/trufnetwork/kwil-db/core/types" + nodesql "github.com/trufnetwork/kwil-db/node/types/sql" +) + +func TestComputeAttestationHash(t *testing.T) { + const ( + version = uint8(1) + algorithm = uint8(1) + height = uint64(99) + actionID = uint16(7) + ) + dataProvider := []byte("provider") + streamID := []byte("stream") + args := []byte{0x01, 0x02} + result := []byte{0x03, 0x04} + + canonical := buildCanonical(version, algorithm, height, dataProvider, streamID, actionID, args, result) + payload, err := ParseCanonicalPayload(canonical) + require.NoError(t, err) + + expected := sha256.Sum256(buildHashMaterial(version, algorithm, dataProvider, streamID, actionID, args)) + actual := computeAttestationHash(payload) + assert.Equal(t, expected, actual) + + payload.raw = nil + actual = computeAttestationHash(payload) + assert.Equal(t, expected, actual) +} + +func TestPrepareSigningWork(t *testing.T) { + t.Cleanup(ResetValidatorSignerForTesting) + + privateKey, _, err := kwilcrypto.GenerateSecp256k1Key(nil) + require.NoError(t, err) + require.NoError(t, InitializeValidatorSigner(privateKey)) + + version := uint8(1) + algo := uint8(1) + height := uint64(77) + actionID := uint16(5) + dataProvider := []byte("provider-1") + streamID := []byte("stream-abc") + args := []byte{0x01, 0x02} + result := []byte{0xAA} + + canonical := buildCanonical(version, algo, height, dataProvider, streamID, actionID, args, result) + payload, err := ParseCanonicalPayload(canonical) + require.NoError(t, err) + + hash := computeAttestationHash(payload) + + engine := &stubEngine{ + rows: []*common.Row{ + { + Values: []any{ + hash[:], + []byte("requester"), + canonical, + int64(123), + }, + }, + }, + } + + ext := &signerExtension{ + logger: log.DiscardLogger, + scanIntervalBlocks: 100, + } + ext.setApp(&common.App{ + Engine: engine, + DB: stubDB{}, + }) + + prepared, err := ext.prepareSigningWork(context.Background(), hex.EncodeToString(hash[:])) + require.NoError(t, err) + require.Len(t, prepared, 1) + + ps := prepared[0] + assert.Equal(t, hash[:], ps.Hash) + assert.Equal(t, payload, ps.Payload) + assert.Equal(t, int64(123), ps.CreatedHeight) + assert.Len(t, ps.Signature, 65) +} + +func TestSubmitSignature(t *testing.T) { + t.Cleanup(ResetValidatorSignerForTesting) + + privateKey, _, err := kwilcrypto.GenerateSecp256k1Key(nil) + require.NoError(t, err) + require.NoError(t, InitializeValidatorSigner(privateKey)) + + version := uint8(1) + algo := uint8(1) + height := uint64(77) + actionID := uint16(5) + dataProvider := []byte("provider-1") + streamID := []byte("stream-abc") + args := []byte{0x01, 0x02} + result := []byte{0xAA} + + canonical := buildCanonical(version, algo, height, dataProvider, streamID, actionID, args, result) + payload, err := ParseCanonicalPayload(canonical) + require.NoError(t, err) + + hash := computeAttestationHash(payload) + + engine := &stubEngine{ + rows: []*common.Row{ + { + Values: []any{ + hash[:], + []byte("requester"), + canonical, + int64(123), + }, + }, + }, + } + + service := &common.Service{ + Logger: log.DiscardLogger, + GenesisConfig: &config.GenesisConfig{ChainID: "test-chain"}, + LocalConfig: &config.Config{}, + } + + accounts := &stubAccounts{ + acct: &ktypes.Account{Nonce: 7}, + } + + broadcaster := &recordingBroadcaster{} + + ext := &signerExtension{ + logger: log.DiscardLogger, + scanIntervalBlocks: 100, + } + ext.setService(service) + ext.setApp(&common.App{ + Engine: engine, + DB: stubDB{}, + Accounts: accounts, + Service: service, + }) + ext.setNodeSigner(auth.GetNodeSigner(privateKey)) + ext.setBroadcaster(broadcaster) + + prepared, err := ext.prepareSigningWork(context.Background(), hex.EncodeToString(hash[:])) + require.NoError(t, err) + require.Len(t, prepared, 1) + + err = ext.submitSignature(context.Background(), prepared[0]) + require.NoError(t, err) + + assert.Equal(t, 1, broadcaster.calls) + require.NotNil(t, broadcaster.lastTx) + var decoded ktypes.ActionExecution + require.NoError(t, decoded.UnmarshalBinary(broadcaster.lastTx.Body.Payload)) + assert.Equal(t, "main", decoded.Namespace) + assert.Equal(t, "sign_attestation", decoded.Action) + require.Len(t, decoded.Arguments, 1) + require.Len(t, decoded.Arguments[0], 4) + + hashBytes := decodeBytesArg(t, decoded.Arguments[0][0], "hash") + assert.Equal(t, hash[:], hashBytes) + + requesterBytes := decodeBytesArg(t, decoded.Arguments[0][1], "requester") + assert.Equal(t, []byte("requester"), requesterBytes) + + createdHeight := decodeInt64Arg(t, decoded.Arguments[0][2], "created_height") + assert.Equal(t, int64(123), createdHeight) + + signatureBytes := decodeBytesArg(t, decoded.Arguments[0][3], "signature") + assert.Len(t, signatureBytes, 65) +} + +func decodeBytesArg(t *testing.T, arg *ktypes.EncodedValue, fieldName string) []byte { + t.Helper() + val, err := arg.Decode() + require.NoError(t, err) + switch typed := val.(type) { + case []byte: + return typed + case *[]byte: + require.NotNil(t, typed, "%s argument pointer was nil", fieldName) + return *typed + default: + t.Fatalf("unexpected %s argument type %T", fieldName, val) + return nil + } +} + +func decodeInt64Arg(t *testing.T, arg *ktypes.EncodedValue, fieldName string) int64 { + t.Helper() + val, err := arg.Decode() + require.NoError(t, err) + switch typed := val.(type) { + case int64: + return typed + case *int64: + require.NotNil(t, typed, "%s argument pointer was nil", fieldName) + return *typed + default: + t.Fatalf("unexpected %s argument type %T", fieldName, val) + return 0 + } +} + +func buildHashMaterial(version, algo uint8, dataProvider, streamID []byte, actionID uint16, args []byte) []byte { + buf := bytes.NewBuffer(nil) + buf.WriteByte(version) + buf.WriteByte(algo) + buf.Write(dataProvider) + buf.Write(streamID) + + var actionBytes [2]byte + binary.BigEndian.PutUint16(actionBytes[:], actionID) + buf.Write(actionBytes[:]) + buf.Write(args) + + return buf.Bytes() +} + +func TestFetchPendingHashes(t *testing.T) { + ext := &signerExtension{ + logger: log.DiscardLogger, + scanIntervalBlocks: 100, + scanBatchLimit: 5, + } + + engine := &stubEngine{ + hashRows: []*common.Row{ + {Values: []any{"aaa"}}, + {Values: []any{"bbb"}}, + {Values: []any{"ccc"}}, + }, + } + + ext.setApp(&common.App{ + Engine: engine, + DB: stubDB{}, + }) + + hashes, err := ext.fetchPendingHashes(context.Background(), 2) + require.NoError(t, err) + assert.Equal(t, []string{"aaa", "bbb"}, hashes) + + hashes, err = ext.fetchPendingHashes(context.Background(), 0) + require.NoError(t, err) + assert.Equal(t, []string{"aaa", "bbb", "ccc"}, hashes) +} + +type stubEngine struct { + rows []*common.Row + hashRows []*common.Row +} + +func (s *stubEngine) Call(ctx *common.EngineContext, db nodesql.DB, namespace, action string, args []any, resultFn func(*common.Row) error) (*common.CallResult, error) { + panic("not implemented") +} + +func (s *stubEngine) CallWithoutEngineCtx(ctx context.Context, db nodesql.DB, namespace, action string, args []any, resultFn func(*common.Row) error) (*common.CallResult, error) { + panic("not implemented") +} + +func (s *stubEngine) Execute(ctx *common.EngineContext, db nodesql.DB, statement string, params map[string]any, fn func(*common.Row) error) error { + panic("not implemented") +} + +func (s *stubEngine) ExecuteWithoutEngineCtx(ctx context.Context, db nodesql.DB, statement string, params map[string]any, fn func(*common.Row) error) error { + rows := s.rows + if params != nil { + if limit, ok := params["limit"]; ok { + rows = s.hashRows + if n := toInt(limit); n >= 0 && n < len(rows) { + rows = rows[:n] + } + } + } else if strings.Contains(statement, "GROUP BY attestation_hash") { + // Fallback for callers that forget to pass params. + rows = s.hashRows + } + for _, row := range rows { + if err := fn(row); err != nil { + return err + } + } + return nil +} + +type stubDB struct{} + +func (stubDB) Execute(ctx context.Context, stmt string, args ...any) (*nodesql.ResultSet, error) { + return nil, fmt.Errorf("not implemented") +} + +func (stubDB) BeginTx(ctx context.Context) (nodesql.Tx, error) { + return nil, fmt.Errorf("not implemented") +} + +func toInt(v any) int { + switch val := v.(type) { + case int: + return val + case int32: + return int(val) + case int64: + return int(val) + case uint: + return int(val) + case uint32: + return int(val) + case uint64: + return int(val) + default: + return -1 + } +} + +type stubAccounts struct { + acct *ktypes.Account + err error +} + +func (s *stubAccounts) Credit(ctx context.Context, tx nodesql.Executor, account *ktypes.AccountID, balance *big.Int) error { + return nil +} + +func (s *stubAccounts) Transfer(ctx context.Context, tx nodesql.TxMaker, from, to *ktypes.AccountID, amt *big.Int) error { + return nil +} + +func (s *stubAccounts) GetAccount(ctx context.Context, tx nodesql.Executor, account *ktypes.AccountID) (*ktypes.Account, error) { + if s.err != nil { + return nil, s.err + } + if s.acct != nil { + return s.acct, nil + } + return nil, fmt.Errorf("account not found") +} + +func (s *stubAccounts) ApplySpend(ctx context.Context, tx nodesql.Executor, account *ktypes.AccountID, amount *big.Int, nonce int64) error { + return nil +} + +type recordingBroadcaster struct { + calls int + lastTx *ktypes.Transaction +} + +func (b *recordingBroadcaster) BroadcastTx(ctx context.Context, tx *ktypes.Transaction, sync uint8) (ktypes.Hash, *ktypes.TxResult, error) { + b.calls++ + b.lastTx = tx + return ktypes.Hash{}, &ktypes.TxResult{Code: uint32(ktypes.CodeOk)}, nil +} diff --git a/extensions/tn_attestation/signer.go b/extensions/tn_attestation/signer.go index 8caeec779..e7af096d3 100644 --- a/extensions/tn_attestation/signer.go +++ b/extensions/tn_attestation/signer.go @@ -42,12 +42,8 @@ func NewValidatorSigner(privateKey kwilcrypto.PrivateKey) (*ValidatorSigner, err }, nil } -// SignKeccak256 signs the keccak256 hash of the payload and returns a 65-byte EVM-compatible signature. -// The signature format is [R || S || V] where: -// - R: 32 bytes (signature R component) -// - S: 32 bytes (signature S component) -// - V: 1 byte (recovery ID, 27 or 28 for EVM compatibility) -func (s *ValidatorSigner) SignKeccak256(payload []byte) ([]byte, error) { +// SignDigest signs the provided 32-byte digest (already hashed) and returns a 65-byte EVM-compatible signature. +func (s *ValidatorSigner) SignDigest(digest []byte) ([]byte, error) { s.mu.RLock() defer s.mu.RUnlock() @@ -55,6 +51,26 @@ func (s *ValidatorSigner) SignKeccak256(payload []byte) ([]byte, error) { return nil, fmt.Errorf("private key not initialized") } + if len(digest) != crypto.DigestLength { + return nil, fmt.Errorf("digest must be %d bytes, got %d", crypto.DigestLength, len(digest)) + } + + signature, err := crypto.Sign(digest, s.privateKey) + if err != nil { + return nil, fmt.Errorf("failed to sign digest: %w", err) + } + + // Convert V from {0,1} to {27,28} for EVM compatibility. + signature[64] += 27 + return signature, nil +} + +// SignKeccak256 signs the keccak256 hash of the payload and returns a 65-byte EVM-compatible signature. +// The signature format is [R || S || V] where: +// - R: 32 bytes (signature R component) +// - S: 32 bytes (signature S component) +// - V: 1 byte (recovery ID, 27 or 28 for EVM compatibility) +func (s *ValidatorSigner) SignKeccak256(payload []byte) ([]byte, error) { if len(payload) == 0 { return nil, fmt.Errorf("payload cannot be empty") } @@ -63,17 +79,7 @@ func (s *ValidatorSigner) SignKeccak256(payload []byte) ([]byte, error) { hash := crypto.Keccak256Hash(payload) // Sign the hash using secp256k1 - signature, err := crypto.Sign(hash.Bytes(), s.privateKey) - if err != nil { - return nil, fmt.Errorf("failed to sign payload: %w", err) - } - - // crypto.Sign returns 65-byte signature [R || S || V] where V is 0 or 1 - // EVM's ecrecover expects V as 27 or 28, so convert: - // V=0 (even Y) → 27, V=1 (odd Y) → 28 - signature[64] += 27 - - return signature, nil + return s.SignDigest(hash.Bytes()) } // PublicKey returns the public key associated with this signer (for verification). diff --git a/extensions/tn_attestation/signer_test.go b/extensions/tn_attestation/signer_test.go index aa75a381c..ae1a6f99b 100644 --- a/extensions/tn_attestation/signer_test.go +++ b/extensions/tn_attestation/signer_test.go @@ -1,6 +1,7 @@ package tn_attestation import ( + "crypto/sha256" "fmt" "sync" "testing" @@ -30,6 +31,34 @@ func TestValidatorSigner(t *testing.T) { assert.Contains(t, err.Error(), "private key cannot be nil") }) + t.Run("SignDigest", func(t *testing.T) { + privateKey, _, err := kwilcrypto.GenerateSecp256k1Key(nil) + require.NoError(t, err) + + signer, err := NewValidatorSigner(privateKey) + require.NoError(t, err) + + digest := sha256.Sum256([]byte("attestation payload")) + + signature, err := signer.SignDigest(digest[:]) + require.NoError(t, err) + assert.NotNil(t, signature) + assert.Equal(t, 65, len(signature)) + }) + + t.Run("SignDigestInvalidLength", func(t *testing.T) { + privateKey, _, err := kwilcrypto.GenerateSecp256k1Key(nil) + require.NoError(t, err) + + signer, err := NewValidatorSigner(privateKey) + require.NoError(t, err) + + signature, err := signer.SignDigest([]byte{}) + assert.Error(t, err) + assert.Nil(t, signature) + assert.Contains(t, err.Error(), "digest must be 32 bytes") + }) + t.Run("SignKeccak256", func(t *testing.T) { // Generate a test private key privateKey, _, err := kwilcrypto.GenerateSecp256k1Key(nil) @@ -62,7 +91,7 @@ func TestValidatorSigner(t *testing.T) { assert.Contains(t, err.Error(), "payload cannot be empty") }) - t.Run("SignatureVerification", func(t *testing.T) { + t.Run("SignatureVerificationWithDigest", func(t *testing.T) { // Generate a test private key privateKey, _, err := kwilcrypto.GenerateSecp256k1Key(nil) require.NoError(t, err) @@ -72,9 +101,10 @@ func TestValidatorSigner(t *testing.T) { // Test payload payload := []byte("test attestation payload") + digest := sha256.Sum256(payload) - // Sign the payload - signature, err := signer.SignKeccak256(payload) + // Sign the digest + signature, err := signer.SignDigest(digest[:]) require.NoError(t, err) // Verify V byte is EVM-compatible (27 or 28) @@ -87,8 +117,7 @@ func TestValidatorSigner(t *testing.T) { copy(testSig, signature) testSig[64] -= 27 // Convert 27/28 → 0/1 - hash := crypto.Keccak256Hash(payload) - recoveredPubKey, err := crypto.SigToPub(hash.Bytes(), testSig) + recoveredPubKey, err := crypto.SigToPub(digest[:], testSig) require.NoError(t, err) // Verify the recovered public key matches the signer's public key @@ -125,7 +154,7 @@ func TestValidatorSigner(t *testing.T) { assert.True(t, address[:2] == "0x", "address should start with 0x") }) - t.Run("DeterministicSignature", func(t *testing.T) { + t.Run("DeterministicSignatureDigest", func(t *testing.T) { privateKey, _, err := kwilcrypto.GenerateSecp256k1Key(nil) require.NoError(t, err) @@ -133,12 +162,13 @@ func TestValidatorSigner(t *testing.T) { require.NoError(t, err) payload := []byte("deterministic test payload") + digest := sha256.Sum256(payload) // Sign the same payload twice - sig1, err := signer.SignKeccak256(payload) + sig1, err := signer.SignDigest(digest[:]) require.NoError(t, err) - sig2, err := signer.SignKeccak256(payload) + sig2, err := signer.SignDigest(digest[:]) require.NoError(t, err) // Signatures should be identical for the same payload and key @@ -161,7 +191,8 @@ func TestValidatorSigner(t *testing.T) { go func(idx int) { defer wg.Done() payload := []byte("concurrent test payload") - signature, err := signer.SignKeccak256(payload) + digest := sha256.Sum256(payload) + signature, err := signer.SignDigest(digest[:]) require.NoError(t, err) results[idx] = signature }(i) @@ -257,9 +288,10 @@ func TestEVMCompatibility(t *testing.T) { // Create test payload (simulating attestation structure) payload := []byte("version|algo|dataProvider|streamId|actionId|args|result") + digest := sha256.Sum256(payload) // Sign the payload - signature, err := signer.SignKeccak256(payload) + signature, err := signer.SignDigest(digest[:]) require.NoError(t, err) // Verify signature format is EVM-compatible @@ -270,16 +302,11 @@ func TestEVMCompatibility(t *testing.T) { assert.True(t, v == 27 || v == 28, "V must be 27 or 28 for EVM compatibility, got %d", v) // Recover the signer address from the signature (simulating Solidity ecrecover) - hash := crypto.Keccak256Hash(payload) - - // Note: Go's crypto.Ecrecover expects V as 0/1, but our signature has V as 27/28 (EVM format) - // In real usage, Solidity's ecrecover accepts 27/28 directly - // For testing with Go's crypto.Ecrecover, we need to convert back temporarily testSig := make([]byte, len(signature)) copy(testSig, signature) testSig[64] -= 27 // Convert 27/28 → 0/1 for Go's Ecrecover - recoveredPubKey, err := crypto.Ecrecover(hash.Bytes(), testSig) + recoveredPubKey, err := crypto.Ecrecover(digest[:], testSig) require.NoError(t, err) assert.NotNil(t, recoveredPubKey) @@ -303,7 +330,8 @@ func TestEVMCompatibility(t *testing.T) { // Test multiple signatures to ensure V is always 27 or 28 for i := 0; i < 10; i++ { payload := []byte(fmt.Sprintf("test payload %d", i)) - signature, err := signer.SignKeccak256(payload) + digest := sha256.Sum256(payload) + signature, err := signer.SignDigest(digest[:]) require.NoError(t, err) // Signature must be 65 bytes diff --git a/extensions/tn_attestation/status_worker_test.go b/extensions/tn_attestation/status_worker_test.go new file mode 100644 index 000000000..d76037aa5 --- /dev/null +++ b/extensions/tn_attestation/status_worker_test.go @@ -0,0 +1,141 @@ +package tn_attestation + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/trufnetwork/kwil-db/core/log" + ktypes "github.com/trufnetwork/kwil-db/core/types" +) + +type fakeTxQueryClient struct { + mu sync.Mutex + responses []*ktypes.TxQueryResponse + errs []error + expected int + calls int + done chan struct{} +} + +// newFakeTxQueryClient builds a deterministic TxQuery client that returns the provided responses/errors. +// expected indicates how many calls we should observe before closing the done channel; 0 auto-expands +// to the longer of the responses/errors slices (at least 1). +func newFakeTxQueryClient(resps []*ktypes.TxQueryResponse, errs []error, expected int) *fakeTxQueryClient { + if expected == 0 { + expected = len(resps) + if len(errs) > expected { + expected = len(errs) + } + if expected == 0 { + expected = 1 + } + } + return &fakeTxQueryClient{ + responses: resps, + errs: errs, + expected: expected, + done: make(chan struct{}), + } +} + +func (f *fakeTxQueryClient) TxQuery(ctx context.Context, txHash ktypes.Hash) (*ktypes.TxQueryResponse, error) { + f.mu.Lock() + defer f.mu.Unlock() + + var resp *ktypes.TxQueryResponse + if len(f.responses) > 0 { + resp = f.responses[0] + f.responses = f.responses[1:] + } + + var err error + if len(f.errs) > 0 { + err = f.errs[0] + f.errs = f.errs[1:] + } + + f.calls++ + if f.calls >= f.expected { + select { + case <-f.done: + default: + close(f.done) + } + } + + return resp, err +} + +func (f *fakeTxQueryClient) Calls() int { + f.mu.Lock() + defer f.mu.Unlock() + return f.calls +} + +func TestStatusWorkerProcessesSuccess(t *testing.T) { + origDelays := statusRetryDelays + delays := make([]time.Duration, statusMaxAttempts) + for i := range delays { + delays[i] = time.Millisecond + } + statusRetryDelays = delays + defer func() { statusRetryDelays = origDelays }() + + ext := &signerExtension{ + logger: log.DiscardLogger, + } + + client := newFakeTxQueryClient([]*ktypes.TxQueryResponse{ + {Height: 10, Result: &ktypes.TxResult{Code: uint32(ktypes.CodeOk)}}, + }, nil, 1) + + ext.setTxQueryClient(client) + ext.startStatusWorker() + ext.enqueueStatusCheck(ktypes.Hash{}, "success-attestation", []byte("requester")) + + select { + case <-client.done: + case <-time.After(time.Second): + t.Fatal("transaction status worker did not complete in time") + } + + require.Equal(t, 1, client.Calls()) + close(ext.statusQueue) +} + +func TestStatusWorkerRetriesOnFailure(t *testing.T) { + origDelays := statusRetryDelays + delays := make([]time.Duration, statusMaxAttempts) + for i := range delays { + delays[i] = time.Millisecond + } + statusRetryDelays = delays + defer func() { statusRetryDelays = origDelays }() + + ext := &signerExtension{ + logger: log.DiscardLogger, + } + + errs := make([]error, statusMaxAttempts) + for i := range errs { + errs[i] = fmt.Errorf("tx not found") + } + client := newFakeTxQueryClient(nil, errs, statusMaxAttempts) + + ext.setTxQueryClient(client) + ext.startStatusWorker() + ext.enqueueStatusCheck(ktypes.Hash{}, "fail-attestation", []byte("requester")) + + select { + case <-client.done: + case <-time.After(2 * time.Second): + t.Fatal("transaction status worker did not exhaust retries in time") + } + + require.Equal(t, statusMaxAttempts, client.Calls()) + close(ext.statusQueue) +} diff --git a/extensions/tn_attestation/tn_attestation.go b/extensions/tn_attestation/tn_attestation.go index af1331e85..04a761081 100644 --- a/extensions/tn_attestation/tn_attestation.go +++ b/extensions/tn_attestation/tn_attestation.go @@ -8,6 +8,7 @@ import ( appconf "github.com/trufnetwork/kwil-db/app/node/conf" "github.com/trufnetwork/kwil-db/common" "github.com/trufnetwork/kwil-db/config" + "github.com/trufnetwork/kwil-db/core/crypto/auth" "github.com/trufnetwork/kwil-db/extensions/hooks" "github.com/trufnetwork/node/extensions/leaderwatch" ) @@ -15,7 +16,7 @@ import ( // InitializeExtension registers the tn_attestation extension. // This includes: // - Registering the queue_for_signing() precompile -// - Registering leader watch callbacks for signing worker // TODO: WIP +// - Registering leader watch callbacks for the signing workflow func InitializeExtension() { // Register the precompile for queue_for_signing() method if err := registerPrecompile(); err != nil { @@ -27,7 +28,7 @@ func InitializeExtension() { panic(fmt.Sprintf("failed to register %s engine ready hook: %v", ExtensionName, err)) } - // Register leader watch callbacks (for Issue 6 - leader signing worker) + // Register leader watch callbacks for signing workflow lifecycle if err := leaderwatch.Register(ExtensionName, leaderwatch.Callbacks{ OnAcquire: onLeaderAcquire, OnLose: onLeaderLose, @@ -44,7 +45,12 @@ func engineReadyHook(ctx context.Context, app *common.App) error { return nil } - logger := app.Service.Logger.New(ExtensionName) + ext := getExtension() + ext.setService(app.Service) + ext.setApp(app) + ext.applyConfig(app.Service) + + logger := ext.Logger() // Load the validator's private key from the node key file rootDir := appconf.RootDir() @@ -70,6 +76,10 @@ func engineReadyHook(ctx context.Context, app *common.App) error { return fmt.Errorf("failed to initialize validator signer: %w", err) } + if ext.NodeSigner() == nil { + ext.setNodeSigner(auth.GetNodeSigner(privateKey)) + } + // Log the validator address for debugging signer := GetValidatorSigner() if signer != nil { @@ -78,6 +88,8 @@ func engineReadyHook(ctx context.Context, app *common.App) error { "validator_address", signer.Address()) } + ext.ensureBroadcaster(app.Service) + return nil } @@ -88,9 +100,21 @@ func onLeaderAcquire(ctx context.Context, app *common.App, block *common.BlockCo return } - logger := app.Service.Logger.New(ExtensionName) + ext := getExtension() + ext.setService(app.Service) + if block != nil { + ext.setLeader(true, block.Height) + } + + logger := ext.Logger() logger.Info("tn_attestation: acquired leadership") + queue := GetAttestationQueue() + if n := queue.Len(); n > 0 { + logger.Debug("tn_attestation: clearing residual attestation queue after leader acquisition", "dropped", n) + } + queue.Clear() + // TODO: Implement signing worker startup // Reference implementation: // ext := GetExtension() @@ -104,9 +128,21 @@ func onLeaderLose(ctx context.Context, app *common.App, block *common.BlockConte return } - logger := app.Service.Logger.New(ExtensionName) + ext := getExtension() + ext.setService(app.Service) + if block != nil { + ext.setLeader(false, block.Height) + } + + logger := ext.Logger() logger.Info("tn_attestation: lost leadership") + queue := GetAttestationQueue() + if n := queue.Len(); n > 0 { + logger.Debug("tn_attestation: clearing attestation queue on leader loss", "dropped", n) + } + queue.Clear() + // TODO: Implement signing worker shutdown // Reference implementation: // ext := GetExtension() @@ -114,37 +150,42 @@ func onLeaderLose(ctx context.Context, app *common.App, block *common.BlockConte } // onLeaderEndBlock is called on every EndBlock when the node is the leader. -// Currently dequeues and logs hashes to prevent unbounded memory growth. -// TODO: Implement actual signing and submission of attestations. +// It processes queued attestation hashes and, in later phases, also performs +// periodic scans to recover any missed notifications. func onLeaderEndBlock(ctx context.Context, app *common.App, block *common.BlockContext) { if app == nil || app.Service == nil { return } + ext := getExtension() + ext.setService(app.Service) + + if !ext.Leader() { + return + } + // Dequeue all pending attestation hashes to prevent unbounded growth queue := GetAttestationQueue() hashes := queue.DequeueAll() - // If there are hashes, log them (signing implementation pending in Issue 6) if len(hashes) > 0 { - logger := app.Service.Logger.New(ExtensionName) - logger.Info("tn_attestation: dequeued attestations for signing", + logger := ext.Logger() + logger.Info("tn_attestation: processing queued attestations", "count", len(hashes), - "block_height", block.Height, - "note", "signing implementation pending (Issue 6)") - - // TODO: Implement actual signing and submission - // Reference implementation: - // ext := GetExtension() - // for _, hash := range hashes { - // signature, err := ext.signAttestation(ctx, app, hash) - // if err != nil { - // logger.Error("failed to sign attestation", "hash", hash, "error", err) - // continue - // } - // if err := ext.submitSignature(ctx, app, hash, signature); err != nil { - // logger.Error("failed to submit signature", "hash", hash, "error", err) - // } - // } + "block_height", block.Height) + ext.processAttestationHashes(ctx, hashes) + } + + if block != nil && ext.shouldPerformScan(block.Height) { + logger := ext.Logger() + pending, err := ext.fetchPendingHashes(ctx, int(ext.ScanBatchLimit())) + if err != nil { + logger.Error("tn_attestation: fallback scan failed", "error", err) + } else if len(pending) > 0 { + logger.Info("tn_attestation: fallback scan found unsigned attestations", + "count", len(pending), + "block_height", block.Height) + ext.processAttestationHashes(ctx, pending) + } } } diff --git a/extensions/tn_attestation/tn_attestation_test.go b/extensions/tn_attestation/tn_attestation_test.go index 3379ff3ec..83fab249c 100644 --- a/extensions/tn_attestation/tn_attestation_test.go +++ b/extensions/tn_attestation/tn_attestation_test.go @@ -3,15 +3,18 @@ package tn_attestation import ( "context" "fmt" + "strings" "sync" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/trufnetwork/kwil-db/common" + "github.com/trufnetwork/kwil-db/config" "github.com/trufnetwork/kwil-db/core/crypto" "github.com/trufnetwork/kwil-db/core/log" "github.com/trufnetwork/kwil-db/extensions/precompiles" + nodesql "github.com/trufnetwork/kwil-db/node/types/sql" ) // ensurePrecompileRegistered ensures the precompile is registered before use. @@ -450,6 +453,129 @@ func TestEngineReadyHook(t *testing.T) { assert.Nil(t, signer, "signer should not be initialized without key file") }) - // Note: Full integration test with actual key loading is deferred to Issue #1209 - // where it will be tested as part of end-to-end leader signing workflow +} + +func TestLeaderLifecycleState(t *testing.T) { + original := getExtension() + SetExtension(&signerExtension{ + logger: log.DiscardLogger, + scanIntervalBlocks: 100, + scanBatchLimit: 100, + }) + defer SetExtension(original) + + app := &common.App{ + Service: &common.Service{ + Logger: log.DiscardLogger, + }, + } + block := &common.BlockContext{Height: 10} + + onLeaderAcquire(context.Background(), app, block) + ext := getExtension() + assert.True(t, ext.Leader(), "extension should mark leadership on acquire") + assert.Equal(t, int64(10), ext.LastScanHeight(), "last scan height should seed from acquire") + + queue := GetAttestationQueue() + queue.Clear() + queue.Enqueue("hashA") + queue.Enqueue("hashB") + + onLeaderEndBlock(context.Background(), app, &common.BlockContext{Height: 11}) + assert.Equal(t, 0, queue.Len(), "leader end block should dequeue hashes") + + onLeaderLose(context.Background(), app, &common.BlockContext{Height: 12}) + assert.False(t, ext.Leader(), "extension should unset leadership on lose") + + queue.Enqueue("hashC") + onLeaderEndBlock(context.Background(), app, &common.BlockContext{Height: 13}) + assert.Equal(t, 1, queue.Len(), "non-leader end block should not tamper with queue") + queue.Clear() +} + +func TestOnLeaderEndBlockFallbackScan(t *testing.T) { + original := getExtension() + queue := GetAttestationQueue() + queue.Clear() + + ext := &signerExtension{ + logger: log.DiscardLogger, + scanIntervalBlocks: 1, + scanBatchLimit: 5, + } + SetExtension(ext) + t.Cleanup(func() { + SetExtension(original) + queue.Clear() + }) + + service := &common.Service{ + Logger: log.DiscardLogger, + LocalConfig: &config.Config{}, + } + + engine := &fallbackEngine{hashes: []string{"abc"}} + ext.setService(service) + ext.setApp(&common.App{ + Service: service, + Engine: engine, + DB: fallbackDB{}, + }) + ext.setLeader(true, 0) + + processed := make([][]string, 0) + ext.setProcessOverride(func(_ context.Context, hashes []string) { + processed = append(processed, append([]string(nil), hashes...)) + }) + + onLeaderEndBlock(context.Background(), &common.App{Service: service}, &common.BlockContext{Height: 1}) + + require.Len(t, processed, 1) + assert.Equal(t, []string{"abc"}, processed[0]) + assert.True(t, strings.Contains(engine.lastStatement, "GROUP BY attestation_hash")) +} + +type fallbackEngine struct { + hashes []string + lastStatement string +} + +func (f *fallbackEngine) Call(ctx *common.EngineContext, db nodesql.DB, namespace, action string, args []any, resultFn func(*common.Row) error) (*common.CallResult, error) { + panic("not implemented") +} + +func (f *fallbackEngine) CallWithoutEngineCtx(ctx context.Context, db nodesql.DB, namespace, action string, args []any, resultFn func(*common.Row) error) (*common.CallResult, error) { + panic("not implemented") +} + +func (f *fallbackEngine) Execute(ctx *common.EngineContext, db nodesql.DB, statement string, params map[string]any, fn func(*common.Row) error) error { + panic("not implemented") +} + +func (f *fallbackEngine) ExecuteWithoutEngineCtx(ctx context.Context, db nodesql.DB, statement string, params map[string]any, fn func(*common.Row) error) error { + f.lastStatement = statement + if strings.Contains(statement, "GROUP BY attestation_hash") { + rows := f.hashes + if limit, ok := params["limit"]; ok { + if n := toInt(limit); n >= 0 && n < len(rows) { + rows = rows[:n] + } + } + for _, h := range rows { + if err := fn(&common.Row{Values: []any{h}}); err != nil { + return err + } + } + } + return nil +} + +type fallbackDB struct{} + +func (fallbackDB) Execute(ctx context.Context, stmt string, args ...any) (*nodesql.ResultSet, error) { + return nil, fmt.Errorf("not implemented") +} + +func (fallbackDB) BeginTx(ctx context.Context) (nodesql.Tx, error) { + return nil, fmt.Errorf("not implemented") } diff --git a/extensions/tn_attestation/worker.go b/extensions/tn_attestation/worker.go new file mode 100644 index 000000000..0f24b36c1 --- /dev/null +++ b/extensions/tn_attestation/worker.go @@ -0,0 +1,304 @@ +// This file implements the attestation signing worker and async transaction monitoring. +// +// Transaction status checking runs in a dedicated background goroutine to prevent blocking +// EndBlock processing. The worker polls for confirmation over ~2 minutes, logging outcomes +// for operational visibility without impacting consensus performance. +package tn_attestation + +import ( + "context" + "fmt" + "strings" + "time" + + ktypes "github.com/trufnetwork/kwil-db/core/types" +) + +const ( + statusMaxAttempts = 12 + statusWorkerTimeout = 2 * time.Minute +) + +var statusRetryDelays = []time.Duration{2 * time.Second, 5 * time.Second, 10 * time.Second} + +// txStatusWork queues a broadcast transaction for async status monitoring. +// Avoids blocking consensus by deferring potentially slow RPC queries to a worker goroutine. +type txStatusWork struct { + hash ktypes.Hash // Transaction hash from broadcast + attestationHash string // Original attestation hash for log correlation + requester []byte // Requester address for audit trail +} + +// processAttestationHashes iterates through every dequeued hash, prepares the +// canonical payload(s) for signing, and submits signatures back through +// consensus. All failures are logged and do not abort the remainder of the +// batch so we can make steady progress even when an individual record is +// problematic. +func (e *signerExtension) processAttestationHashes(ctx context.Context, hashes []string) { + if len(hashes) == 0 { + return + } + + e.mu.RLock() + override := e.processOverride + e.mu.RUnlock() + if override != nil { + override(ctx, hashes) + return + } + + logger := e.Logger() + for _, hashHex := range hashes { + prepared, err := e.prepareSigningWork(ctx, hashHex) + if err != nil { + logger.Error("tn_attestation: failed to prepare signing payload", "hash", hashHex, "error", err) + continue + } + if len(prepared) == 0 { + logger.Debug("tn_attestation: no unsigned rows for hash", "hash", hashHex) + continue + } + for _, item := range prepared { + if err := e.submitSignature(ctx, item); err != nil { + logger.Error("tn_attestation: submit signature failed", "hash", hashHex, "requester", fmt.Sprintf("%x", item.Requester), "error", err) + continue + } + logger.Debug("tn_attestation: attestation signed", "hash", hashHex, "requester", fmt.Sprintf("%x", item.Requester)) + } + } +} + +func (e *signerExtension) submitSignature(ctx context.Context, item *PreparedSignature) error { + if item == nil { + return fmt.Errorf("prepared signature is nil") + } + if len(item.Requester) == 0 { + return fmt.Errorf("requester not available for attestation hash %s", item.HashHex) + } + + service := e.Service() + if service == nil || service.GenesisConfig == nil { + return fmt.Errorf("service or genesis config not available") + } + if service.GenesisConfig.ChainID == "" { + return fmt.Errorf("chain id not configured") + } + + broadcaster := e.Broadcaster() + if broadcaster == nil { + return fmt.Errorf("transaction broadcaster unavailable") + } + + signer := e.NodeSigner() + if signer == nil { + return fmt.Errorf("node signer not initialised") + } + + accountID, err := ktypes.GetSignerAccount(signer) + if err != nil { + return fmt.Errorf("derive account id: %w", err) + } + + accounts := e.Accounts() + if accounts == nil { + return fmt.Errorf("accounts subsystem unavailable") + } + + db := e.DB() + if db == nil { + return fmt.Errorf("database handle unavailable") + } + + // Nonce handling: First signature ever gets nonce=1 (account doesn't exist yet). + // Subsequent signatures increment from last recorded nonce. We tolerate "not found" + // because leader's first-ever signature creates the account on-chain. + account, err := accounts.GetAccount(ctx, db, accountID) + var nonce uint64 = 1 + if err != nil { + msg := strings.ToLower(err.Error()) + if !strings.Contains(msg, "not found") && !strings.Contains(msg, "no rows") { + return fmt.Errorf("get account: %w", err) + } + } else { + nonce = uint64(account.Nonce + 1) + } + + hashArg, err := ktypes.EncodeValue(item.Hash) + if err != nil { + return fmt.Errorf("encode hash argument: %w", err) + } + requesterArg, err := ktypes.EncodeValue(item.Requester) + if err != nil { + return fmt.Errorf("encode requester argument: %w", err) + } + heightArg, err := ktypes.EncodeValue(item.CreatedHeight) + if err != nil { + return fmt.Errorf("encode created_height argument: %w", err) + } + signatureArg, err := ktypes.EncodeValue(item.Signature) + if err != nil { + return fmt.Errorf("encode signature argument: %w", err) + } + + payload := &ktypes.ActionExecution{ + Namespace: "main", + Action: "sign_attestation", + Arguments: [][]*ktypes.EncodedValue{{ + hashArg, requesterArg, heightArg, signatureArg, + }}, + } + + tx, err := ktypes.CreateNodeTransaction(payload, service.GenesisConfig.ChainID, nonce) + if err != nil { + return fmt.Errorf("create tx: %w", err) + } + if err := tx.Sign(signer); err != nil { + return fmt.Errorf("sign tx: %w", err) + } + + txHash, _, err := broadcaster.BroadcastTx(ctx, tx, 0) // Use BroadcastWaitAccept to avoid blocking consensus + if err != nil { + return fmt.Errorf("broadcast tx: %w", err) + } + + logger := e.Logger() + logger.Info("tn_attestation: signature broadcast", + "hash", item.HashHex, + "tx_hash", txHash, + "requester", fmt.Sprintf("%x", item.Requester)) + + // Queue asynchronous transaction status check for logging + e.enqueueStatusCheck(txHash, item.HashHex, item.Requester) + + return nil +} + +// startStatusWorker initializes the background goroutine that monitors transaction status. +// Uses sync.Once to prevent goroutine leaks across leader transitions. The 128-entry +// buffer provides headroom during burst signing without blocking EndBlock. +// +// KEY INSIGHT: Worker survives leader changes intentionally. When node loses leadership, +// the queue still drains pending status checks for txs already broadcast. This ensures +// operators see outcomes for all signatures, not just those from the current term. +func (e *signerExtension) startStatusWorker() { + e.statusOnce.Do(func() { + if e.statusQueue == nil { + e.statusQueue = make(chan txStatusWork, 128) + } + go e.runStatusWorker() + }) +} + +// enqueueStatusCheck queues a transaction for async status monitoring. +// Drops entries if the queue is full to avoid blocking the signing workflow. +func (e *signerExtension) enqueueStatusCheck(txHash ktypes.Hash, attestationHash string, requester []byte) { + if e.getTxQueryClient() == nil { + return + } + if e.statusQueue == nil { + e.startStatusWorker() + } + work := txStatusWork{ + hash: txHash, + attestationHash: attestationHash, + requester: append([]byte(nil), requester...), + } + select { + case e.statusQueue <- work: + default: + e.Logger().Warn("tn_attestation: transaction status queue full, dropping entry", + "hash", attestationHash, + "tx_hash", txHash) + } +} + +// runStatusWorker consumes queued transactions and monitors each for confirmation. +// Runs for the lifetime of the extension process, surviving leader transitions. +func (e *signerExtension) runStatusWorker() { + for work := range e.statusQueue { + ctx, cancel := context.WithTimeout(context.Background(), statusWorkerTimeout) + e.monitorTransaction(ctx, work) + cancel() + } +} + +// monitorTransaction polls for transaction confirmation with exponential backoff. +// Queries up to 12 times over ~2 minutes (2s, 5s, then 10s intervals) to handle +// network delays and block production variance. Logs final outcome for observability. +func (e *signerExtension) monitorTransaction(ctx context.Context, work txStatusWork) { + client := e.getTxQueryClient() + if client == nil { + return + } + + delays := make([]time.Duration, len(statusRetryDelays)) + copy(delays, statusRetryDelays) + if len(delays) < statusMaxAttempts { + extra := make([]time.Duration, statusMaxAttempts-len(delays)) + for i := range extra { + extra[i] = 10 * time.Second + } + delays = append(delays, extra...) + } + + logger := e.Logger() + + for attempt, delay := range delays { + if attempt > 0 { + select { + case <-ctx.Done(): + logger.Warn("tn_attestation: transaction status check cancelled", + "hash", work.attestationHash, + "tx_hash", work.hash) + return + case <-time.After(delay): + } + } + + resp, err := client.TxQuery(ctx, work.hash) + if err != nil { + if attempt == len(delays)-1 { + logger.Warn("tn_attestation: transaction status unknown", + "hash", work.attestationHash, + "tx_hash", work.hash, + "attempt", attempt+1, + "requester", fmt.Sprintf("%x", work.requester), + "error", err) + } + continue + } + + if resp.Height <= 0 { + continue + } + + if resp.Result != nil && resp.Result.Code == uint32(ktypes.CodeOk) { + logger.Info("tn_attestation: transaction confirmed", + "hash", work.attestationHash, + "tx_hash", work.hash, + "height", resp.Height, + "requester", fmt.Sprintf("%x", work.requester)) + } else { + code := uint32(0) + logMsg := "transaction result missing" + if resp.Result != nil { + code = resp.Result.Code + logMsg = resp.Result.Log + } + logger.Error("tn_attestation: transaction failed", + "hash", work.attestationHash, + "tx_hash", work.hash, + "height", resp.Height, + "code", code, + "log", logMsg, + "requester", fmt.Sprintf("%x", work.requester)) + } + return +} + +logger.Warn("tn_attestation: transaction status unresolved after retries", + "hash", work.attestationHash, + "tx_hash", work.hash, + "requester", fmt.Sprintf("%x", work.requester), + "attempts", len(delays)) +} diff --git a/internal/migrations/023-attestation-schema.sql b/internal/migrations/023-attestation-schema.sql index d6dfc6d03..08a135843 100644 --- a/internal/migrations/023-attestation-schema.sql +++ b/internal/migrations/023-attestation-schema.sql @@ -2,7 +2,7 @@ * ATTESTATION SCHEMA MIGRATION * * Creates the essential tables needed for the attestation system: - * - attestations: Stores attestation requests and signatures with composite PK (requester, attestation_hash) + * - attestations: Stores attestation requests and signatures * - attestation_actions: Allowlist of actions permitted for attestation with normalized IDs */ diff --git a/internal/migrations/024-attestation-actions.sql b/internal/migrations/024-attestation-actions.sql index 497ac77a1..8045d6160 100644 --- a/internal/migrations/024-attestation-actions.sql +++ b/internal/migrations/024-attestation-actions.sql @@ -1,14 +1,3 @@ -/* - * ATTESTATION ACTIONS MIGRATION - * - * Current scope: - * - request_attestation: User requests signed attestation of query results - * - * Placeholders: - * - sign_attestation – TODO - * - get_signed_attestation / list_attestations – TODO - */ - -- ============================================================================= -- CORE ATTESTATION ACTIONS -- ============================================================================= @@ -55,7 +44,6 @@ $max_fee INT8 -- some args per action $query_result := tn_utils.call_dispatch($action_name, $args_bytes); - -- Calculate attestation hash from (version|algo|data_provider|stream_id|action_id|args) $version := 1; $algo := 1; -- secp256k1 -- Serialize canonical payload (version through result) using tn_utils helpers @@ -64,18 +52,6 @@ $max_fee INT8 $height_bytes := tn_utils.encode_uint64($created_height::INT); $action_id_bytes := tn_utils.encode_uint16($action_id::INT); - -- Build hash material in canonical order (no length prefixes) to match - -- the engine-side hashing utilities used by the signing service. - $hash_input := tn_utils.bytea_join(ARRAY[ - $version_bytes, - $algo_bytes, - $data_provider, - $stream_id, - $action_id_bytes, - $args_bytes - ], NULL); - $attestation_hash := digest($hash_input, 'sha256'); - -- Canonical payload mirrors Go helpers: each field length-prefixed so the -- validator can recover every component without ambiguity. $result_canonical := tn_utils.bytea_join(ARRAY[ @@ -89,6 +65,18 @@ $max_fee INT8 tn_utils.bytea_length_prefix($query_result) ], NULL); + -- Build hash material in canonical order using caller-provided inputs only. + -- This keeps the hash deterministic for clients (excludes block height and result). + $hash_input := tn_utils.bytea_join(ARRAY[ + $version_bytes, + $algo_bytes, + $data_provider, + $stream_id, + $action_id_bytes, + $args_bytes + ], NULL); + $attestation_hash := digest($hash_input, 'sha256'); + -- Store unsigned attestation INSERT INTO attestations ( attestation_hash, requester, result_canonical, encrypt_sig, @@ -98,16 +86,75 @@ $max_fee INT8 $created_height, NULL, NULL, NULL ); - -- Queue for signing (no-op on non-leader validators; handled by precompile) - tn_attestation.queue_for_signing(encode($attestation_hash, 'hex')); - - RETURN $attestation_hash; +-- Queue for signing (no-op on non-leader validators; handled by precompile) +tn_attestation.queue_for_signing(encode($attestation_hash, 'hex')); + +RETURN $attestation_hash; }; -- ----------------------------------------------------------------------------- --- TODO: sign_attestation --- Placeholder to avoid merge conflicts with the signing workflow. --- CREATE OR REPLACE ACTION sign_attestation(...) { ... }; +-- Leader-only action for recording validator signatures on attestations. +CREATE OR REPLACE ACTION sign_attestation( + $attestation_hash BYTEA, + $requester BYTEA, + $created_height INT8, + $signature BYTEA +) PUBLIC { + -- Only the current leader may submit signatures on-chain. + IF @leader_sender IS NULL OR @signer IS NULL OR @leader_sender != @signer { + $leader_hex TEXT := 'unknown'; + $signer_hex TEXT := 'unknown'; + IF @leader_sender IS NOT NULL { + $leader_hex := encode(@leader_sender, 'hex')::TEXT; + } + IF @signer IS NOT NULL { + $signer_hex := encode(@signer, 'hex')::TEXT; + } + ERROR('Only the current block leader may sign attestations. leader=' || $leader_hex || ' signer=' || $signer_hex); + } + + IF $attestation_hash IS NULL { + ERROR('Attestation hash is required'); + } + IF $requester IS NULL { + ERROR('Requester is required'); + } + IF $created_height IS NULL { + ERROR('Created height is required'); + } + IF $signature IS NULL { + ERROR('Signature is required'); + } + + -- Ensure attestation exists and has not been signed yet. + $found BOOL := FALSE; + FOR $row IN + SELECT signature + FROM attestations + WHERE attestation_hash = $attestation_hash + AND requester = $requester + AND created_height = $created_height + LIMIT 1 + { + $found := TRUE; + IF $row.signature IS NOT NULL { + ERROR('Attestation already signed for requester at height ' || $created_height::TEXT); + } + } + IF NOT $found { + ERROR('Attestation not found for requester at height ' || $created_height::TEXT); + } + + -- Record signature, validator identity, and the height at which it was signed. + UPDATE attestations + SET signature = $signature, + validator_pubkey = @signer, + signed_height = @height + WHERE attestation_hash = $attestation_hash + AND requester = $requester + AND created_height = $created_height + AND signature IS NULL; +}; -- TODO: get_signed_attestation / list_attestations -- CREATE OR REPLACE ACTION get_signed_attestation(...) { ... }; diff --git a/tests/streams/attestation/attestation_request_test.go b/tests/streams/attestation/attestation_request_test.go index 17568b6f5..d095eaee3 100644 --- a/tests/streams/attestation/attestation_request_test.go +++ b/tests/streams/attestation/attestation_request_test.go @@ -141,6 +141,7 @@ func runAttestationHappyPath(t *testing.T, ctx context.Context, platform *kwilTe } type attestationRow struct { + requester []byte attestationHash []byte resultCanonical []byte encryptSig bool @@ -156,24 +157,25 @@ func fetchAttestationRow(t *testing.T, ctx context.Context, platform *kwilTestin var rowData attestationRow err = platform.Engine.Execute(engineCtx, platform.DB, ` -SELECT attestation_hash, result_canonical, encrypt_sig, signature, validator_pubkey, signed_height, created_height +SELECT requester, attestation_hash, result_canonical, encrypt_sig, signature, validator_pubkey, signed_height, created_height FROM attestations WHERE attestation_hash = $hash; `, map[string]any{"hash": hash}, func(row *common.Row) error { - rowData.attestationHash = append([]byte(nil), row.Values[0].([]byte)...) - rowData.resultCanonical = append([]byte(nil), row.Values[1].([]byte)...) - rowData.encryptSig = row.Values[2].(bool) - if row.Values[3] != nil { - rowData.signature = append([]byte(nil), row.Values[3].([]byte)...) - } + rowData.requester = append([]byte(nil), row.Values[0].([]byte)...) + rowData.attestationHash = append([]byte(nil), row.Values[1].([]byte)...) + rowData.resultCanonical = append([]byte(nil), row.Values[2].([]byte)...) + rowData.encryptSig = row.Values[3].(bool) if row.Values[4] != nil { - rowData.validatorPubKey = append([]byte(nil), row.Values[4].([]byte)...) + rowData.signature = append([]byte(nil), row.Values[4].([]byte)...) } if row.Values[5] != nil { - height := row.Values[5].(int64) + rowData.validatorPubKey = append([]byte(nil), row.Values[5].([]byte)...) + } + if row.Values[6] != nil { + height := row.Values[6].(int64) rowData.signedHeight = &height } - rowData.createdHeight = row.Values[6].(int64) + rowData.createdHeight = row.Values[7].(int64) return nil }) require.NoError(t, err)