Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
58fb761
refactor: update attestation schema and actions for clarity
outerlook Oct 10, 2025
42e9c39
feat: implement tn_attestation extension for signing and processing a…
outerlook Oct 10, 2025
18e4505
refactor: update tn_attestation integration tests and SQL action vali…
outerlook Oct 10, 2025
d77ecf9
docs: add documentation for tn_attestation package and enhance constants
outerlook Oct 10, 2025
a5b560d
refactor: enhance tn_attestation extension with improved broadcaster …
outerlook Oct 10, 2025
baf6494
refactor: enhance transaction status monitoring in tn_attestation ext…
outerlook Oct 11, 2025
ee54906
test: add unit tests for transaction status worker in tn_attestation …
outerlook Oct 11, 2025
0976278
test: add documentation for newFakeTxQueryClient function in status_w…
outerlook Oct 11, 2025
7118ae7
docs: enhance comments in processor and worker for clarity
outerlook Oct 11, 2025
a4e8c6c
refactor: optimize canonical payload handling in tn_attestation exten…
outerlook Oct 13, 2025
961c465
test: enhance ComputeAttestationHash tests in processor_test.go
outerlook Oct 13, 2025
c0a4f80
test: refactor TestSubmitSignature for improved clarity and reusability
outerlook Oct 13, 2025
0ce7e98
refactor: improve computeAttestationHash function and related tests
outerlook Oct 13, 2025
4497ef2
test: streamline TestComputeAttestationHash for improved clarity
outerlook Oct 13, 2025
9dcf571
fix: update sign_attestation action to ensure signature is not null
outerlook Oct 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions extensions/tn_attestation/broadcast.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package tn_attestation

import (
"context"
"fmt"
"net"
"net/url"
"strings"

"github.com/trufnetwork/kwil-db/common"
rpcclient "github.com/trufnetwork/kwil-db/core/rpc/client"
userjsonrpc "github.com/trufnetwork/kwil-db/core/rpc/client/user/jsonrpc"
"github.com/trufnetwork/kwil-db/core/types"
)

// ensureBroadcaster initializes the RPC client for transaction submission if not already set.
// Called during startup and leader acquisition to ensure the leader can broadcast sign_attestation
// transactions. Prefers extension-specific rpc_url config, falling back to node's RPC endpoint.
func (e *signerExtension) ensureBroadcaster(service *common.Service) {
if service == nil || service.LocalConfig == nil {
return
}
if e.Broadcaster() != nil && e.TxQueryClient() != nil {
return
}

endpoint := ""
if cfg, ok := service.LocalConfig.Extensions[ExtensionName]; ok {
if v := strings.TrimSpace(cfg["rpc_url"]); v != "" {
endpoint = v
}
}

if endpoint == "" && service.LocalConfig.RPC.ListenAddress != "" {
endpoint = service.LocalConfig.RPC.ListenAddress
}
if endpoint == "" {
e.Logger().Warn("tn_attestation: cannot build broadcaster (no rpc endpoint configured)")
return
}

u, err := normalizeListenAddressForClient(endpoint)
if err != nil {
e.Logger().Warn("tn_attestation: invalid rpc endpoint", "endpoint", endpoint, "error", err)
return
}

broadcaster, queryClient := makeBroadcasterFromURL(u)
e.setBroadcaster(broadcaster)
e.setTxQueryClient(queryClient)
e.startStatusWorker()
}

// normalizeListenAddressForClient converts a server bind address (e.g., "0.0.0.0:8080")
// to a client-usable localhost URL. Needed because the extension runs on the same node
// but cannot connect to wildcard addresses like 0.0.0.0 or [::].
func normalizeListenAddressForClient(listen string) (*url.URL, error) {
if listen == "" {
return nil, fmt.Errorf("empty listen address")
}
endpoint := listen
if !strings.HasPrefix(endpoint, "http://") && !strings.HasPrefix(endpoint, "https://") {
endpoint = "http://" + endpoint
}
u, err := url.Parse(endpoint)
if err != nil {
return nil, err
}
host, port, err := net.SplitHostPort(u.Host)
if err == nil {
clean := strings.Trim(host, "[]")
if clean == "" {
u.Host = net.JoinHostPort("127.0.0.1", port)
} else if ip := net.ParseIP(clean); ip != nil && ip.IsUnspecified() {
u.Host = net.JoinHostPort("127.0.0.1", port)
}
}
return u, nil
}

// makeBroadcasterFromURL creates broadcaster and query client from the normalized RPC endpoint.
// Returns the same instance for both interfaces to share a single connection pool.
func makeBroadcasterFromURL(u *url.URL) (TxBroadcaster, TxQueryClient) {
client := userjsonrpc.NewClient(u)
br := &jsonRPCBroadcaster{client: client}
return br, br
}

type jsonRPCBroadcaster struct {
client *userjsonrpc.Client
}

