From 667efd9154111b6fd4469703fe540a57519b7b80 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Fri, 7 Nov 2025 15:50:36 +0100 Subject: [PATCH] Introduce default parameters for some appropriate functions and merge some signatures --- cpp/command/genbook.cpp | 4 +- cpp/command/selfplay.cpp | 4 +- cpp/core/config_parser.cpp | 127 ++++---------------------------- cpp/core/config_parser.h | 30 +++----- cpp/dataio/trainingwrite.cpp | 32 ++++---- cpp/dataio/trainingwrite.h | 4 +- cpp/game/board.cpp | 4 - cpp/game/board.h | 7 +- cpp/game/boardhistory.cpp | 4 - cpp/game/boardhistory.h | 3 +- cpp/search/asyncbot.cpp | 2 +- cpp/search/search.cpp | 5 +- cpp/search/search.h | 15 +--- cpp/tests/testboardbasic.cpp | 6 +- cpp/tests/testsymmetries.cpp | 4 +- cpp/tests/testtrainingwrite.cpp | 43 +++++------ 16 files changed, 80 insertions(+), 214 deletions(-) diff --git a/cpp/command/genbook.cpp b/cpp/command/genbook.cpp index 7d44ff201..bbe9fc435 100644 --- a/cpp/command/genbook.cpp +++ b/cpp/command/genbook.cpp @@ -525,8 +525,8 @@ int MainCmds::genbook(const vector& args) { if(!bonusInitialBoard.isEqualForTesting(book->getInitialHist().getRecentBoard(0), false, false)) throw StringError( "Book initial board and initial board in bonus sgf file do not match\n" + - Board::toStringSimple(book->getInitialHist().getRecentBoard(0),'\n') + "\n" + - Board::toStringSimple(bonusInitialBoard,'\n') + Board::toStringSimple(book->getInitialHist().getRecentBoard(0)) + "\n" + + Board::toStringSimple(bonusInitialBoard) ); if(bonusInitialPla != book->initialPla) throw StringError( diff --git a/cpp/command/selfplay.cpp b/cpp/command/selfplay.cpp index b03fb8c7d..8c712a037 100644 --- a/cpp/command/selfplay.cpp +++ b/cpp/command/selfplay.cpp @@ -212,8 +212,8 @@ int MainCmds::selfplay(const vector& args) { //Note that this inputsVersion passed here is NOT necessarily the same as the one used in the neural net self play, it //simply controls the input feature version for the written data - TrainingDataWriter* tdataWriter = new TrainingDataWriter( - tdataOutputDir, inputsVersion, maxRowsPerTrainFile, firstFileRandMinProp, dataBoardLen, dataBoardLen, Global::uint64ToHexString(rand.nextUInt64())); + auto tdataWriter = new TrainingDataWriter( + tdataOutputDir, nullptr, inputsVersion, maxRowsPerTrainFile, firstFileRandMinProp, dataBoardLen, dataBoardLen, Global::uint64ToHexString(rand.nextUInt64())); ofstream* sgfOut = NULL; if(sgfOutputDir.length() > 0) { sgfOut = new ofstream(); diff --git a/cpp/core/config_parser.cpp b/cpp/core/config_parser.cpp index e78388b63..8acfd15e1 100644 --- a/cpp/core/config_parser.cpp +++ b/cpp/core/config_parser.cpp @@ -2,7 +2,6 @@ #include "../core/fileutils.h" -#include #include #include @@ -113,8 +112,6 @@ void ConfigParser::processIncludedFile(const std::string &fname) { baseDirs.pop_back(); } - - bool ConfigParser::parseKeyValue(const std::string& trimmedLine, std::string& key, std::string& value) { // Parse trimmed line, taking into account comments and quoting. key.clear(); @@ -598,14 +595,7 @@ enabled_t ConfigParser::getEnabled(const string& key) { return x; } -int ConfigParser::getInt(const string& key) { - string value = getString(key); - int x; - if(!Global::tryStringToInt(value,x)) - throw IOError("Could not parse '" + value + "' as int for key '" + key + "' in config file " + fileName); - return x; -} -int ConfigParser::getInt(const string& key, int min, int max) { +int ConfigParser::getInt(const string& key, const int min, const int max) { assert(min <= max); string value = getString(key); int x; @@ -615,19 +605,8 @@ int ConfigParser::getInt(const string& key, int min, int max) { throw IOError("Key '" + key + "' must be in the range " + Global::intToString(min) + " to " + Global::intToString(max) + " in config file " + fileName); return x; } -vector ConfigParser::getInts(const string& key) { - vector values = getStrings(key); - vector ret; - for(size_t i = 0; i ConfigParser::getInts(const string& key, int min, int max) { + +vector ConfigParser::getInts(const string& key, const int min, const int max) { vector values = getStrings(key); vector ret; for(size_t i = 0; i> ConfigParser::getNonNegativeIntDashedPairs(const stri return ret; } - -int64_t ConfigParser::getInt64(const string& key) { - string value = getString(key); - int64_t x; - if(!Global::tryStringToInt64(value,x)) - throw IOError("Could not parse '" + value + "' as int64_t for key '" + key + "' in config file " + fileName); - return x; -} -int64_t ConfigParser::getInt64(const string& key, int64_t min, int64_t max) { +int64_t ConfigParser::getInt64(const string& key, const int64_t min, const int64_t max) { assert(min <= max); string value = getString(key); int64_t x; @@ -688,19 +659,8 @@ int64_t ConfigParser::getInt64(const string& key, int64_t min, int64_t max) { throw IOError("Key '" + key + "' must be in the range " + Global::int64ToString(min) + " to " + Global::int64ToString(max) + " in config file " + fileName); return x; } -vector ConfigParser::getInt64s(const string& key) { - vector values = getStrings(key); - vector ret; - for(size_t i = 0; i ConfigParser::getInt64s(const string& key, int64_t min, int64_t max) { + +vector ConfigParser::getInt64s(const string& key, const int64_t min, const int64_t max) { vector values = getStrings(key); vector ret; for(size_t i = 0; i ConfigParser::getInt64s(const string& key, int64_t min, int64_t return ret; } - -uint64_t ConfigParser::getUInt64(const string& key) { - string value = getString(key); - uint64_t x; - if(!Global::tryStringToUInt64(value,x)) - throw IOError("Could not parse '" + value + "' as uint64_t for key '" + key + "' in config file " + fileName); - return x; -} -uint64_t ConfigParser::getUInt64(const string& key, uint64_t min, uint64_t max) { +uint64_t ConfigParser::getUInt64(const string& key, const uint64_t min, const uint64_t max) { assert(min <= max); string value = getString(key); uint64_t x; @@ -733,19 +685,8 @@ uint64_t ConfigParser::getUInt64(const string& key, uint64_t min, uint64_t max) throw IOError("Key '" + key + "' must be in the range " + Global::uint64ToString(min) + " to " + Global::uint64ToString(max) + " in config file " + fileName); return x; } -vector ConfigParser::getUInt64s(const string& key) { - vector values = getStrings(key); - vector ret; - for(size_t i = 0; i ConfigParser::getUInt64s(const string& key, uint64_t min, uint64_t max) { + +vector ConfigParser::getUInt64s(const string& key, const uint64_t min, const uint64_t max) { vector values = getStrings(key); vector ret; for(size_t i = 0; i ConfigParser::getUInt64s(const string& key, uint64_t min, uint6 return ret; } - -float ConfigParser::getFloat(const string& key) { - string value = getString(key); - float x; - if(!Global::tryStringToFloat(value,x)) - throw IOError("Could not parse '" + value + "' as float for key '" + key + "' in config file " + fileName); - return x; -} -float ConfigParser::getFloat(const string& key, float min, float max) { +float ConfigParser::getFloat(const string& key, const float min, const float max) { assert(min <= max); string value = getString(key); float x; @@ -780,19 +713,8 @@ float ConfigParser::getFloat(const string& key, float min, float max) { throw IOError("Key '" + key + "' must be in the range " + Global::floatToString(min) + " to " + Global::floatToString(max) + " in config file " + fileName); return x; } -vector ConfigParser::getFloats(const string& key) { - vector values = getStrings(key); - vector ret; - for(size_t i = 0; i ConfigParser::getFloats(const string& key, float min, float max) { + +vector ConfigParser::getFloats(const string& key, const float min, const float max) { vector values = getStrings(key); vector ret; for(size_t i = 0; i ConfigParser::getFloats(const string& key, float min, float max) { return ret; } - -double ConfigParser::getDouble(const string& key) { - string value = getString(key); - double x; - if(!Global::tryStringToDouble(value,x)) - throw IOError("Could not parse '" + value + "' as double for key '" + key + "' in config file " + fileName); - return x; -} -double ConfigParser::getDouble(const string& key, double min, double max) { +double ConfigParser::getDouble(const string& key, const double min, const double max) { assert(min <= max); string value = getString(key); double x; @@ -829,19 +743,8 @@ double ConfigParser::getDouble(const string& key, double min, double max) { throw IOError("Key '" + key + "' must be in the range " + Global::doubleToString(min) + " to " + Global::doubleToString(max) + " in config file " + fileName); return x; } -vector ConfigParser::getDoubles(const string& key) { - vector values = getStrings(key); - vector ret; - for(size_t i = 0; i ConfigParser::getDoubles(const string& key, double min, double max) { + +vector ConfigParser::getDoubles(const string& key, const double min, const double max) { vector values = getStrings(key); vector ret; for(size_t i = 0; i& possibles); - int getInt(const std::string& key, int min, int max); - int64_t getInt64(const std::string& key, int64_t min, int64_t max); - uint64_t getUInt64(const std::string& key, uint64_t min, uint64_t max); - float getFloat(const std::string& key, float min, float max); - double getDouble(const std::string& key, double min, double max); + int getInt(const std::string& key, int min = std::numeric_limits::min(), int max = std::numeric_limits::max()); + int64_t getInt64(const std::string& key, int64_t min = std::numeric_limits::min(), int64_t max = std::numeric_limits::max()); + uint64_t getUInt64(const std::string& key, uint64_t min = std::numeric_limits::min(), uint64_t max = std::numeric_limits::max()); + float getFloat(const std::string& key, float min = std::numeric_limits::min(), float max = std::numeric_limits::max()); + double getDouble(const std::string& key, double min = std::numeric_limits::min(), double max = std::numeric_limits::max()); std::vector getStrings(const std::string& key); std::vector getStringsNonEmptyTrim(const std::string& key); std::vector getBools(const std::string& key); - std::vector getInts(const std::string& key); - std::vector getInt64s(const std::string& key); - std::vector getUInt64s(const std::string& key); - std::vector getFloats(const std::string& key); - std::vector getDoubles(const std::string& key); std::vector getStrings(const std::string& key, const std::set& possibles); - std::vector getInts(const std::string& key, int min, int max); - std::vector getInt64s(const std::string& key, int64_t min, int64_t max); - std::vector getUInt64s(const std::string& key, uint64_t min, uint64_t max); - std::vector getFloats(const std::string& key, float min, float max); - std::vector getDoubles(const std::string& key, double min, double max); + std::vector getInts(const std::string& key, int min = std::numeric_limits::min(), int max = std::numeric_limits::max()); + std::vector getInt64s(const std::string& key, int64_t min = std::numeric_limits::min(), int64_t max = std::numeric_limits::max()); + std::vector getUInt64s(const std::string& key, uint64_t min = std::numeric_limits::min(), uint64_t max = std::numeric_limits::max()); + std::vector getFloats(const std::string& key, float min = std::numeric_limits::min(), float max = std::numeric_limits::max()); + std::vector getDoubles(const std::string& key, double min = std::numeric_limits::min(), double max = std::numeric_limits::max()); std::vector> getNonNegativeIntDashedPairs(const std::string& key, int min, int max); diff --git a/cpp/dataio/trainingwrite.cpp b/cpp/dataio/trainingwrite.cpp index a8b03bd25..2808ed4ad 100644 --- a/cpp/dataio/trainingwrite.cpp +++ b/cpp/dataio/trainingwrite.cpp @@ -954,48 +954,48 @@ void TrainingWriteBuffers::writeToTextOstream(ostream& out) { //------------------------------------------------------------------------------------- -TrainingDataWriter::TrainingDataWriter(const string& outDir, int iVersion, int maxRowsPerFile, double firstFileMinRandProp, int dataXLen, int dataYLen, const string& randSeed) - : TrainingDataWriter(outDir,NULL,iVersion,maxRowsPerFile,firstFileMinRandProp,dataXLen,dataYLen,1,randSeed) -{} -TrainingDataWriter::TrainingDataWriter(ostream* dbgOut, int iVersion, int maxRowsPerFile, double firstFileMinRandProp, int dataXLen, int dataYLen, int onlyEvery, const string& randSeed) - : TrainingDataWriter(string(),dbgOut,iVersion,maxRowsPerFile,firstFileMinRandProp,dataXLen,dataYLen,onlyEvery,randSeed) -{} - -TrainingDataWriter::TrainingDataWriter(const string& outDir, ostream* dbgOut, int iVersion, int maxRowsPerFile, double firstFileMinRandProp, int dataXLen, int dataYLen, int onlyEvery, const string& randSeed) - :outputDir(outDir),inputsVersion(iVersion),rand(randSeed),writeBuffers(NULL),debugOut(dbgOut),debugOnlyWriteEvery(onlyEvery),rowCount(0) +TrainingDataWriter::TrainingDataWriter(const string& outDir, ostream* dbgOut, + const int iVersion, + const int maxRowsPerFile, + const double firstFileMinRandProp, + const int dataXLen, + const int dataYLen, + const string& randSeed, + const int onlyWriteEvery) + :outputDir(outDir),inputsVersion(iVersion),rand(randSeed),writeBuffers(nullptr),debugOut(dbgOut),debugOnlyWriteEvery(onlyWriteEvery),rowCount(0) { int numBinaryChannels; int numGlobalChannels; //Note that this inputsVersion is for data writing, it might be different than the inputsVersion used //to feed into a model during selfplay static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); - if(inputsVersion == 3) { + if(iVersion == 3) { numBinaryChannels = NNInputs::NUM_FEATURES_SPATIAL_V3; numGlobalChannels = NNInputs::NUM_FEATURES_GLOBAL_V3; } - else if(inputsVersion == 4) { + else if(iVersion == 4) { numBinaryChannels = NNInputs::NUM_FEATURES_SPATIAL_V4; numGlobalChannels = NNInputs::NUM_FEATURES_GLOBAL_V4; } - else if(inputsVersion == 5) { + else if(iVersion == 5) { numBinaryChannels = NNInputs::NUM_FEATURES_SPATIAL_V5; numGlobalChannels = NNInputs::NUM_FEATURES_GLOBAL_V5; } - else if(inputsVersion == 6) { + else if(iVersion == 6) { numBinaryChannels = NNInputs::NUM_FEATURES_SPATIAL_V6; numGlobalChannels = NNInputs::NUM_FEATURES_GLOBAL_V6; } - else if(inputsVersion == 7) { + else if(iVersion == 7) { numBinaryChannels = NNInputs::NUM_FEATURES_SPATIAL_V7; numGlobalChannels = NNInputs::NUM_FEATURES_GLOBAL_V7; } else { - throw StringError("TrainingDataWriter: Unsupported inputs version: " + Global::intToString(inputsVersion)); + throw StringError("TrainingDataWriter: Unsupported inputs version: " + Global::intToString(iVersion)); } const bool hasMetadataInput = false; writeBuffers = new TrainingWriteBuffers( - inputsVersion, + iVersion, maxRowsPerFile, numBinaryChannels, numGlobalChannels, diff --git a/cpp/dataio/trainingwrite.h b/cpp/dataio/trainingwrite.h index 253c13258..0f309ced1 100644 --- a/cpp/dataio/trainingwrite.h +++ b/cpp/dataio/trainingwrite.h @@ -311,9 +311,7 @@ struct TrainingWriteBuffers { class TrainingDataWriter { public: - TrainingDataWriter(const std::string& outputDir, int inputsVersion, int maxRowsPerFile, double firstFileMinRandProp, int dataXLen, int dataYLen, const std::string& randSeed); - TrainingDataWriter(std::ostream* debugOut, int inputsVersion, int maxRowsPerFile, double firstFileMinRandProp, int dataXLen, int dataYLen, int onlyWriteEvery, const std::string& randSeed); - TrainingDataWriter(const std::string& outputDir, std::ostream* debugOut, int inputsVersion, int maxRowsPerFile, double firstFileMinRandProp, int dataXLen, int dataYLen, int onlyWriteEvery, const std::string& randSeed); + TrainingDataWriter(const std::string& outDir, std::ostream* dbgOut, int iVersion, int maxRowsPerFile, double firstFileMinRandProp, int dataXLen, int dataYLen, const std::string& randSeed, int onlyWriteEvery = 1); ~TrainingDataWriter(); void writeGame(const FinishedGameData& data); diff --git a/cpp/game/board.cpp b/cpp/game/board.cpp index 5ad4686f4..44965beb7 100644 --- a/cpp/game/board.cpp +++ b/cpp/game/board.cpp @@ -2688,10 +2688,6 @@ string Board::toStringSimple(const Board& board, char lineDelimiter) { return s; } -Board Board::parseBoard(int xSize, int ySize, const string& s) { - return parseBoard(xSize,ySize,s,'\n'); -} - Board Board::parseBoard(int xSize, int ySize, const string& s, char lineDelimiter) { Board board(xSize,ySize); vector lines = Global::split(Global::trim(s),lineDelimiter); diff --git a/cpp/game/board.h b/cpp/game/board.h index 4fdb2a259..78423521f 100644 --- a/cpp/game/board.h +++ b/cpp/game/board.h @@ -297,12 +297,11 @@ struct Board void checkConsistency() const; //For the moment, only used in testing since it does extra consistency checks. //If we need a version to be used in "prod", we could make an efficient version maybe as operator==. - bool isEqualForTesting(const Board& other, bool checkNumCaptures, bool checkSimpleKo) const; + bool isEqualForTesting(const Board& other, bool checkNumCaptures = true, bool checkSimpleKo = true) const; - static Board parseBoard(int xSize, int ySize, const std::string& s); - static Board parseBoard(int xSize, int ySize, const std::string& s, char lineDelimiter); + static Board parseBoard(int xSize, int ySize, const std::string& s, char lineDelimiter = '\n'); static void printBoard(std::ostream& out, const Board& board, Loc markLoc, const std::vector* hist); - static std::string toStringSimple(const Board& board, char lineDelimiter); + static std::string toStringSimple(const Board& board, char lineDelimiter = '\n'); static nlohmann::json toJson(const Board& board); static Board ofJson(const nlohmann::json& data); diff --git a/cpp/game/boardhistory.cpp b/cpp/game/boardhistory.cpp index 5b3c35a9b..22d39cefd 100644 --- a/cpp/game/boardhistory.cpp +++ b/cpp/game/boardhistory.cpp @@ -914,10 +914,6 @@ bool BoardHistory::makeBoardMoveTolerant(Board& board, Loc moveLoc, Player moveP return true; } -void BoardHistory::makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player movePla, const KoHashTable* rootKoHashTable) { - makeBoardMoveAssumeLegal(board,moveLoc,movePla,rootKoHashTable,false); -} - void BoardHistory::makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player movePla, const KoHashTable* rootKoHashTable, bool preventEncore) { Hash128 posHashBeforeMove = board.pos_hash; diff --git a/cpp/game/boardhistory.h b/cpp/game/boardhistory.h index 9a6e87323..a5479802f 100644 --- a/cpp/game/boardhistory.h +++ b/cpp/game/boardhistory.h @@ -158,8 +158,7 @@ struct BoardHistory { //even if the move violates superko or encore ko recapture prohibitions, or is past when the game is ended. //This allows for robustness when this code is being used for analysis or with external data sources. //preventEncore artifically prevents any move from entering or advancing the encore phase when using territory scoring. - void makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player movePla, const KoHashTable* rootKoHashTable); - void makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player movePla, const KoHashTable* rootKoHashTable, bool preventEncore); + void makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player movePla, const KoHashTable* rootKoHashTable, bool preventEncore = false); //Make a move with legality checking, but be mostly tolerant and allow moves that can still be handled but that may not technically //be legal. This is intended for reading moves from SGFs and such where maybe we're getting moves that were played in a different //ruleset than ours. Returns true if successful, false if was illegal even unter tolerant rules. diff --git a/cpp/search/asyncbot.cpp b/cpp/search/asyncbot.cpp index 10e6cea37..148fd8cc9 100644 --- a/cpp/search/asyncbot.cpp +++ b/cpp/search/asyncbot.cpp @@ -54,7 +54,7 @@ AsyncBot::AsyncBot( analyzeCallback(), searchBegunCallback() { - search = new Search(params,nnEval,humanEval,l,randSeed); + search = new Search(params,nnEval,l,randSeed,humanEval); searchThread = std::thread(searchThreadLoop,this,l); } diff --git a/cpp/search/search.cpp b/cpp/search/search.cpp index e2713ae76..a02bba241 100644 --- a/cpp/search/search.cpp +++ b/cpp/search/search.cpp @@ -65,10 +65,7 @@ SearchThread::~SearchThread() { static const double VALUE_WEIGHT_DEGREES_OF_FREEDOM = 3.0; -Search::Search(SearchParams params, NNEvaluator* nnEval, Logger* lg, const string& rSeed) - :Search(params,nnEval,NULL,lg,rSeed) -{} -Search::Search(SearchParams params, NNEvaluator* nnEval, NNEvaluator* humanEval, Logger* lg, const string& rSeed) +Search::Search(const SearchParams& params, NNEvaluator* nnEval, Logger* lg, const string& rSeed, NNEvaluator* humanEval) :rootPla(P_BLACK), rootBoard(), rootHistory(), diff --git a/cpp/search/search.h b/cpp/search/search.h index fe19c3f81..303d29df9 100644 --- a/cpp/search/search.h +++ b/cpp/search/search.h @@ -182,18 +182,11 @@ struct Search { //Note - randSeed controls a few things in the search, but a lot of the randomness actually comes from //random symmetries of the neural net evaluations, see nneval.h Search( - SearchParams params, + const SearchParams ¶ms, NNEvaluator* nnEval, - Logger* logger, - const std::string& randSeed - ); - Search( - SearchParams params, - NNEvaluator* nnEval, - NNEvaluator* humanEval, - Logger* logger, - const std::string& randSeed - ); + Logger* lg, + const std::string& rSeed, + NNEvaluator* humanEval = nullptr); ~Search(); Search(const Search&) = delete; diff --git a/cpp/tests/testboardbasic.cpp b/cpp/tests/testboardbasic.cpp index fa62b94e5..b44a5289a 100644 --- a/cpp/tests/testboardbasic.cpp +++ b/cpp/tests/testboardbasic.cpp @@ -2507,9 +2507,9 @@ oxxxxx.xo //if(rep < 100) // hist.printDebugInfo(cout,board); - testAssert(boardCopy.isEqualForTesting(board, true, true)); - testAssert(boardCopy.isEqualForTesting(histCopy.getRecentBoard(0), true, true)); - testAssert(histCopy.getRecentBoard(0).isEqualForTesting(hist.getRecentBoard(0), true, true)); + testAssert(boardCopy.isEqualForTesting(board)); + testAssert(boardCopy.isEqualForTesting(histCopy.getRecentBoard(0))); + testAssert(histCopy.getRecentBoard(0).isEqualForTesting(hist.getRecentBoard(0))); testAssert(BoardHistory::getSituationRulesAndKoHash(boardCopy,histCopy,pla,drawEquivalentWinsForWhite) == hist.getSituationRulesAndKoHash(board,hist,pla,drawEquivalentWinsForWhite)); testAssert(histCopy.currentSelfKomi(P_BLACK, drawEquivalentWinsForWhite) == hist.currentSelfKomi(P_BLACK, drawEquivalentWinsForWhite)); testAssert(histCopy.currentSelfKomi(P_WHITE, drawEquivalentWinsForWhite) == hist.currentSelfKomi(P_WHITE, drawEquivalentWinsForWhite)); diff --git a/cpp/tests/testsymmetries.cpp b/cpp/tests/testsymmetries.cpp index 5703ccf35..f686f9269 100644 --- a/cpp/tests/testsymmetries.cpp +++ b/cpp/tests/testsymmetries.cpp @@ -413,7 +413,7 @@ x.xxo.... Loc symLocComb = SymmetryHelpers::getSymLoc(loc,board,symmetryComposed); Loc symLocCombManual = SymmetryHelpers::getSymLoc(SymmetryHelpers::getSymLoc(loc,board,symmetry1),SymmetryHelpers::getSymBoard(board,symmetry1),symmetry2); out << "Symmetry " << symmetry1 << " + " << symmetry2 << " = " << symmetryComposed << endl; - testAssert(symBoardCombManual.isEqualForTesting(symBoardComb,true,true)); + testAssert(symBoardCombManual.isEqualForTesting(symBoardComb)); testAssert(symLocComb == symLocCombManual); } } @@ -588,7 +588,7 @@ x.xxo.... out << "SYMMETRY " << symmetry << endl; out << boardA << endl; out << boardB << endl; - testAssert(boardA.isEqualForTesting(boardB,true,true)); + testAssert(boardA.isEqualForTesting(boardB)); } string expected = R"%%( SYMMETRY 0 diff --git a/cpp/tests/testtrainingwrite.cpp b/cpp/tests/testtrainingwrite.cpp index 7c9e3292f..9c04fc5de 100644 --- a/cpp/tests/testtrainingwrite.cpp +++ b/cpp/tests/testtrainingwrite.cpp @@ -9,6 +9,15 @@ using namespace std; using namespace TestCommon; +static TrainingDataWriter createTestTrainingDataWriter( + const int inputVersion, + const int nnXLen, + const int nnYLen, + const string& seed, + const int onlyWriteEvery) { + return TrainingDataWriter(string(), &cout, inputVersion, 256, 1.0f, nnXLen, nnYLen, seed, onlyWriteEvery); +} + static NNEvaluator* startNNEval( const string& modelFile, const string& seed, Logger& logger, int defaultSymmetry, bool inputsUseNHWC, bool useNHWC, bool useFP16 @@ -66,13 +75,9 @@ void Tests::runTrainingWriteTests() { cout << "Running training write tests" << endl; NeuralNet::globalInitialize(); - int maxRows = 256; - double firstFileMinRandProp = 1.0; - int debugOnlyWriteEvery = 5; - - const bool logToStdout = true; - const bool logToStderr = false; - const bool logTime = false; + constexpr bool logToStdout = true; + constexpr bool logToStderr = false; + constexpr bool logTime = false; Logger logger(nullptr, logToStdout, logToStderr, logTime); auto run = [&]( @@ -82,7 +87,7 @@ void Tests::runTrainingWriteTests() { int boardXLen, int boardYLen, bool cheapLongSgf ) { - TrainingDataWriter dataWriter(&cout,inputsVersion, maxRows, firstFileMinRandProp, nnXLen, nnYLen, debugOnlyWriteEvery, seedBase+"dwriter"); + TrainingDataWriter dataWriter = createTestTrainingDataWriter(inputsVersion, nnXLen, nnYLen, seedBase + "dwriter", 5); NNEvaluator* nnEval = startNNEval("/dev/null",seedBase+"nneval",logger,0,inputsNHWC,useNHWC,false); @@ -1056,8 +1061,6 @@ xxxxxxxx. cout << "====================================================================================================" << endl; cout << "Testing turnnumber and early temperatures" << endl; - int maxRows = 256; - double firstFileMinRandProp = 1.0; int debugOnlyWriteEvery = 1; int inputsVersion = 7; @@ -1142,7 +1145,7 @@ xxxxxxxx. GameRunner* gameRunner = new GameRunner(cfg, seed, playSettings, logger); auto shouldStop = []() noexcept { return false; }; WaitableFlag* shouldPause = nullptr; - TrainingDataWriter dataWriter(&cout,inputsVersion, maxRows, firstFileMinRandProp, 9, 9, debugOnlyWriteEvery, seed); + TrainingDataWriter dataWriter = createTestTrainingDataWriter(inputsVersion, 9, 9, seed, debugOnlyWriteEvery); Sgf::PositionSample startPosSample; startPosSample.board = Board(9,9); @@ -1173,7 +1176,7 @@ xxxxxxxx. GameRunner* gameRunner = new GameRunner(cfg, seed, playSettings, logger); auto shouldStop = []() noexcept { return false; }; WaitableFlag* shouldPause = nullptr; - TrainingDataWriter dataWriter(&cout,inputsVersion, maxRows, firstFileMinRandProp, 9, 9, debugOnlyWriteEvery, seed); + TrainingDataWriter dataWriter = createTestTrainingDataWriter(inputsVersion, 9, 9, seed, debugOnlyWriteEvery); Sgf::PositionSample startPosSample; startPosSample.board = Board(9,9); @@ -1203,9 +1206,6 @@ xxxxxxxx. cout << "====================================================================================================" << endl; cout << "Testing no result" << endl; - int maxRows = 256; - double firstFileMinRandProp = 1.0; - int debugOnlyWriteEvery = 1; int inputsVersion = 7; SearchParams params = SearchParams::forTestsV2(); @@ -1272,9 +1272,9 @@ xxxxxxxx. botSpec.nnEval = nnEval; botSpec.baseParams = params; - string seed = "seed-testing-temperature"; - { + string seed = "seed-testing-temperature"; + int debugOnlyWriteEvery = 1; cout << "Turn number initial 0 selfplay with high temperatures" << endl; nnEval->clearCache(); nnEval->clearStats(); @@ -1284,7 +1284,7 @@ xxxxxxxx. GameRunner* gameRunner = new GameRunner(cfg, seed, playSettings, logger); auto shouldStop = []() noexcept { return false; }; WaitableFlag* shouldPause = nullptr; - TrainingDataWriter dataWriter(&cout,inputsVersion, maxRows, firstFileMinRandProp, 9, 9, debugOnlyWriteEvery, seed); + TrainingDataWriter dataWriter = createTestTrainingDataWriter(inputsVersion, 9, 9, seed, debugOnlyWriteEvery); Sgf::PositionSample startPosSample; startPosSample.board = Board::parseBoard(9,9,R"%%( @@ -2732,9 +2732,6 @@ void Tests::runSekiTrainWriteTests(const string& modelFile) { cout << "Running test for how a seki gets recorded" << endl; NeuralNet::globalInitialize(); - int nnXLen = 13; - int nnYLen = 13; - const bool logToStdout = true; const bool logToStderr = false; const bool logTime = false; @@ -2744,10 +2741,8 @@ void Tests::runSekiTrainWriteTests(const string& modelFile) { auto run = [&](const string& sgfStr, const string& seedBase, const Rules& rules) { int inputsVersion = 6; - int maxRows = 256; - double firstFileMinRandProp = 1.0; int debugOnlyWriteEvery = 1000; - TrainingDataWriter dataWriter(&cout,inputsVersion, maxRows, firstFileMinRandProp, nnXLen, nnYLen, debugOnlyWriteEvery, seedBase+"dwriter"); + TrainingDataWriter dataWriter = createTestTrainingDataWriter(inputsVersion, 13, 13, seedBase + "dwriter", debugOnlyWriteEvery); nnEval->clearCache(); nnEval->clearStats();