forked from Fossana/cplusplus-cfr-poker-solver
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTrainer.cpp
More file actions
66 lines (54 loc) · 1.8 KB
/
Trainer.cpp
File metadata and controls
66 lines (54 loc) · 1.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include "Trainer.h"
#include <iostream>
#include "card_utility.h"
#include "ChanceNodeTypeEnum.h"
#include "TerminalNodeTypeEnum.h"
#include "CfrTask.h"
#include <chrono>
#include <cstring>
#include <thread>
#include "tbb/task.h"
#include <functional>
using tbb::task;
using std::memset;
using std::thread;
using std::cout;
using chronoClock = std::chrono::system_clock;
using sec = std::chrono::duration<double>;
using std::make_unique;
using std::vector;
Trainer::Trainer(shared_ptr<RangeManager> rangeManager, uint8_t initialBoard[5], int initialPot, int inPositionPlayer)
{
this->rangeManager = rangeManager;
for (int i = 0; i < 5; i++)
this->initialBoard[i] = initialBoard[i];
this->initialPot = initialPot;
this->inPositionPlayer = inPositionPlayer;
}
void Trainer::train(Node* root, int numIterations)
{
br = make_unique<BestResponse>(rangeManager, root, initialBoard, initialPot, inPositionPlayer);
br->print_exploitability();
cout << '\n';
const auto before = chronoClock::now();
for (int i = 1; i <= numIterations; i++)
{
cfr(1, 2, root, i);
cfr(2, 1, root, i);
if (i % 25 == 0)
{
br->print_exploitability();
const sec duration = chronoClock::now() - before;
cout << i << " cfr iterations took: " << duration.count() << "s\n";
cout << '\n';
}
}
}
vector<float> Trainer::cfr(int hero, int villain, Node *root, int iterationCount)
{
vector<float> villainReachProbs = rangeManager->get_initial_reach_probs(villain);
vector<float> result;
CfrTask& cfrTask = *new (task::allocate_root()) CfrTask(rangeManager, &result, root, hero, villain, &villainReachProbs, initialBoard, iterationCount);
task::spawn_root_and_wait(cfrTask);
return result;
}