diff --git a/src/coinjoin/coinjoin.cpp b/src/coinjoin/coinjoin.cpp index 9c5d99ea3f81..33fbc43d6b0e 100644 --- a/src/coinjoin/coinjoin.cpp +++ b/src/coinjoin/coinjoin.cpp @@ -301,21 +301,17 @@ bool CoinJoin::IsCollateralValid(ChainstateManager& chainman, const llmq::CInsta } } + LOCK(::cs_main); + CCoinsViewMemPool viewMemPool(&chainman.ActiveChainstate().CoinsTip(), mempool); + for (const auto& txin : txCollateral.vin) { Coin coin; - auto mempoolTx = mempool.get(txin.prevout.hash); - if (mempoolTx != nullptr) { - if (mempool.isSpent(txin.prevout) || !isman.IsLocked(txin.prevout.hash)) { - LogPrint(BCLog::COINJOIN, "CoinJoin::IsCollateralValid -- spent or non-locked mempool input! txin=%s\n", txin.ToString()); - return false; - } - nValueIn += mempoolTx->vout[txin.prevout.n].nValue; - } else if (GetUTXOCoin(chainman.ActiveChainstate(), txin.prevout, coin)) { - nValueIn += coin.out.nValue; - } else { - LogPrint(BCLog::COINJOIN, "CoinJoin::IsCollateralValid -- Unknown inputs in collateral transaction, txCollateral=%s", txCollateral.ToString()); /* Continued */ + if (!viewMemPool.GetCoin(txin.prevout, coin) || coin.IsSpent() || + (coin.nHeight == MEMPOOL_HEIGHT && !isman.IsLocked(txin.prevout.hash))) { + LogPrint(BCLog::COINJOIN, "CoinJoin::IsCollateralValid -- missing, spent or non-locked mempool input! txin=%s\n", txin.ToString()); return false; } + nValueIn += coin.out.nValue; } //collateral transactions are required to pay out a small fee to the miners @@ -326,12 +322,9 @@ bool CoinJoin::IsCollateralValid(ChainstateManager& chainman, const llmq::CInsta LogPrint(BCLog::COINJOIN, "CoinJoin::IsCollateralValid -- %s", txCollateral.ToString()); /* Continued */ - { - LOCK(::cs_main); - if (!ATMPIfSaneFee(chainman, MakeTransactionRef(txCollateral), /*test_accept=*/true)) { - LogPrint(BCLog::COINJOIN, "CoinJoin::IsCollateralValid -- didn't pass ATMPIfSaneFee()\n"); - return false; - } + if (!ATMPIfSaneFee(chainman, MakeTransactionRef(txCollateral), /*test_accept=*/true)) { + LogPrint(BCLog::COINJOIN, "CoinJoin::IsCollateralValid -- didn't pass ATMPIfSaneFee()\n"); + return false; } return true; diff --git a/src/llmq/dkgsession.cpp b/src/llmq/dkgsession.cpp index fba7e8b38d39..db1cb05311c4 100644 --- a/src/llmq/dkgsession.cpp +++ b/src/llmq/dkgsession.cpp @@ -410,21 +410,20 @@ bool CDKGSession::PreVerifyMessage(const CDKGJustification& qj, bool& retBan) co } std::set contributionsSet; - for (const auto& p : qj.contributions) { - if (p.index > members.size()) { + for (const auto& [index, skContribution] : qj.contributions) { + if (GetMemberAtIndex(index) == nullptr) { logger.Batch("invalid contribution index"); retBan = true; return false; } - if (!contributionsSet.emplace(p.index).second) { + if (!contributionsSet.emplace(index).second) { logger.Batch("duplicate contribution index"); retBan = true; return false; } - const auto& skShare = p.key; - if (!skShare.IsValid()) { + if (!skContribution.IsValid()) { logger.Batch("invalid contribution"); retBan = true; return false; @@ -482,8 +481,9 @@ std::optional CDKGSession::ReceiveMessage(const CDKGJustification& qj) return inv; } - for (const auto& p : qj.contributions) { - const auto& member2 = members[p.index]; + for (const auto& [index, skContribution] : qj.contributions) { + const auto* member2 = GetMemberAtIndex(index); + assert(member2); if (member->complaintsFromOthers.count(member2->dmn->proTxHash) == 0) { logger.Batch("got justification from %s for %s even though he didn't complain", @@ -499,14 +499,16 @@ std::optional CDKGSession::ReceiveMessage(const CDKGJustification& qj) std::list> futures; for (const auto& [index, skContribution] : qj.contributions) { - const auto& member2 = members[index]; + const auto* member2 = GetMemberAtIndex(index); + assert(member2); // watch out to not bail out before these async calls finish (they rely on valid references) futures.emplace_back(blsWorker.AsyncVerifyContributionShare(member2->id, receivedVvecs[member->idx], skContribution)); } auto resultIt = futures.begin(); for (const auto& [index, skContribution] : qj.contributions) { - const auto& member2 = members[index]; + const auto* member2 = GetMemberAtIndex(index); + assert(member2); bool result = (resultIt++)->get(); if (!result) { @@ -683,6 +685,12 @@ CDKGMember* CDKGSession::GetMember(const uint256& proTxHash) const return members[it->second].get(); } +CDKGMember* CDKGSession::GetMemberAtIndex(size_t index) const +{ + if (index >= members.size()) return nullptr; + return members[index].get(); +} + void CDKGSession::MarkBadMember(size_t idx) { auto* member = members.at(idx).get(); diff --git a/src/llmq/dkgsession.h b/src/llmq/dkgsession.h index f9a4ce756e92..95529b2bdde6 100644 --- a/src/llmq/dkgsession.h +++ b/src/llmq/dkgsession.h @@ -385,6 +385,7 @@ class CDKGSession public: [[nodiscard]] bool AreWeMember() const { return !myProTxHash.IsNull(); } [[nodiscard]] CDKGMember* GetMember(const uint256& proTxHash) const; + [[nodiscard]] CDKGMember* GetMemberAtIndex(size_t index) const; [[nodiscard]] std::optional GetMyMemberIndex() const { return myIdx; } [[nodiscard]] const Uint256HashSet& RelayMembers() const { return relayMembers; } [[nodiscard]] const CBlockIndex* BlockIndex() const { return m_quorum_base_block_index; } diff --git a/src/serialize.h b/src/serialize.h index 027f9d87c434..2e119d340081 100644 --- a/src/serialize.h +++ b/src/serialize.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -459,10 +460,10 @@ void ReadFixedBitSet(Stream& s, std::vector& vec, size_t size) template void WriteFixedVarIntsBitSet(Stream& s, const std::vector& vec, size_t size) { - int32_t last = -1; - for (int32_t i = 0; i < (int32_t)vec.size(); i++) { + std::optional last; + for (size_t i = 0; i < vec.size(); i++) { if (vec[i]) { - WriteVarInt(s, (uint32_t)(i - last)); + WriteVarInt(s, static_cast(last ? (i - *last) : (i + 1))); last = i; } } @@ -474,17 +475,17 @@ void ReadFixedVarIntsBitSet(Stream& s, std::vector& vec, size_t size) { vec.assign(size, false); - int32_t last = -1; + std::optional last; while(true) { uint32_t offset = ReadVarInt(s); if (offset == 0) { break; } - int32_t idx = last + offset; - if (idx >= int32_t(size)) { + size_t idx = last ? (*last + offset) : (static_cast(offset) - 1); + if (idx >= size) { throw std::ios_base::failure("out of bounds index"); } - if (last != -1 && idx <= last) { + if (last.has_value() && idx <= *last) { throw std::ios_base::failure("offset overflow"); } vec[idx] = true;