Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,15 @@ if get_option('dag_classic')
'src/search/dag_classic/search.cc',
'src/search/dag_classic/wrapper.cc',
]

## ~~~~~~~~~~
## Transposition table
## ~~~~~~~~~~
fix_tt = get_option('fixed_tt')
if fix_tt
# Change transposition table to a fixed size one.
add_project_arguments('-DFIX_TT', language : 'cpp')
endif
endif

#############################################################################
Expand Down
6 changes: 6 additions & 0 deletions meson_options.txt
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,9 @@ option('dag_classic',
type: 'boolean',
value: true,
description: 'Enable dag-classic search algorithm')

option('fixed_tt',
type: 'boolean',
value: true,
description: 'Build dag using fixed size transposition table.')

5 changes: 5 additions & 0 deletions src/search/dag_classic/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include "chess/gamestate.h"
#include "chess/position.h"
#include "neural/backend.h"
#include "utils/cache.h"
#include "utils/mutex.h"

namespace lczero {
Expand Down Expand Up @@ -955,8 +956,12 @@ inline VisitedNode_Iterator<false> Node::VisitedNodes() {
}

// Transposition Table type for holding references to all low nodes in DAG.
#ifndef FIX_TT
typedef absl::flat_hash_map<uint64_t, std::weak_ptr<LowNode>>
TranspositionTable;
#else
typedef HashKeyedCache<std::weak_ptr<LowNode>> TranspositionTable;
#endif

class NodeTree {
public:
Expand Down
46 changes: 46 additions & 0 deletions src/search/dag_classic/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,14 @@ Search::Search(const NodeTree& tree, Backend* backend,
searchmoves_, syzygy_tb_, played_history_,
params_.GetSyzygyFastPlay(), &tb_hits_, &root_is_in_dtz_)),
uci_responder_(std::move(uci_responder)) {
#ifndef FIX_TT
// Evict expired entries from the transposition table.
// Garbage collection may lead to expiration at any time so this is not
// enough to prevent expired entries later during the search.
absl::erase_if(*tt_, [](const auto& item) { return item.second.expired(); });

LOGFILE << "Transposition table garbage collection done.";
#endif

if (params_.GetMaxConcurrentSearchers() != 0) {
pending_searchers_.store(params_.GetMaxConcurrentSearchers(),
Expand Down Expand Up @@ -332,6 +334,9 @@ void Search::SendUciInfo(const classic::IterationStats& stats)
}
}
common_info.tb_hits = tb_hits_.load(std::memory_order_acquire);
#ifdef FIX_TT
common_info.hashfull = tt_->GetSize() * 1000.0f / tt_->GetCapacity();
#endif

int multipv = 0;
const auto default_q = -root_node_->GetQ(-draw_score);
Expand Down Expand Up @@ -2085,13 +2090,23 @@ void SearchWorker::ExtendNode(NodeToProcess& picked_node) {
// Check the transposition table first and NN cache second before asking for
// NN evaluation.
picked_node.hash = history.HashLast(params_.GetCacheHistoryLength() + 1);
#ifndef FIX_TT
auto tt_iter = search_->tt_->find(picked_node.hash);
// Transposition table entry might be expired.
if (tt_iter != search_->tt_->end()) {
picked_node.tt_low_node = tt_iter->second.lock();
}
#else
auto entry = search_->tt_->LookupAndPin(picked_node.hash);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could insert nodes here if insertions avoid random slowdowns of the dynamic version. For the testing purposes it is better to delay insert. That makes it easier to compare how the fixed implementation compares to the dynamic implementation.

if (entry) {
picked_node.tt_low_node = entry->lock();
search_->tt_->Unpin(picked_node.hash, entry);
}
#endif
if (picked_node.tt_low_node) {
#ifndef FIX_TT
assert(!tt_iter->second.expired());
#endif
picked_node.is_tt_hit = true;
} else {
picked_node.tt_low_node = std::make_shared<LowNode>(legal_moves);
Expand Down Expand Up @@ -2238,18 +2253,49 @@ void SearchWorker::DoBackupUpdateSingleNode(
auto path = node_to_process.path;

if (node_to_process.nn_queried) {
#ifdef FIX_TT
auto entry = search_->tt_->LookupAndPin(node_to_process.hash);
if (!entry) {
bool insert_ok = search_->tt_->Insert(
node_to_process.hash, std::make_unique<std::weak_ptr<LowNode>>(
node_to_process.tt_low_node));
if (!insert_ok) {
// The insert may fail if another thread added the same hash.
// In the unlikely case it fails, the search will still work OK.
entry = search_->tt_->LookupAndPin(node_to_process.hash);
}
}
bool is_tt_miss = !entry;
#else
auto [tt_iter, is_tt_miss] = search_->tt_->try_emplace(
node_to_process.hash, node_to_process.tt_low_node);
#endif
if (is_tt_miss) {
#ifndef FIX_TT
assert(!tt_iter->second.expired());
#endif
node_to_process.node->SetLowNode(node_to_process.tt_low_node);
} else {
#ifdef FIX_TT
auto tt_low_node = entry->lock();
#else
auto tt_low_node = tt_iter->second.lock();
#endif
if (!tt_low_node) {
#ifdef FIX_TT
// An insert would fail, so update the (expired) entry directly.
*entry = node_to_process.tt_low_node;
search_->tt_->Unpin(node_to_process.hash, entry);
#else
tt_iter->second = node_to_process.tt_low_node;
#endif
node_to_process.node->SetLowNode(node_to_process.tt_low_node);
} else {
#ifdef FIX_TT
search_->tt_->Unpin(node_to_process.hash, entry);
#else
assert(!tt_iter->second.expired());
#endif
node_to_process.node->SetLowNode(tt_low_node);
}
}
Expand Down
28 changes: 27 additions & 1 deletion src/search/dag_classic/wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ const OptionId kClearTree{
.help_text = "Clear the tree before the next search.",
.visibility = OptionId::kProOnly}};

#ifdef FIX_TT
const OptionId kHashId{{.long_flag = "hash",
.uci_option = "Hash",
.help_text = "Size of the transposition table in MB.",
.visibility = OptionId::kAlwaysVisible}};
#endif

class DagClassicSearch : public SearchBase {
public:
DagClassicSearch(UciResponder* responder, const OptionsDict* options)
Expand Down Expand Up @@ -102,7 +109,11 @@ MoveList StringsToMovelist(const std::vector<std::string>& moves,
void DagClassicSearch::NewGame() {
LOGFILE << "New game.";
search_.reset();
#ifndef FIX_TT
tt_.clear();
#else
tt_.Clear();
#endif
tree_.reset();
time_manager_ = classic::MakeTimeManager(*options_);
}
Expand All @@ -112,6 +123,11 @@ void DagClassicSearch::SetPosition(const GameState& pos) {
const bool is_same_game = tree_->ResetToPosition(pos);
LOGFILE << "Tree reset to a new position.";
if (!is_same_game) time_manager_ = classic::MakeTimeManager(*options_);
#ifdef FIX_TT
// Transposition table size.
tt_.SetCapacity(options_->Get<int>(kHashId) * 1000000 /
tt_.GetItemStructSize());
#endif
}

void DagClassicSearch::StartSearch(const GoParams& params) {
Expand All @@ -133,7 +149,11 @@ void DagClassicSearch::StartSearch(const GoParams& params) {
sizeof(float[classic::MemoryWatchingStopper::kAvgMovesPerPosition]);
size_t total_memory =
tree_.get()->GetCurrentHead()->GetN() * kAvgNodeSize +
#ifdef FIX_TT
tt_.GetCapacity() * tt_.GetItemStructSize() +
#else
(sizeof(TranspositionTable::value_type) + 1) * tt_.bucket_count() +
#endif
cache_size * kAvgCacheItemSize;
auto stopper = time_manager_->GetStopper(
params, tree_.get()->HeadPosition(), total_memory, kAvgNodeSize,
Expand All @@ -143,7 +163,10 @@ void DagClassicSearch::StartSearch(const GoParams& params) {
StringsToMovelist(params.searchmoves, tree_->HeadPosition().GetBoard()),
*move_start_time_, std::move(stopper), params.infinite, params.ponder,
*options_, &tt_, syzygy_tb_);

#ifdef FIX_TT
LOGFILE << "Transposition table load factor is "
<< tt_.GetSize() / static_cast<float>(tt_.GetCapacity());
#endif
LOGFILE << "Timer started at "
<< FormatTime(SteadyClockToSystemClock(*move_start_time_));
search_->StartThreads(options_->Get<int>(kThreadsOptionId));
Expand All @@ -158,6 +181,9 @@ class DagClassicSearchFactory : public SearchFactory {

void PopulateParams(OptionsParser* parser) const override {
parser->Add<IntOption>(kThreadsOptionId, 0, 128) = 0;
#ifdef FIX_TT
parser->Add<IntOption>(kHashId, 0, 2000) = 50;
#endif
SearchParams::Populate(parser);
classic::PopulateTimeManagementOptions(classic::RunType::kUci, parser);

Expand Down
8 changes: 5 additions & 3 deletions src/utils/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ class HashKeyedCache {

// Inserts the element under key @key with value @val. Unless the key is
// already in the cache.
void Insert(uint64_t key, std::unique_ptr<V> val) {
if (capacity_.load(std::memory_order_relaxed) == 0) return;
// Returns false if the hash is found in the cache, blocking insertion.
bool Insert(uint64_t key, std::unique_ptr<V> val) {
if (capacity_.load(std::memory_order_relaxed) == 0) return true;

SpinMutex::Lock lock(mutex_);

Expand All @@ -73,7 +74,7 @@ class HashKeyedCache {
if (!hash_[idx].in_use) break;
if (hash_[idx].key == key) {
// Already exists.
return;
return false;
}
++idx;
if (idx >= hash_.size()) idx -= hash_.size();
Expand All @@ -87,6 +88,7 @@ class HashKeyedCache {
++allocated_;

EvictToCapacity(capacity_);
return true;
}

// Checks whether a key exists. Doesn't pin. Of course the next moment the
Expand Down