Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions src/coinjoin/coinjoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,21 +301,17 @@ bool CoinJoin::IsCollateralValid(ChainstateManager& chainman, const llmq::CInsta
}
}

LOCK(::cs_main);
CCoinsViewMemPool viewMemPool(&chainman.ActiveChainstate().CoinsTip(), mempool);

for (const auto& txin : txCollateral.vin) {
Coin coin;
auto mempoolTx = mempool.get(txin.prevout.hash);
if (mempoolTx != nullptr) {
if (mempool.isSpent(txin.prevout) || !isman.IsLocked(txin.prevout.hash)) {
LogPrint(BCLog::COINJOIN, "CoinJoin::IsCollateralValid -- spent or non-locked mempool input! txin=%s\n", txin.ToString());
return false;
}
nValueIn += mempoolTx->vout[txin.prevout.n].nValue;
} else if (GetUTXOCoin(chainman.ActiveChainstate(), txin.prevout, coin)) {
nValueIn += coin.out.nValue;
} else {
LogPrint(BCLog::COINJOIN, "CoinJoin::IsCollateralValid -- Unknown inputs in collateral transaction, txCollateral=%s", txCollateral.ToString()); /* Continued */
if (!viewMemPool.GetCoin(txin.prevout, coin) || coin.IsSpent() ||
(coin.nHeight == MEMPOOL_HEIGHT && !isman.IsLocked(txin.prevout.hash))) {
LogPrint(BCLog::COINJOIN, "CoinJoin::IsCollateralValid -- missing, spent or non-locked mempool input! txin=%s\n", txin.ToString());
return false;
}
nValueIn += coin.out.nValue;
}

//collateral transactions are required to pay out a small fee to the miners
Expand All @@ -326,12 +322,9 @@ bool CoinJoin::IsCollateralValid(ChainstateManager& chainman, const llmq::CInsta

LogPrint(BCLog::COINJOIN, "CoinJoin::IsCollateralValid -- %s", txCollateral.ToString()); /* Continued */

{
LOCK(::cs_main);
if (!ATMPIfSaneFee(chainman, MakeTransactionRef(txCollateral), /*test_accept=*/true)) {
LogPrint(BCLog::COINJOIN, "CoinJoin::IsCollateralValid -- didn't pass ATMPIfSaneFee()\n");
return false;
}
if (!ATMPIfSaneFee(chainman, MakeTransactionRef(txCollateral), /*test_accept=*/true)) {
LogPrint(BCLog::COINJOIN, "CoinJoin::IsCollateralValid -- didn't pass ATMPIfSaneFee()\n");
return false;
}

return true;
Expand Down
26 changes: 17 additions & 9 deletions src/llmq/dkgsession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,21 +410,20 @@ bool CDKGSession::PreVerifyMessage(const CDKGJustification& qj, bool& retBan) co
}

std::set<size_t> contributionsSet;
for (const auto& p : qj.contributions) {
if (p.index > members.size()) {
for (const auto& [index, skContribution] : qj.contributions) {
if (GetMemberAtIndex(index) == nullptr) {
logger.Batch("invalid contribution index");
retBan = true;
return false;
}

if (!contributionsSet.emplace(p.index).second) {
if (!contributionsSet.emplace(index).second) {
logger.Batch("duplicate contribution index");
retBan = true;
return false;
}

const auto& skShare = p.key;
if (!skShare.IsValid()) {
if (!skContribution.IsValid()) {
logger.Batch("invalid contribution");
retBan = true;
return false;
Expand Down Expand Up @@ -482,8 +481,9 @@ std::optional<CInv> CDKGSession::ReceiveMessage(const CDKGJustification& qj)
return inv;
}

for (const auto& p : qj.contributions) {
const auto& member2 = members[p.index];
for (const auto& [index, skContribution] : qj.contributions) {
const auto* member2 = GetMemberAtIndex(index);
assert(member2);

if (member->complaintsFromOthers.count(member2->dmn->proTxHash) == 0) {
logger.Batch("got justification from %s for %s even though he didn't complain",
Expand All @@ -499,14 +499,16 @@ std::optional<CInv> CDKGSession::ReceiveMessage(const CDKGJustification& qj)

std::list<std::future<bool>> futures;
for (const auto& [index, skContribution] : qj.contributions) {
const auto& member2 = members[index];
const auto* member2 = GetMemberAtIndex(index);
assert(member2);

// watch out to not bail out before these async calls finish (they rely on valid references)
futures.emplace_back(blsWorker.AsyncVerifyContributionShare(member2->id, receivedVvecs[member->idx], skContribution));
}
auto resultIt = futures.begin();
for (const auto& [index, skContribution] : qj.contributions) {
const auto& member2 = members[index];
const auto* member2 = GetMemberAtIndex(index);
assert(member2);

bool result = (resultIt++)->get();
if (!result) {
Expand Down Expand Up @@ -683,6 +685,12 @@ CDKGMember* CDKGSession::GetMember(const uint256& proTxHash) const
return members[it->second].get();
}

CDKGMember* CDKGSession::GetMemberAtIndex(size_t index) const
{
if (index >= members.size()) return nullptr;
return members[index].get();
}

void CDKGSession::MarkBadMember(size_t idx)
{
auto* member = members.at(idx).get();
Expand Down
1 change: 1 addition & 0 deletions src/llmq/dkgsession.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ class CDKGSession
public:
[[nodiscard]] bool AreWeMember() const { return !myProTxHash.IsNull(); }
[[nodiscard]] CDKGMember* GetMember(const uint256& proTxHash) const;
[[nodiscard]] CDKGMember* GetMemberAtIndex(size_t index) const;
[[nodiscard]] std::optional<size_t> GetMyMemberIndex() const { return myIdx; }
[[nodiscard]] const Uint256HashSet& RelayMembers() const { return relayMembers; }
[[nodiscard]] const CBlockIndex* BlockIndex() const { return m_quorum_base_block_index; }
Expand Down
15 changes: 8 additions & 7 deletions src/serialize.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <list>
#include <map>
#include <memory>
#include <optional>
#include <set>
#include <string>
#include <string.h>
Expand Down Expand Up @@ -459,10 +460,10 @@ void ReadFixedBitSet(Stream& s, std::vector<bool>& vec, size_t size)
template<typename Stream>
void WriteFixedVarIntsBitSet(Stream& s, const std::vector<bool>& vec, size_t size)
{
int32_t last = -1;
for (int32_t i = 0; i < (int32_t)vec.size(); i++) {
std::optional<size_t> last;
for (size_t i = 0; i < vec.size(); i++) {
if (vec[i]) {
WriteVarInt<Stream, VarIntMode::DEFAULT, uint32_t>(s, (uint32_t)(i - last));
WriteVarInt<Stream, VarIntMode::DEFAULT, uint32_t>(s, static_cast<uint32_t>(last ? (i - *last) : (i + 1)));
last = i;
}
}
Expand All @@ -474,17 +475,17 @@ void ReadFixedVarIntsBitSet(Stream& s, std::vector<bool>& vec, size_t size)
{
vec.assign(size, false);

int32_t last = -1;
std::optional<size_t> last;
while(true) {
uint32_t offset = ReadVarInt<Stream, VarIntMode::DEFAULT, uint32_t>(s);
if (offset == 0) {
break;
}
int32_t idx = last + offset;
if (idx >= int32_t(size)) {
size_t idx = last ? (*last + offset) : (static_cast<size_t>(offset) - 1);
if (idx >= size) {
throw std::ios_base::failure("out of bounds index");
}
if (last != -1 && idx <= last) {
if (last.has_value() && idx <= *last) {
throw std::ios_base::failure("offset overflow");
}
vec[idx] = true;
Expand Down
Loading