Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions broadcastclient/broadcastclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,131 @@ func TestServerMissingFeedServerVersion(t *testing.T) {
}
}

type accumulatingTransactionStreamer struct {
mu sync.Mutex
messages []*message.BroadcastFeedMessage
}

func (ts *accumulatingTransactionStreamer) AddBroadcastMessages(feedMessages []*message.BroadcastFeedMessage) error {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.messages = append(ts.messages, feedMessages...)
return nil
}

func (ts *accumulatingTransactionStreamer) getMessages() []*message.BroadcastFeedMessage {
ts.mu.Lock()
defer ts.mu.Unlock()
result := make([]*message.BroadcastFeedMessage, len(ts.messages))
copy(result, ts.messages)
return result
}

// awaitCount waits until at least count messages have been received.
// The timeout is a safety net to prevent the test from hanging.
func (ts *accumulatingTransactionStreamer) awaitCount(t *testing.T, count int, timeout time.Duration) {
t.Helper()
deadline := time.After(timeout)
for {
if len(ts.getMessages()) >= count {
return
}
select {
case <-deadline:
t.Fatalf("timed out waiting for %d messages, got %d", count, len(ts.getMessages()))
case <-time.After(10 * time.Millisecond):
}
}
}

func TestInvalidSignatureMessagesAreSkipped(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

chainId := uint64(9742)

// Trusted key: broadcaster signs with this, client trusts this
trustedKey, err := crypto.GenerateKey()
Require(t, err)
trustedAddr := crypto.PubkeyToAddress(trustedKey.PublicKey)
trustedSigner := signature.DataSignerFromPrivateKey(trustedKey)

// Untrusted key: used to create messages with invalid signatures
untrustedKey, err := crypto.GenerateKey()
Require(t, err)
untrustedSigner := signature.DataSignerFromPrivateKey(untrustedKey)

feedErrChan := make(chan error, 10)
trustedBroadcaster := broadcaster.NewBroadcaster(func() *wsbroadcastserver.BroadcasterConfig { return &wsbroadcastserver.DefaultTestBroadcasterConfig }, chainId, feedErrChan, trustedSigner)

Require(t, trustedBroadcaster.Initialize())
Require(t, trustedBroadcaster.Start(ctx))
defer trustedBroadcaster.StopAndWait()

// Second broadcaster (not started) used only to create messages signed with the untrusted key
untrustedBroadcaster := broadcaster.NewBroadcaster(func() *wsbroadcastserver.BroadcasterConfig { return &wsbroadcastserver.DefaultTestBroadcasterConfig }, chainId, make(chan error, 1), untrustedSigner)

ts := &accumulatingTransactionStreamer{}

clientFeedErrChan := make(chan error, 10)
broadcastClient, err := newTestBroadcastClient(
DefaultTestConfig,
trustedBroadcaster.ListenerAddr(),
chainId,
0,
ts,
nil,
clientFeedErrChan,
&trustedAddr,
t,
)
Require(t, err)
broadcastClient.Start(ctx)
defer broadcastClient.StopAndWait()

// Batch 1: valid messages (seq 0, 1) - should be delivered.
Require(t, trustedBroadcaster.BroadcastFeedMessages(feedMessage(t, trustedBroadcaster, 0)))
Require(t, trustedBroadcaster.BroadcastFeedMessages(feedMessage(t, trustedBroadcaster, 1)))
ts.awaitCount(t, 2, 10*time.Second)

// Batch 2: invalid messages (seq 2, 3) signed with untrusted key - should be skipped.
Require(t, trustedBroadcaster.BroadcastFeedMessages(feedMessage(t, untrustedBroadcaster, 2)))
Require(t, trustedBroadcaster.BroadcastFeedMessages(feedMessage(t, untrustedBroadcaster, 3)))

// Sentinel (seq 2): a valid message that deterministically proves the client has
// processed and skipped the invalid messages. WebSocket messages are ordered, so the
// sentinel can only arrive after the invalid ones have been processed (and skipped).
// Invalid messages don't advance nextSeqNum (stays at 2), so the sentinel is seq 2.
Require(t, trustedBroadcaster.BroadcastFeedMessages(feedMessage(t, trustedBroadcaster, 2)))
ts.awaitCount(t, 3, 10*time.Second)

// Verify: only valid messages were delivered, and all have trusted signatures.
got := ts.getMessages()
if len(got) != 3 {
t.Fatalf("expected 3 messages, got %d", len(got))
}
for i, msg := range got {
if msg.SequenceNumber != arbutil.MessageIndex(i) { // nolint: gosec
t.Fatalf("message %d: unexpected seq number: %d", i, msg.SequenceNumber)
}
hash := msg.SignatureHash(chainId)
sigPub, err := crypto.SigToPub(hash.Bytes(), msg.Signature)
Require(t, err)
signerAddr := crypto.PubkeyToAddress(*sigPub)
if signerAddr != trustedAddr {
t.Fatalf("message %d (seq %d): signed by %s, expected trusted signer %s", i, msg.SequenceNumber, signerAddr, trustedAddr)
}
}

// Verify no fatal errors occurred (invalid signatures are non-fatal since NIT-4017)
select {
case err := <-clientFeedErrChan:
t.Fatalf("unexpected fatal feed error: %v", err)
default:
}
}

func TestBroadcastClientReconnectsOnServerDisconnect(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
Expand Down
2 changes: 2 additions & 0 deletions changelog/pmikolajczyk-nit-4034.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
### Ignored
- Add test for invalid feed signature handling.
Loading