func (b *jsonRPCBroadcaster) BroadcastTx(ctx context.Context, tx *types.Transaction, sync uint8) (types.Hash, *types.TxResult, error) {
mode := rpcclient.BroadcastWaitAccept
if sync == uint8(rpcclient.BroadcastWaitCommit) || sync == 1 {
mode = rpcclient.BroadcastWaitCommit
}

hash, err := b.client.Broadcast(ctx, tx, mode)
if err != nil {
return types.Hash{}, nil, err
}

if mode == rpcclient.BroadcastWaitAccept {
return hash, nil, nil
}

resp, err := b.client.TxQuery(ctx, hash)
if err != nil {
return hash, nil, fmt.Errorf("tx query failed: %w", err)
}
if resp == nil || resp.Result == nil {
return hash, nil, fmt.Errorf("transaction result missing")
}

return hash, resp.Result, nil
}

func (b *jsonRPCBroadcaster) TxQuery(ctx context.Context, txHash types.Hash) (*types.TxQueryResponse, error) {
return b.client.TxQuery(ctx, txHash)
}

// TxBroadcaster matches the subset of the JSON-RPC client used by the signing
// worker to inject transactions.
type TxBroadcaster interface {
BroadcastTx(ctx context.Context, tx *types.Transaction, sync uint8) (types.Hash, *types.TxResult, error)
}

// TxQueryClient interface for querying transaction status.
type TxQueryClient interface {
TxQuery(ctx context.Context, txHash types.Hash) (*types.TxQueryResponse, error)
}
114 changes: 114 additions & 0 deletions extensions/tn_attestation/canonical.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package tn_attestation

import (
"bytes"
"crypto/sha256"
"encoding/binary"
"fmt"
)

// CanonicalPayload represents the eight attestation fields stored in result_canonical.
// The byte layout mirrors the SQL migration: fixed-width integers followed by
// length-prefixed blobs (little-endian 4-byte prefixes for variable sections).
//
// Layout:
//
// 1 byte version
// 1 byte algorithm
// 8 bytes block height (big-endian)
// 4 + n data provider (length-prefixed)
// 4 + m stream ID (length-prefixed)
// 2 bytes action ID (big-endian)
// 4 + k arguments (length-prefixed)
// 4 + r result (length-prefixed)
type CanonicalPayload struct {
Version uint8
Algorithm uint8
BlockHeight uint64
DataProvider []byte
StreamID []byte
ActionID uint16
Args []byte
Result []byte

raw []byte
}

// ParseCanonicalPayload decodes the canonical payload into structured fields.
// The function validates every length prefix and returns descriptive errors so
// future maintainers can diagnose storage corruption quickly.
func ParseCanonicalPayload(data []byte) (*CanonicalPayload, error) {
if len(data) < 1+1+8+2 {
return nil, fmt.Errorf("canonical payload too short: got %d bytes", len(data))
}

cursor := 0
payload := &CanonicalPayload{
Version: data[cursor],
Algorithm: data[cursor+1],
}
cursor += 2

payload.BlockHeight = binary.BigEndian.Uint64(data[cursor : cursor+8])
cursor += 8

var err error
if payload.DataProvider, cursor, err = readLengthPrefixed(data, cursor); err != nil {
return nil, fmt.Errorf("decode data_provider: %w", err)
}
if payload.StreamID, cursor, err = readLengthPrefixed(data, cursor); err != nil {
return nil, fmt.Errorf("decode stream_id: %w", err)
}

if len(data) < cursor+2 {
return nil, fmt.Errorf("canonical payload truncated before action_id")
}
payload.ActionID = binary.BigEndian.Uint16(data[cursor : cursor+2])
cursor += 2

if payload.Args, cursor, err = readLengthPrefixed(data, cursor); err != nil {
return nil, fmt.Errorf("decode args: %w", err)
}
if payload.Result, cursor, err = readLengthPrefixed(data, cursor); err != nil {
return nil, fmt.Errorf("decode result: %w", err)
}

if cursor != len(data) {
return nil, fmt.Errorf("canonical payload has %d trailing bytes", len(data)-cursor)
}

payload.raw = append(payload.raw[:0], data...) // ensure private copy
return payload, nil
}

// SigningBytes returns the backing canonical bytes that must be covered by the
// validator's signature (fields 1 through 8). Callers should treat the slice as
// immutable.
func (p *CanonicalPayload) SigningBytes() []byte {
return p.raw
}

// SigningDigest computes sha256(SigningBytes()) to match the on-chain verifier
// expectations. The digest is returned as a value to prevent accidental reuse of
// the backing slice.
func (p *CanonicalPayload) SigningDigest() [sha256.Size]byte {
return sha256.Sum256(p.SigningBytes())
}

