diff --git a/state/accountsDB.go b/state/accountsDB.go index 9d57fcc06a5..c913938443d 100644 --- a/state/accountsDB.go +++ b/state/accountsDB.go @@ -346,54 +346,68 @@ func (adb *AccountsDB) saveCode(newAcc, oldAcc baseAccountHandler) error { return nil } - unmodifiedOldCodeEntry, err := adb.updateOldCodeEntry(oldCodeHash) + oldCodeEntry, err := adb.getCodeEntry(oldCodeHash) if err != nil { return err } - err = adb.updateNewCodeEntry(newCodeHash, newCode) + newCodeEntry, err := adb.getCodeEntry(newCodeHash) if err != nil { return err } - entry, err := NewJournalEntryCode(unmodifiedOldCodeEntry, oldCodeHash, newCodeHash, adb.mainTrie, adb.marshaller) + entry, err := NewJournalEntryCode(oldCodeEntry, oldCodeHash, newCodeEntry, newCodeHash, adb.mainTrie, adb.marshaller) if err != nil { return err } adb.journalize(entry) + err = adb.updateOldCodeEntry(oldCodeHash, oldCodeEntry) + if err != nil { + return err + } + + err = adb.updateNewCodeEntry(newCodeHash, newCodeEntry, newCode) + if err != nil { + return err + } + newAcc.SetCodeHash(newCodeHash) return nil } -func (adb *AccountsDB) updateOldCodeEntry(oldCodeHash []byte) (*CodeEntry, error) { - oldCodeEntry, err := getCodeEntry(oldCodeHash, adb.mainTrie, adb.marshaller) +func (adb *AccountsDB) getCodeEntry(hash []byte) (*CodeEntry, error) { + codeEntry, err := getCodeEntry(hash, adb.mainTrie, adb.marshaller) if err != nil { return nil, err } - if oldCodeEntry == nil { - return nil, nil - } - sc := &stateChange.StateAccess{ Type: stateChange.Read, - MainTrieKey: oldCodeHash, + MainTrieKey: hash, MainTrieVal: nil, Operation: stateChange.GetCode, DataTrieChanges: nil, } adb.stateAccessesCollector.AddStateAccess(sc) - unmodifiedOldCodeEntry := &CodeEntry{ + return codeEntry, nil +} + +func (adb *AccountsDB) updateOldCodeEntry(oldCodeHash []byte, oldCodeEntry *CodeEntry) error { + if oldCodeEntry == nil { + return nil + } + + codeEntryClone := &CodeEntry{ Code: oldCodeEntry.Code, NumReferences: oldCodeEntry.NumReferences, } - if oldCodeEntry.NumReferences <= 1 { - err = adb.mainTrie.Delete(oldCodeHash) + if codeEntryClone.NumReferences <= 1 { + err := adb.mainTrie.Delete(oldCodeHash) if err != nil { - return nil, err + return err } sc1 := &stateChange.StateAccess{ @@ -405,16 +419,16 @@ func (adb *AccountsDB) updateOldCodeEntry(oldCodeHash []byte) (*CodeEntry, error } adb.stateAccessesCollector.AddStateAccess(sc1) - return unmodifiedOldCodeEntry, nil + return nil } - oldCodeEntry.NumReferences-- - codeEntryBytes, err := saveCodeEntry(oldCodeHash, oldCodeEntry, adb.mainTrie, adb.marshaller) + codeEntryClone.NumReferences-- + codeEntryBytes, err := saveCodeEntry(oldCodeHash, codeEntryClone, adb.mainTrie, adb.marshaller) if err != nil { - return nil, err + return err } - sc = &stateChange.StateAccess{ + sc := &stateChange.StateAccess{ Type: stateChange.Write, MainTrieKey: oldCodeHash, MainTrieVal: codeEntryBytes, @@ -423,41 +437,27 @@ func (adb *AccountsDB) updateOldCodeEntry(oldCodeHash []byte) (*CodeEntry, error } adb.stateAccessesCollector.AddStateAccess(sc) - return unmodifiedOldCodeEntry, nil + return nil } -func (adb *AccountsDB) updateNewCodeEntry(newCodeHash []byte, newCode []byte) error { +func (adb *AccountsDB) updateNewCodeEntry(newCodeHash []byte, newCodeEntry *CodeEntry, newCode []byte) error { if len(newCode) == 0 { return nil } - newCodeEntry, err := getCodeEntry(newCodeHash, adb.mainTrie, adb.marshaller) - if err != nil { - return err - } - - sc := &stateChange.StateAccess{ - Type: stateChange.Read, - MainTrieKey: newCodeHash, - MainTrieVal: nil, - Operation: stateChange.GetCode, - DataTrieChanges: nil, - } - adb.stateAccessesCollector.AddStateAccess(sc) - - if newCodeEntry == nil { - newCodeEntry = &CodeEntry{ - Code: newCode, - } + codeEntry := &CodeEntry{} + codeEntry.Code = newCode + if newCodeEntry != nil { + codeEntry.NumReferences = newCodeEntry.NumReferences } - newCodeEntry.NumReferences++ + codeEntry.NumReferences++ - codeEntryBytes, err := saveCodeEntry(newCodeHash, newCodeEntry, adb.mainTrie, adb.marshaller) + codeEntryBytes, err := saveCodeEntry(newCodeHash, codeEntry, adb.mainTrie, adb.marshaller) if err != nil { return err } - sc = &stateChange.StateAccess{ + sc := &stateChange.StateAccess{ Type: stateChange.Write, MainTrieKey: newCodeHash, MainTrieVal: codeEntryBytes, @@ -676,18 +676,19 @@ func (adb *AccountsDB) removeDataTrie(baseAcc baseAccountHandler) error { func (adb *AccountsDB) removeCode(baseAcc baseAccountHandler) error { oldCodeHash := baseAcc.GetCodeHash() - unmodifiedOldCodeEntry, err := adb.updateOldCodeEntry(oldCodeHash) + + oldCodeEntry, err := adb.getCodeEntry(oldCodeHash) if err != nil { return err } - codeChangeEntry, err := NewJournalEntryCode(unmodifiedOldCodeEntry, oldCodeHash, nil, adb.mainTrie, adb.marshaller) + entry, err := NewJournalEntryCode(oldCodeEntry, oldCodeHash, nil, nil, adb.mainTrie, adb.marshaller) if err != nil { return err } - adb.journalize(codeChangeEntry) + adb.journalize(entry) - return nil + return adb.updateOldCodeEntry(oldCodeHash, oldCodeEntry) } // LoadAccount fetches the account based on the address. Creates an empty account if the account is missing. @@ -853,12 +854,14 @@ func (adb *AccountsDB) RevertToSnapshot(snapshot int) error { for i := len(adb.entries) - 1; i >= snapshot; i-- { account, err := adb.entries[i].Revert() if err != nil { + adb.entries = adb.entries[:i+1] return err } if !check.IfNil(account) { _, err = adb.saveAccountToTrie(account, adb.mainTrie) if err != nil { + adb.entries = adb.entries[:i+1] return err } } diff --git a/state/accountsDB_test.go b/state/accountsDB_test.go index 2f2125c7737..a248628223e 100644 --- a/state/accountsDB_test.go +++ b/state/accountsDB_test.go @@ -3467,3 +3467,384 @@ func testAccountLoadInParallel( wg.Wait() } + +func TestAccountsDB_SaveAccountShouldRevertOldCodeEntryWhenUpdateNewCodeEntryFails(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("expected update new code entry error") + marshaller := &marshallerMock.MarshalizerMock{} + hasher := &hashingMocks.HasherMock{} + + oldCode := []byte("old code") + oldCodeHash := hasher.Compute(string(oldCode)) + oldCodeEntry := &state.CodeEntry{ + Code: oldCode, + NumReferences: 2, + } + oldCodeEntryBytes, _ := marshaller.Marshal(oldCodeEntry) + + newCode := []byte("new code") + newCodeHash := hasher.Compute(string(newCode)) + + codeEntries := map[string][]byte{ + string(oldCodeHash): oldCodeEntryBytes, + } + + accountAddress := generateRandomByteArray(32) + acc := stateMock.NewAccountWrapMock(accountAddress) + acc.SetCode(oldCode) + acc.SetCodeHash(oldCodeHash) + marshalledAcc, err := marshaller.Marshal(acc) + require.Nil(t, err) + + trieStub := &trieMock.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + if bytes.Equal(key, accountAddress) { + return marshalledAcc, 0, nil + } + return codeEntries[string(key)], 0, nil + }, + UpdateCalled: func(key, value []byte) error { + if bytes.Equal(key, newCodeHash) { + if len(value) == 0 { + return nil + } + return expectedErr + } + + codeEntries[string(key)] = value + return nil + }, + DeleteCalled: func(key []byte) error { + delete(codeEntries, string(key)) + return nil + }, + GetStorageManagerCalled: func() common.StorageManager { + return &storageManager.StorageManagerStub{} + }, + } + + adb := generateAccountDBFromTrie(trieStub) + + dummyAcc := generateAccount() + err = adb.SaveAccount(dummyAcc) + require.Nil(t, err) + snapshot := adb.JournalLen() + require.NotEqual(t, 0, snapshot) + + oldAcc, err := adb.LoadAccount(accountAddress) + require.Nil(t, err) + oldAcc.(state.UserAccountHandler).SetCode(newCode) + + err = adb.SaveAccount(oldAcc) + require.ErrorIs(t, err, expectedErr) + + err = adb.RevertToSnapshot(snapshot) + require.Nil(t, err) + + var restoredOldCodeEntry state.CodeEntry + err = marshaller.Unmarshal(&restoredOldCodeEntry, codeEntries[string(oldCodeHash)]) + require.Nil(t, err) + assert.Equal(t, oldCode, restoredOldCodeEntry.Code) + assert.Equal(t, uint32(2), restoredOldCodeEntry.NumReferences) +} + +func TestAccountsDB_RevertShouldNotModifyNewCodeWhenOldCodeUpdateFails(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("expected old code update error") + marshaller := &marshallerMock.MarshalizerMock{} + hasher := &hashingMocks.HasherMock{} + + oldCode := []byte("old code") + oldCodeHash := hasher.Compute(string(oldCode)) + oldCodeEntryBytes, err := marshaller.Marshal(&state.CodeEntry{ + Code: oldCode, + NumReferences: 2, + }) + require.NoError(t, err) + + newCode := []byte("new code") + newCodeHash := hasher.Compute(string(newCode)) + newCodeEntryBytes, err := marshaller.Marshal(&state.CodeEntry{ + Code: newCode, + NumReferences: 3, + }) + require.NoError(t, err) + + codeEntries := map[string][]byte{ + string(oldCodeHash): oldCodeEntryBytes, + string(newCodeHash): newCodeEntryBytes, + } + + accountAddress := generateRandomByteArray(32) + acc := stateMock.NewAccountWrapMock(accountAddress) + acc.SetCode(oldCode) + acc.SetCodeHash(oldCodeHash) + + marshalledAcc, err := marshaller.Marshal(acc) + require.NoError(t, err) + + failOldCodeUpdate := true + trieStub := &trieMock.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + if bytes.Equal(key, accountAddress) { + return marshalledAcc, 0, nil + } + + return codeEntries[string(key)], 0, nil + }, + UpdateCalled: func(key, value []byte) error { + if bytes.Equal(key, oldCodeHash) && failOldCodeUpdate { + failOldCodeUpdate = false + return expectedErr + } + + codeEntries[string(key)] = value + return nil + }, + DeleteCalled: func(key []byte) error { + delete(codeEntries, string(key)) + return nil + }, + GetStorageManagerCalled: func() common.StorageManager { + return &storageManager.StorageManagerStub{} + }, + } + + adb := generateAccountDBFromTrie(trieStub) + + err = adb.SaveAccount(generateAccount()) + require.NoError(t, err) + + snapshot := adb.JournalLen() + require.NotZero(t, snapshot) + + loadedAcc, err := adb.LoadAccount(accountAddress) + require.NoError(t, err) + + loadedAcc.(state.UserAccountHandler).SetCode(newCode) + + err = adb.SaveAccount(loadedAcc) + require.ErrorIs(t, err, expectedErr) + + err = adb.RevertToSnapshot(snapshot) + require.NoError(t, err) + + checkCodeEntry(oldCodeHash, oldCode, 2, marshaller, trieStub, t) + checkCodeEntry(newCodeHash, newCode, 3, marshaller, trieStub, t) +} + +func TestAccountsDB_RevertShouldNotModifyNewCodeWhenNewCodeUpdateFails(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("expected new code update error") + marshaller := &marshallerMock.MarshalizerMock{} + hasher := &hashingMocks.HasherMock{} + + oldCode := []byte("old code") + oldCodeHash := hasher.Compute(string(oldCode)) + oldCodeEntryBytes, err := marshaller.Marshal(&state.CodeEntry{ + Code: oldCode, + NumReferences: 2, + }) + require.NoError(t, err) + + newCode := []byte("new code") + newCodeHash := hasher.Compute(string(newCode)) + newCodeEntryBytes, err := marshaller.Marshal(&state.CodeEntry{ + Code: newCode, + NumReferences: 3, + }) + require.NoError(t, err) + + codeEntries := map[string][]byte{ + string(oldCodeHash): oldCodeEntryBytes, + string(newCodeHash): newCodeEntryBytes, + } + + accountAddress := generateRandomByteArray(32) + acc := stateMock.NewAccountWrapMock(accountAddress) + acc.SetCode(oldCode) + acc.SetCodeHash(oldCodeHash) + + marshalledAcc, err := marshaller.Marshal(acc) + require.NoError(t, err) + + failNewCodeUpdate := true + trieStub := &trieMock.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + if bytes.Equal(key, accountAddress) { + return marshalledAcc, 0, nil + } + + return codeEntries[string(key)], 0, nil + }, + UpdateCalled: func(key, value []byte) error { + if bytes.Equal(key, newCodeHash) && failNewCodeUpdate { + failNewCodeUpdate = false + return expectedErr + } + + codeEntries[string(key)] = value + return nil + }, + DeleteCalled: func(key []byte) error { + delete(codeEntries, string(key)) + return nil + }, + GetStorageManagerCalled: func() common.StorageManager { + return &storageManager.StorageManagerStub{} + }, + } + + adb := generateAccountDBFromTrie(trieStub) + + err = adb.SaveAccount(generateAccount()) + require.NoError(t, err) + + snapshot := adb.JournalLen() + require.NotZero(t, snapshot) + + loadedAcc, err := adb.LoadAccount(accountAddress) + require.NoError(t, err) + + loadedAcc.(state.UserAccountHandler).SetCode(newCode) + + err = adb.SaveAccount(loadedAcc) + require.ErrorIs(t, err, expectedErr) + + err = adb.RevertToSnapshot(snapshot) + require.NoError(t, err) + + checkCodeEntry(oldCodeHash, oldCode, 2, marshaller, trieStub, t) + checkCodeEntry(newCodeHash, newCode, 3, marshaller, trieStub, t) +} + +type journalEntryStub struct { + revertCalled func() (vmcommon.AccountHandler, error) +} + +// Revert - +func (stub *journalEntryStub) Revert() (vmcommon.AccountHandler, error) { + if stub.revertCalled != nil { + return stub.revertCalled() + } + + return nil, nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (stub *journalEntryStub) IsInterfaceNil() bool { + return stub == nil +} + +func TestAccountsDB_RevertToSnapshotShouldNotRetryCompletedEntries(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("expected revert error") + + trieStub := &trieMock.TrieStub{ + GetStorageManagerCalled: func() common.StorageManager { + return &storageManager.StorageManagerStub{} + }, + } + adb := generateAccountDBFromTrie(trieStub) + + adb.Journalize(&journalEntryStub{}) + snapshot := adb.JournalLen() + + failingCalls := 0 + adb.Journalize(&journalEntryStub{ + revertCalled: func() (vmcommon.AccountHandler, error) { + failingCalls++ + if failingCalls == 1 { + return nil, expectedErr + } + + return nil, nil + }, + }) + + completedCalls := 0 + adb.Journalize(&journalEntryStub{ + revertCalled: func() (vmcommon.AccountHandler, error) { + completedCalls++ + return nil, nil + }, + }) + + err := adb.RevertToSnapshot(snapshot) + require.ErrorIs(t, err, expectedErr) + + // The latest entry completed successfully and must already be removed. + assert.Equal(t, snapshot+1, adb.JournalLen()) + assert.Equal(t, 1, completedCalls) + assert.Equal(t, 1, failingCalls) + + err = adb.RevertToSnapshot(snapshot) + require.NoError(t, err) + + assert.Equal(t, snapshot, adb.JournalLen()) + assert.Equal(t, 1, completedCalls) + assert.Equal(t, 2, failingCalls) +} + +func TestAccountsDB_RevertToSnapshotShouldNotRetryCompletedEntriesWhenAccountSaveFails(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("expected account save error") + accountAddress := generateRandomByteArray(32) + account := stateMock.NewAccountWrapMock(accountAddress) + + failAccountSave := true + trieStub := &trieMock.TrieStub{ + UpdateCalled: func(key, _ []byte) error { + if bytes.Equal(key, accountAddress) && failAccountSave { + failAccountSave = false + return expectedErr + } + + return nil + }, + GetStorageManagerCalled: func() common.StorageManager { + return &storageManager.StorageManagerStub{} + }, + } + adb := generateAccountDBFromTrie(trieStub) + + adb.Journalize(&journalEntryStub{}) + snapshot := adb.JournalLen() + + accountEntryCalls := 0 + adb.Journalize(&journalEntryStub{ + revertCalled: func() (vmcommon.AccountHandler, error) { + accountEntryCalls++ + return account, nil + }, + }) + + completedCalls := 0 + adb.Journalize(&journalEntryStub{ + revertCalled: func() (vmcommon.AccountHandler, error) { + completedCalls++ + return nil, nil + }, + }) + + err := adb.RevertToSnapshot(snapshot) + require.ErrorIs(t, err, expectedErr) + + // The completed entry is removed, while the entry whose account save + // failed remains available for retry. + assert.Equal(t, snapshot+1, adb.JournalLen()) + assert.Equal(t, 1, completedCalls) + assert.Equal(t, 1, accountEntryCalls) + + err = adb.RevertToSnapshot(snapshot) + require.NoError(t, err) + + assert.Equal(t, snapshot, adb.JournalLen()) + assert.Equal(t, 1, completedCalls) + assert.Equal(t, 2, accountEntryCalls) +} diff --git a/state/export_test.go b/state/export_test.go index bbc209312e4..d9a9e98eaff 100644 --- a/state/export_test.go +++ b/state/export_test.go @@ -106,3 +106,8 @@ type AccountHandlerWithDataTrieMigrationStatus interface { vmcommon.AccountHandler IsDataTrieMigrated() (bool, error) } + +// Journalize exposes journalize for tests. +func (adb *AccountsDB) Journalize(entry JournalEntry) { + adb.journalize(entry) +} diff --git a/state/journalEntries.go b/state/journalEntries.go index 4aa2b79e9e2..b90967faed7 100644 --- a/state/journalEntries.go +++ b/state/journalEntries.go @@ -13,6 +13,7 @@ import ( type journalEntryCode struct { oldCodeEntry *CodeEntry oldCodeHash []byte + newCodeEntry *CodeEntry newCodeHash []byte trie Updater marshalizer marshal.Marshalizer @@ -22,6 +23,7 @@ type journalEntryCode struct { func NewJournalEntryCode( oldCodeEntry *CodeEntry, oldCodeHash []byte, + newCodeEntry *CodeEntry, newCodeHash []byte, trie Updater, marshalizer marshal.Marshalizer, @@ -36,6 +38,7 @@ func NewJournalEntryCode( return &journalEntryCode{ oldCodeEntry: oldCodeEntry, oldCodeHash: oldCodeHash, + newCodeEntry: newCodeEntry, newCodeHash: newCodeHash, trie: trie, marshalizer: marshalizer, @@ -48,12 +51,12 @@ func (jea *journalEntryCode) Revert() (vmcommon.AccountHandler, error) { return nil, nil } - err := jea.revertOldCodeEntry() + err := jea.revertCodeEntry(jea.oldCodeHash, jea.oldCodeEntry) if err != nil { return nil, err } - err = jea.revertNewCodeEntry() + err = jea.revertCodeEntry(jea.newCodeHash, jea.newCodeEntry) if err != nil { return nil, err } @@ -61,45 +64,26 @@ func (jea *journalEntryCode) Revert() (vmcommon.AccountHandler, error) { return nil, nil } -func (jea *journalEntryCode) revertOldCodeEntry() error { - if len(jea.oldCodeHash) == 0 { +func (jea *journalEntryCode) revertCodeEntry( + codeHash []byte, + codeEntry *CodeEntry, +) error { + if len(codeHash) == 0 { return nil } - _, err := saveCodeEntry(jea.oldCodeHash, jea.oldCodeEntry, jea.trie, jea.marshalizer) - if err != nil { - return err - } - - return nil -} - -func (jea *journalEntryCode) revertNewCodeEntry() error { - newCodeEntry, err := getCodeEntry(jea.newCodeHash, jea.trie, jea.marshalizer) - if err != nil { - return err - } - - if newCodeEntry == nil { - return nil + if codeEntry == nil { + return jea.trie.Update(codeHash, nil) } - if newCodeEntry.NumReferences <= 1 { - err = jea.trie.Update(jea.newCodeHash, nil) - if err != nil { - return err - } - - return nil - } - - newCodeEntry.NumReferences-- - _, err = saveCodeEntry(jea.newCodeHash, newCodeEntry, jea.trie, jea.marshalizer) - if err != nil { - return err - } + _, err := saveCodeEntry( + codeHash, + codeEntry, + jea.trie, + jea.marshalizer, + ) - return nil + return err } // IsInterfaceNil returns true if there is no value under the interface diff --git a/state/journalEntries_test.go b/state/journalEntries_test.go index 7530f6cbfae..7f66a45a4e5 100644 --- a/state/journalEntries_test.go +++ b/state/journalEntries_test.go @@ -17,7 +17,7 @@ import ( func TestNewJournalEntryCode_NilUpdaterShouldErr(t *testing.T) { t.Parallel() - entry, err := state.NewJournalEntryCode(&state.CodeEntry{}, []byte("code hash"), []byte("code hash"), nil, &marshallerMock.MarshalizerMock{}) + entry, err := state.NewJournalEntryCode(&state.CodeEntry{}, []byte("code hash"), &state.CodeEntry{}, []byte("code hash"), nil, &marshallerMock.MarshalizerMock{}) assert.True(t, check.IfNil(entry)) assert.Equal(t, state.ErrNilUpdater, err) } @@ -25,7 +25,7 @@ func TestNewJournalEntryCode_NilUpdaterShouldErr(t *testing.T) { func TestNewJournalEntryCode_NilMarshalizerShouldErr(t *testing.T) { t.Parallel() - entry, err := state.NewJournalEntryCode(&state.CodeEntry{}, []byte("code hash"), []byte("code hash"), &trieMock.TrieStub{}, nil) + entry, err := state.NewJournalEntryCode(&state.CodeEntry{}, []byte("code hash"), &state.CodeEntry{}, []byte("code hash"), &trieMock.TrieStub{}, nil) assert.True(t, check.IfNil(entry)) assert.Equal(t, state.ErrNilMarshalizer, err) } @@ -33,7 +33,7 @@ func TestNewJournalEntryCode_NilMarshalizerShouldErr(t *testing.T) { func TestNewJournalEntryCode_OkParams(t *testing.T) { t.Parallel() - entry, err := state.NewJournalEntryCode(&state.CodeEntry{}, []byte("code hash"), []byte("code hash"), &trieMock.TrieStub{}, &marshallerMock.MarshalizerMock{}) + entry, err := state.NewJournalEntryCode(&state.CodeEntry{}, []byte("code hash"), &state.CodeEntry{}, []byte("code hash"), &trieMock.TrieStub{}, &marshallerMock.MarshalizerMock{}) assert.Nil(t, err) assert.False(t, check.IfNil(entry)) } @@ -42,7 +42,7 @@ func TestJournalEntryCode_OldHashAndNewHashAreNil(t *testing.T) { t.Parallel() trieStub := &trieMock.TrieStub{} - entry, _ := state.NewJournalEntryCode(&state.CodeEntry{}, nil, nil, trieStub, &marshallerMock.MarshalizerMock{}) + entry, _ := state.NewJournalEntryCode(&state.CodeEntry{}, nil, &state.CodeEntry{}, nil, trieStub, &marshallerMock.MarshalizerMock{}) acc, err := entry.Revert() assert.Nil(t, err) @@ -72,6 +72,7 @@ func TestJournalEntryCode_OldHashIsNilAndNewHashIsNotNil(t *testing.T) { entry, _ := state.NewJournalEntryCode( &state.CodeEntry{}, nil, + &state.CodeEntry{}, []byte("newHash"), trieStub, marshalizer, diff --git a/state/trackableDataTrie/export_test.go b/state/trackableDataTrie/export_test.go index ae44bbbdf7e..607fd6c2d5c 100644 --- a/state/trackableDataTrie/export_test.go +++ b/state/trackableDataTrie/export_test.go @@ -27,3 +27,12 @@ func (tdt *trackableDataTrie) GetValueForVersion(key []byte, val []byte, version valWithMetadata, _ := tdt.getValueForVersion(key, val, version) return valWithMetadata } + +// SetDirtyData - +func (tdt *trackableDataTrie) SetDirtyData(index int, key string, value []byte, newVersion core.TrieNodeVersion) { + tdt.dirtyData[key] = dirtyData{ + index: index, + value: value, + newVersion: newVersion, + } +} diff --git a/state/trackableDataTrie/trackableDataTrie.go b/state/trackableDataTrie/trackableDataTrie.go index 8b6a7333496..500362b839a 100644 --- a/state/trackableDataTrie/trackableDataTrie.go +++ b/state/trackableDataTrie/trackableDataTrie.go @@ -258,29 +258,40 @@ func (tdt *trackableDataTrie) SaveDirtyData(mainTrie common.Trie) ([]*stateChang return tdt.updateTrie(dtr) } +func (tdt *trackableDataTrie) rollbackAppliedUpdates(dtr state.DataTrie, oldValues []core.TrieData) { + for i := len(oldValues) - 1; i >= 0; i-- { + trieUpdate := oldValues[i] + err := dtr.UpdateWithVersion(trieUpdate.Key, trieUpdate.Value, trieUpdate.Version) + if err != nil { + log.Error("could not apply rollback updates", "err", err, "key", trieUpdate.Key, "account", tdt.identifier) + } + } +} + func (tdt *trackableDataTrie) updateTrie(dtr state.DataTrie) ([]*stateChange.DataTrieChange, []core.TrieData, error) { oldValues := make([]core.TrieData, len(tdt.dirtyData)) newData := make([]*stateChange.DataTrieChange, len(tdt.dirtyData)) deletedKeys := make([]*stateChange.DataTrieChange, 0) + trieUpdates := make([]core.TrieData, 0) index := 0 for key, dataEntry := range tdt.dirtyData { oldVal, _, err := tdt.retrieveValueFromTrie([]byte(key)) if err != nil { + tdt.rollbackAppliedUpdates(dtr, trieUpdates) return nil, nil, err } oldValues[index] = oldVal wasDeleted, err := tdt.deleteOldEntryIfMigrated([]byte(key), dataEntry, oldVal) if err != nil { + tdt.rollbackAppliedUpdates(dtr, trieUpdates) return nil, nil, err } if wasDeleted { - originalVal, err := tdt.getValueNotSpecifiedVersion([]byte(key), oldVal.Value) - if err != nil { - return nil, nil, err - } + trieUpdates = append(trieUpdates, oldVal) + originalVal := tdt.getValueNotSpecifiedVersion([]byte(key), oldVal.Value) deletedKeys = append(deletedKeys, &stateChange.DataTrieChange{ @@ -295,6 +306,7 @@ func (tdt *trackableDataTrie) updateTrie(dtr state.DataTrie) ([]*stateChange.Dat dataTrieKey, err := tdt.modifyTrie([]byte(key), dataEntry, oldVal, dtr) if err != nil { + tdt.rollbackAppliedUpdates(dtr, trieUpdates) return nil, nil, err } @@ -302,10 +314,14 @@ func (tdt *trackableDataTrie) updateTrie(dtr state.DataTrie) ([]*stateChange.Dat isFirstMigration := oldVal.Version == core.NotSpecified && dataEntry.newVersion == core.AutoBalanceEnabled if isFirstMigration && len(dataTrieKey) != 0 { - oldValues = append(oldValues, core.TrieData{ + deletedData := core.TrieData{ Key: dataTrieKey, Value: nil, - }) + } + oldValues = append(oldValues, deletedData) + trieUpdates = append(trieUpdates, deletedData) + } else if len(dataTrieKey) != 0 && !wasDeleted { + trieUpdates = append(trieUpdates, oldVal) } if len(dataTrieKey) == 0 { @@ -313,6 +329,7 @@ func (tdt *trackableDataTrie) updateTrie(dtr state.DataTrie) ([]*stateChange.Dat } if dataEntry.index > len(newData)-1 { + tdt.rollbackAppliedUpdates(dtr, trieUpdates) return nil, nil, fmt.Errorf("index out of range") } @@ -324,6 +341,7 @@ func (tdt *trackableDataTrie) updateTrie(dtr state.DataTrie) ([]*stateChange.Dat version = oldVal.Version val, err = tdt.getValueWithoutMetadata([]byte(key), oldVal) if err != nil { + tdt.rollbackAppliedUpdates(dtr, trieUpdates) return nil, nil, err } } @@ -434,7 +452,7 @@ func (tdt *trackableDataTrie) getValueWithoutMetadata(key []byte, trieData core. return tdt.getValueAutoBalanceVersion(trieData.Value) } - return tdt.getValueNotSpecifiedVersion(key, trieData.Value) + return tdt.getValueNotSpecifiedVersion(key, trieData.Value), nil } func (tdt *trackableDataTrie) getValueAutoBalanceVersion(val []byte) ([]byte, error) { @@ -447,11 +465,11 @@ func (tdt *trackableDataTrie) getValueAutoBalanceVersion(val []byte) ([]byte, er return dataTrieVal.Value, nil } -func (tdt *trackableDataTrie) getValueNotSpecifiedVersion(key []byte, val []byte) ([]byte, error) { +func (tdt *trackableDataTrie) getValueNotSpecifiedVersion(key []byte, val []byte) []byte { tailLength := len(key) + len(tdt.identifier) trimmedValue, _ := common.TrimSuffixFromValue(val, tailLength) - return trimmedValue, nil + return trimmedValue } func (tdt *trackableDataTrie) deleteOldEntryIfMigrated(key []byte, newData dirtyData, oldEntry core.TrieData) (bool, error) { diff --git a/state/trackableDataTrie/trackableDataTrie_test.go b/state/trackableDataTrie/trackableDataTrie_test.go index b68a18edc65..87d764487eb 100644 --- a/state/trackableDataTrie/trackableDataTrie_test.go +++ b/state/trackableDataTrie/trackableDataTrie_test.go @@ -11,6 +11,7 @@ import ( vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/pkg/errors" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/multiversx/mx-chain-go/common" errorsCommon "github.com/multiversx/mx-chain-go/errors" @@ -1184,3 +1185,318 @@ func TestTrackableDataTrie_SetAndGetDataTrie(t *testing.T) { tdt.SetDataTrie(newTrie) assert.Equal(t, newTrie, tdt.DataTrie()) } + +func TestTrackableDataTrie_SaveDirtyDataShouldRollbackPreviousUpdateWhenLaterGetFails(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("expected get error") + firstKey := []byte("key1") + secondKey := []byte("key2") + firstOldValue := []byte("old1") + firstNewValue := []byte("new1") + identifier := []byte("identifier") + + getCalls := 0 + updateCalls := 0 + rollbackCalled := false + + trie := &trieMock.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + getCalls++ + if bytes.Equal(key, firstKey) { + return append(firstOldValue, append(firstKey, identifier...)...), 0, nil + } + if bytes.Equal(key, secondKey) { + return nil, 0, expectedErr + } + return nil, 0, nil + }, + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { + updateCalls++ + if updateCalls == 1 { + assert.Equal(t, firstKey, key) + assert.Equal(t, append(firstNewValue, append(firstKey, identifier...)...), value) + assert.Equal(t, core.NotSpecified, version) + return nil + } + + assert.Equal(t, firstKey, key) + assert.Equal(t, append(firstOldValue, append(firstKey, identifier...)...), value) + assert.Equal(t, core.NotSpecified, version) + rollbackCalled = true + return nil + }, + } + + tdt, _ := trackableDataTrie.NewTrackableDataTrie( + identifier, + &hashingMocks.HasherMock{}, + &marshallerMock.MarshalizerMock{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &stateMock.StateAccessesCollectorStub{}, + ) + tdt.SetDataTrie(trie) + + _ = tdt.SaveKeyValue(firstKey, firstNewValue) + _ = tdt.SaveKeyValue(secondKey, []byte("new2")) + + _, _, err := tdt.SaveDirtyData(trie) + require.ErrorIs(t, err, expectedErr) + assert.Equal(t, 2, getCalls) + assert.True(t, rollbackCalled) +} + +func TestTrackableDataTrie_SaveDirtyDataShouldRollbackPreviousUpdateWhenLaterUpdateFails(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("expected update error") + firstKey := []byte("key1") + secondKey := []byte("key2") + firstOldValue := []byte("old1") + firstNewValue := []byte("new1") + identifier := []byte("identifier") + + updateCalls := 0 + rollbackCalled := false + + trie := &trieMock.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + if bytes.Equal(key, firstKey) { + return append(firstOldValue, append(firstKey, identifier...)...), 0, nil + } + return nil, 0, nil + }, + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { + updateCalls++ + switch updateCalls { + case 1: + assert.Equal(t, firstKey, key) + assert.Equal(t, append(firstNewValue, append(firstKey, identifier...)...), value) + assert.Equal(t, core.NotSpecified, version) + return nil + case 2: + assert.Equal(t, secondKey, key) + return expectedErr + case 3: + assert.Equal(t, firstKey, key) + assert.Equal(t, append(firstOldValue, append(firstKey, identifier...)...), value) + assert.Equal(t, core.NotSpecified, version) + rollbackCalled = true + return nil + default: + require.Fail(t, "unexpected update call") + return nil + } + }, + } + + tdt, _ := trackableDataTrie.NewTrackableDataTrie( + identifier, + &hashingMocks.HasherMock{}, + &marshallerMock.MarshalizerMock{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &stateMock.StateAccessesCollectorStub{}, + ) + tdt.SetDataTrie(trie) + + _ = tdt.SaveKeyValue(firstKey, firstNewValue) + _ = tdt.SaveKeyValue(secondKey, []byte("new2")) + + _, _, err := tdt.SaveDirtyData(trie) + require.ErrorIs(t, err, expectedErr) + assert.True(t, rollbackCalled) +} + +func TestTrackableDataTrie_SaveDirtyDataShouldRollbackMigrationDeleteWhenMetadataBuildFails(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("expected marshal error") + key := []byte("key") + oldValue := []byte("old") + identifier := []byte("identifier") + hasher := &hashingMocks.HasherMock{} + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return flag == common.AutoBalanceDataTriesFlag + }, + } + + deleteCalled := false + rollbackCalled := false + + trie := &trieMock.TrieStub{ + GetCalled: func(getKey []byte) ([]byte, uint32, error) { + if bytes.Equal(getKey, hasher.Compute(string(key))) { + return nil, 0, nil + } + assert.Equal(t, key, getKey) + return append(oldValue, append(key, identifier...)...), 0, nil + }, + DeleteCalled: func(deleteKey []byte) error { + assert.Equal(t, key, deleteKey) + deleteCalled = true + return nil + }, + UpdateWithVersionCalled: func(updateKey, value []byte, version core.TrieNodeVersion) error { + assert.Equal(t, key, updateKey) + assert.Equal(t, append(oldValue, append(key, identifier...)...), value) + assert.Equal(t, core.NotSpecified, version) + rollbackCalled = true + return nil + }, + } + + tdt, _ := trackableDataTrie.NewTrackableDataTrie( + identifier, + hasher, + &marshallerMock.MarshalizerStub{ + MarshalCalled: func(_ interface{}) ([]byte, error) { + return nil, expectedErr + }, + }, + enableEpochsHandler, + &stateMock.StateAccessesCollectorStub{}, + ) + tdt.SetDataTrie(trie) + + _ = tdt.SaveKeyValue(key, []byte("new")) + + _, _, err := tdt.SaveDirtyData(trie) + require.ErrorIs(t, err, expectedErr) + assert.True(t, deleteCalled) + assert.True(t, rollbackCalled) +} + +func TestTrackableDataTrie_SaveDirtyDataShouldRollbackMigrationDeleteAndNewKeyWhenPostUpdateFails(t *testing.T) { + t.Parallel() + + key := []byte("key") + oldValue := []byte("old") + newValue := []byte("new") + identifier := []byte("identifier") + hasher := &hashingMocks.HasherMock{} + hashedKey := hasher.Compute(string(key)) + + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return flag == common.AutoBalanceDataTriesFlag + }, + } + + updateCalls := 0 + rollbackHashedKeyCalled := false + rollbackOldKeyCalled := false + + trie := &trieMock.TrieStub{ + GetCalled: func(getKey []byte) ([]byte, uint32, error) { + if bytes.Equal(getKey, hashedKey) { + return nil, 0, nil + } + assert.Equal(t, key, getKey) + return append(oldValue, append(key, identifier...)...), 0, nil + }, + DeleteCalled: func(deleteKey []byte) error { + assert.Equal(t, key, deleteKey) + return nil + }, + UpdateWithVersionCalled: func(updateKey, value []byte, version core.TrieNodeVersion) error { + updateCalls++ + switch updateCalls { + case 1: + assert.Equal(t, hashedKey, updateKey) + assert.Equal(t, core.AutoBalanceEnabled, version) + return nil + case 2: + assert.Equal(t, hashedKey, updateKey) + assert.Nil(t, value) + rollbackHashedKeyCalled = true + return nil + case 3: + assert.Equal(t, key, updateKey) + assert.Equal(t, append(oldValue, append(key, identifier...)...), value) + assert.Equal(t, core.NotSpecified, version) + rollbackOldKeyCalled = true + return nil + default: + require.Fail(t, "unexpected update call") + return nil + } + }, + } + + tdt, _ := trackableDataTrie.NewTrackableDataTrie( + identifier, + hasher, + &marshallerMock.MarshalizerMock{}, + enableEpochsHandler, + &stateMock.StateAccessesCollectorStub{}, + ) + tdt.SetDataTrie(trie) + + _ = tdt.SaveKeyValue(key, newValue) + + // Force the post-update index check to fail after old-key delete and new-key update. + tdt.SetDirtyData(10, string(key), newValue, core.AutoBalanceEnabled) + + _, _, err := tdt.SaveDirtyData(trie) + require.ErrorContains(t, err, "index out of range") + assert.True(t, rollbackHashedKeyCalled) + assert.True(t, rollbackOldKeyCalled) +} + +func TestTrackableDataTrie_SaveDirtyDataShouldRollbackDeleteWhenMetadataDecodeFails(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("expected unmarshal error") + key := []byte("key") + hasher := &hashingMocks.HasherMock{} + hashedKey := hasher.Compute(string(key)) + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return flag == common.AutoBalanceDataTriesFlag + }, + } + + serializedOldValue := []byte("invalid metadata") + deleteCalled := false + rollbackCalled := false + + trie := &trieMock.TrieStub{ + GetCalled: func(getKey []byte) ([]byte, uint32, error) { + assert.Equal(t, hashedKey, getKey) + return serializedOldValue, 0, nil + }, + DeleteCalled: func(deleteKey []byte) error { + assert.Equal(t, hashedKey, deleteKey) + deleteCalled = true + return nil + }, + UpdateWithVersionCalled: func(updateKey, value []byte, version core.TrieNodeVersion) error { + assert.Equal(t, hashedKey, updateKey) + assert.Equal(t, serializedOldValue, value) + assert.Equal(t, core.AutoBalanceEnabled, version) + rollbackCalled = true + return nil + }, + } + + tdt, _ := trackableDataTrie.NewTrackableDataTrie( + []byte("identifier"), + hasher, + &marshallerMock.MarshalizerStub{ + UnmarshalCalled: func(_ interface{}, _ []byte) error { + return expectedErr + }, + }, + enableEpochsHandler, + &stateMock.StateAccessesCollectorStub{}, + ) + tdt.SetDataTrie(trie) + + _ = tdt.SaveKeyValue(key, nil) + + _, _, err := tdt.SaveDirtyData(trie) + require.ErrorIs(t, err, expectedErr) + assert.True(t, deleteCalled) + assert.True(t, rollbackCalled) +}