Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ common_files += [
'src/neural/register.cc',
'src/neural/shared_params.cc',
'src/neural/wrapper.cc',
'src/search/common/temperature.cc',
'src/search/classic/node.cc',
'src/syzygy/syzygy.cc',
'src/trainingdata/reader.cc',
Expand Down Expand Up @@ -225,6 +226,7 @@ files += [

files += [
'src/search/instamove/instamove.cc',
'src/search/policyhead/policyhead.cc',
]

includes += include_directories('src')
Expand Down Expand Up @@ -900,6 +902,11 @@ if get_option('gtest')
include_directories: includes, link_with: lc0_lib, dependencies: gtest
), args: '--gtest_output=xml:syzygy.xml', timeout: 90)

test('TemperatureTest',
executable('temperature_test', 'src/search/common/temperature_test.cc',
include_directories: includes, link_with: lc0_lib, dependencies: gtest
), args: '--gtest_output=xml:temperature.xml', timeout: 90)

test('EncodePositionForNN',
executable('encoder_test', 'src/neural/encoder_test.cc', pb_files,
include_directories: includes, link_with: lc0_lib,
Expand Down
33 changes: 12 additions & 21 deletions src/search/classic/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "utils/fastmath.h"
#include "utils/random.h"
#include "utils/spinhelper.h"
#include "search/common/temperature.h"

namespace lczero {
namespace classic {
Expand Down Expand Up @@ -688,27 +689,17 @@ void Search::EnsureBestMoveKnown() REQUIRES(nodes_mutex_)
if (root_node_->GetN() == 0) return;
if (!root_node_->HasChildren()) return;

float temperature = params_.GetTemperature();
const int cutoff_move = params_.GetTemperatureCutoffMove();
const int decay_delay_moves = params_.GetTempDecayDelayMoves();
const int decay_moves = params_.GetTempDecayMoves();
const int moves = played_history_.Last().GetGamePly() / 2;

if (cutoff_move && (moves + 1) >= cutoff_move) {
temperature = params_.GetTemperatureEndgame();
} else if (temperature && decay_moves) {
if (moves >= decay_delay_moves + decay_moves) {
temperature = 0.0;
} else if (moves >= decay_delay_moves) {
temperature *=
static_cast<float>(decay_delay_moves + decay_moves - moves) /
decay_moves;
}
// don't allow temperature to decay below endgame temperature
if (temperature < params_.GetTemperatureEndgame()) {
temperature = params_.GetTemperatureEndgame();
}
}
TemperatureParams tp{
.temperature = params_.GetTemperature(),
.temp_decay_moves = params_.GetTempDecayMoves(),
.temp_cutoff_move = params_.GetTemperatureCutoffMove(),
.temp_decay_delay_moves = params_.GetTempDecayDelayMoves(),
.temp_endgame = params_.GetTemperatureEndgame(),
.value_cutoff = params_.GetTemperatureWinpctCutoff(),
.visit_offset = params_.GetTemperatureVisitOffset(),
};
const int ply = played_history_.Last().GetGamePly();
const float temperature = EffectiveTau(tp, ply);

auto bestmove_edge = temperature
? GetBestRootChildWithTemperature(temperature)
Expand Down
76 changes: 76 additions & 0 deletions src/search/common/temperature.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#include "search/common/temperature.h"

#include <algorithm>
#include <cmath>
#include <limits>

#include "utils/random.h"

namespace lczero {

float EffectiveTau(const TemperatureParams& p, int ply) {
const int moves = ply / 2;

float tau = p.temperature;
if (p.temp_cutoff_move > 0 && (moves + 1) >= p.temp_cutoff_move) {
tau = p.temp_endgame;
} else if (tau > 0.0 && p.temp_decay_moves > 0) {
const int decay_delay = p.temp_decay_delay_moves;
const int decay_moves = p.temp_decay_moves;
if (moves >= decay_delay + decay_moves) {
tau = 0.0f;
} else if (moves >= decay_delay) {
tau *= static_cast<double>(decay_delay + decay_moves - moves) /
static_cast<double>(decay_moves);
}
// don't allow temp to decay below endgame temp
if (tau < p.temp_endgame) tau = p.temp_endgame;
}
if (tau < 0.0) tau = 0.0;
return tau;
}

int SampleWithTemperature(std::span<const double> base_weights,
std::span<const double> winprob,
const TemperatureParams& p,
float tau,
Random& rng,
int fallback_index) {
const size_t n = base_weights.size();
std::vector<double> weights(n);
double max_winprob = -std::numeric_limits<double>::infinity();
if (!winprob.empty()) {
for (double w : winprob) {
if (w > max_winprob) max_winprob = w;
}
}
double sum = 0.0;
const double inv_tau = tau > 0 ? 1.0 / static_cast<double>(tau) : 0.0;
for (size_t i = 0; i < n; ++i) {
double w = base_weights[i];
if (!winprob.empty() && max_winprob - winprob[i] > p.value_cutoff) {
w = 0.0;
}
if (p.visit_offset != 0.0) {
w = std::max(w - p.visit_offset, 0.0);
}
if (w > 0.0 && tau > 0.0) {
w = std::pow(w, inv_tau);
} else {
w = 0.0;
}
weights[i] = w;
sum += w;
}
if (sum <= 0.0) return fallback_index;
double toss = rng.GetDouble(sum);
double cumulative = 0.0;
for (size_t i = 0; i < n; ++i) {
cumulative += weights[i];
if (toss < cumulative) return static_cast<int>(i);
}
return fallback_index;
}

} // namespace lczero

37 changes: 37 additions & 0 deletions src/search/common/temperature.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#pragma once

#include <span>
#include <vector>

namespace lczero {

struct TemperatureParams {
float temperature;
int temp_decay_moves;
int temp_cutoff_move;
int temp_decay_delay_moves;
float temp_endgame;
float value_cutoff;
float visit_offset;
};

// Returns effective temperature tau for the given game ply
// The fullmove number is computed internally as (ply / 2) + 1.
// Applies cutoff and linear decay. Result clamped to [0, +inf).
float EffectiveTau(const TemperatureParams& p, int ply);

class Random; // Forward declaration from utils/random.h.

// Samples an index from base_weights using temperature tau.
// Applies value cutoff and visit offset
// winprob may be empty to skip value cutoff.
// Returns fallback_index if all weights are filtered to zero.
int SampleWithTemperature(std::span<const double> base_weights,
std::span<const double> winprob,
const TemperatureParams& p,
float tau,
Random& rng,
int fallback_index);

} // namespace lczero

86 changes: 86 additions & 0 deletions src/search/common/temperature_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include "search/common/temperature.h"

#include <gtest/gtest.h>

#include "utils/random.h"

namespace lczero {

TEST(TemperatureTest, EffectiveTauDecayAndCutoff) {
// temperature = 1.0, decay over 10 moves
TemperatureParams p{};
p.temperature = 1.0f;
p.temp_decay_moves = 10;
p.temp_cutoff_move = 0;
p.temp_decay_delay_moves = 0;
p.temp_endgame = 0.0f;
p.value_cutoff = 0.0f;
p.visit_offset = 0.0f;
// fullmove 1 -> ply 0
EXPECT_FLOAT_EQ(EffectiveTau(p, 0), 1.0f);
// fullmove 5 -> moves_played = 4 -> ply ~8
EXPECT_NEAR(EffectiveTau(p, 8), 0.6f, 1e-6f);
// fullmove 11 -> moves_played = 10 -> ply ~20 -> tau should be 0
EXPECT_FLOAT_EQ(EffectiveTau(p, 20), 0.0f);
p.temp_cutoff_move = 5;
p.temp_endgame = 0.3f;
// fullmove 5 -> trigger cutoff when moves+1 >= 5, so ply 8 (moves=4) should
// hit endgame temperature
EXPECT_FLOAT_EQ(EffectiveTau(p, 8), 0.3f);
}

TEST(TemperatureTest, SampleProbabilityShift) {
TemperatureParams p{};
p.temperature = 0.0f;
p.temp_decay_moves = 0;
p.temp_cutoff_move = 0;
p.temp_decay_delay_moves = 0;
p.temp_endgame = 0.0f;
p.value_cutoff = 0.0f;
p.visit_offset = 0.0f;
std::vector<double> base{1.0, 4.0};
std::vector<double> wp{0.5, 0.5};
int count = 0;
for (int i = 0; i < 5000; ++i) {
int idx = SampleWithTemperature(base, wp, p, 1.0f, Random::Get(), 0);
if (idx == 1) ++count;
}
double freq = static_cast<double>(count) / 5000.0;
EXPECT_NEAR(freq, 4.0 / 5.0, 0.05);

count = 0;
for (int i = 0; i < 5000; ++i) {
int idx = SampleWithTemperature(base, wp, p, 0.5f, Random::Get(), 0);
if (idx == 1) ++count;
}
freq = static_cast<double>(count) / 5000.0;
EXPECT_NEAR(freq, 16.0 / 17.0, 0.05);
}

TEST(TemperatureTest, CutoffAndVisitOffset) {
TemperatureParams p{};
p.temperature = 0.0f;
p.temp_decay_moves = 0;
p.temp_cutoff_move = 0;
p.temp_decay_delay_moves = 0;
p.temp_endgame = 0.0f;
p.value_cutoff = 0.1f;
p.visit_offset = 0.0f;
std::vector<double> base{1.0, 1.0};
std::vector<double> wp{1.0, 0.8};
for (int i = 0; i < 10; ++i) {
EXPECT_EQ(0, SampleWithTemperature(base, wp, p, 1.0f, Random::Get(), 0));
}
p.visit_offset = 2.0;
std::vector<double> base2{1.0};
EXPECT_EQ(0, SampleWithTemperature(base2, std::span<const double>(), p, 1.0f,
Random::Get(), 0));
}

} // namespace lczero

int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

33 changes: 12 additions & 21 deletions src/search/dag_classic/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "utils/fastmath.h"
#include "utils/random.h"
#include "utils/spinhelper.h"
#include "search/common/temperature.h"

namespace lczero {
namespace dag_classic {
Expand Down Expand Up @@ -691,27 +692,17 @@ void Search::EnsureBestMoveKnown() REQUIRES(nodes_mutex_)
if (root_node_->GetN() == 0) return;
if (!root_node_->HasChildren()) return;

float temperature = params_.GetTemperature();
const int cutoff_move = params_.GetTemperatureCutoffMove();
const int decay_delay_moves = params_.GetTempDecayDelayMoves();
const int decay_moves = params_.GetTempDecayMoves();
const int moves = played_history_.Last().GetGamePly() / 2;

if (cutoff_move && (moves + 1) >= cutoff_move) {
temperature = params_.GetTemperatureEndgame();
} else if (temperature && decay_moves) {
if (moves >= decay_delay_moves + decay_moves) {
temperature = 0.0;
} else if (moves >= decay_delay_moves) {
temperature *=
static_cast<float>(decay_delay_moves + decay_moves - moves) /
decay_moves;
}
// don't allow temperature to decay below endgame temperature
if (temperature < params_.GetTemperatureEndgame()) {
temperature = params_.GetTemperatureEndgame();
}
}
TemperatureParams tp{
.temperature = params_.GetTemperature(),
.temp_decay_moves = params_.GetTempDecayMoves(),
.temp_cutoff_move = params_.GetTemperatureCutoffMove(),
.temp_decay_delay_moves = params_.GetTempDecayDelayMoves(),
.temp_endgame = params_.GetTemperatureEndgame(),
.value_cutoff = params_.GetTemperatureWinpctCutoff(),
.visit_offset = params_.GetTemperatureVisitOffset(),
};
const int ply = played_history_.Last().GetGamePly();
const float temperature = EffectiveTau(tp, ply);

auto bestmove_edge = temperature
? GetBestRootChildWithTemperature(temperature)
Expand Down
Loading