// readLengthPrefixed decodes a little-endian uint32 length followed by that many bytes.
func readLengthPrefixed(data []byte, cursor int) ([]byte, int, error) {
if len(data) < cursor+4 {
return nil, cursor, fmt.Errorf("truncated length prefix at offset %d", cursor)
}

length := binary.LittleEndian.Uint32(data[cursor : cursor+4])
cursor += 4

if len(data) < cursor+int(length) {
return nil, cursor, fmt.Errorf("declared length %d exceeds remaining %d bytes", length, len(data)-cursor)
}

chunk := data[cursor : cursor+int(length)]
cursor += int(length)
return bytes.Clone(chunk), cursor, nil
}
94 changes: 94 additions & 0 deletions extensions/tn_attestation/canonical_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package tn_attestation

import (
"bytes"
"crypto/sha256"
"encoding/binary"
"testing"

"github.com/stretchr/testify/require"
)

func TestParseCanonicalPayload_Success(t *testing.T) {
version := uint8(1)
algo := uint8(1)
height := uint64(12345)
actionID := uint16(9)
dataProvider := []byte("provider-1")
streamID := []byte("stream-xyz")
args := []byte{0x01, 0x02, 0x03}
result := []byte{0xAA, 0xBB}

raw := buildCanonical(version, algo, height, dataProvider, streamID, actionID, args, result)

payload, err := ParseCanonicalPayload(raw)
require.NoError(t, err)
require.NotNil(t, payload)

require.Equal(t, version, payload.Version)
require.Equal(t, algo, payload.Algorithm)
require.Equal(t, height, payload.BlockHeight)
require.Equal(t, dataProvider, payload.DataProvider)
require.Equal(t, streamID, payload.StreamID)
require.Equal(t, actionID, payload.ActionID)
require.Equal(t, args, payload.Args)
require.Equal(t, result, payload.Result)

// Signing digest should equal sha256(raw)
expectedDigest := sha256.Sum256(raw)
require.Equal(t, expectedDigest, payload.SigningDigest())
require.True(t, bytes.Equal(raw, payload.SigningBytes()))
}

func TestParseCanonicalPayload_TruncatedPrefix(t *testing.T) {
base := buildCanonical(1, 1, 1, []byte("a"), []byte("b"), 1, []byte{0x01}, []byte{0x02})
// Corrupt by chopping last byte
corrupted := base[:len(base)-1]

_, err := ParseCanonicalPayload(corrupted)
require.Error(t, err)
require.Contains(t, err.Error(), "decode result")
}

func TestParseCanonicalPayload_ExtraBytes(t *testing.T) {
base := buildCanonical(1, 1, 1, []byte("a"), []byte("b"), 1, []byte{0x01}, []byte{0x02})
extra := append(base, []byte{0xFF, 0xFF}...)

_, err := ParseCanonicalPayload(extra)
require.Error(t, err)
require.Contains(t, err.Error(), "trailing bytes")
}

// buildCanonical mirrors the SQL encoder to generate canonical payloads.
func buildCanonical(version, algo uint8, height uint64, provider, stream []byte, actionID uint16, args, result []byte) []byte {
buf := bytes.NewBuffer(nil)
buf.WriteByte(version)
buf.WriteByte(algo)

heightBytes := make([]byte, 8)
binary.BigEndian.PutUint64(heightBytes, height)
buf.Write(heightBytes)

lengthBytes := make([]byte, 4)
binary.LittleEndian.PutUint32(lengthBytes, uint32(len(provider)))
buf.Write(lengthBytes)
buf.Write(provider)

binary.LittleEndian.PutUint32(lengthBytes, uint32(len(stream)))
buf.Write(lengthBytes)
buf.Write(stream)

actionBytes := make([]byte, 2)
binary.BigEndian.PutUint16(actionBytes, actionID)
buf.Write(actionBytes)

binary.LittleEndian.PutUint32(lengthBytes, uint32(len(args)))
buf.Write(lengthBytes)
buf.Write(args)

binary.LittleEndian.PutUint32(lengthBytes, uint32(len(result)))
buf.Write(lengthBytes)
buf.Write(result)

return buf.Bytes()
}
1 change: 1 addition & 0 deletions extensions/tn_attestation/constants.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
package tn_attestation

// ExtensionName is the identifier for the attestation signing extension.
const ExtensionName = "tn_attestation"
15 changes: 15 additions & 0 deletions extensions/tn_attestation/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Package tn_attestation implements the attestation signing workflow for TN validators.
//
// When a user requests an attestation via request_attestation (SQL), the extension:
// 1. Queues the hash via queue_for_signing precompile (non-deterministic, leader-only)
// 2. Processes queued hashes on leader's EndBlock
// 3. Signs canonical payloads using the validator's secp256k1 key
// 4. Broadcasts sign_attestation transactions back to consensus
//
// Key components:
// - ValidatorSigner: Thread-safe secp256k1 signing with EVM compatibility
// - CanonicalPayload: Structured representation of the 8-field attestation format
// - Leader callbacks: OnAcquire, OnLose, OnEndBlock lifecycle hooks
//
// Initialize the extension by calling InitializeExtension() during node startup.
package tn_attestation
Loading
Loading