Added optimal first guess and better code readability

This commit is contained in:
Matúš Púll 2024-12-30 23:45:31 +01:00
parent 213209243e
commit c55dbee834

View file

@ -12,28 +12,28 @@ void Solver::generate_set(vector<int> carry) {
carry.pop_back(); carry.pop_back();
} }
} }
Solver::Solver(int _N, int _M) : N(_N), M(_M) { Solver::Solver(int _N, int _M) : N(_N), M(_M) {
generate_set({}); generate_set({});
} }
vector<int> Solver::guess() { vector<int> Solver::guess() {
// First pick // Optimal first pick is always the same (at least for 5x8)
if(first_pick) { if(first_pick) {
first_pick = false; first_pick = false;
vector<int> pick(0); vector<int> pick = {0};
int times = N / M + 1; int times = (N-1) / M + 1;
for(int i = 0; i < times; i++) for(int i = 0; i < times; i++)
for(int j = 0; j < M && pick.size() < N; j++) for(int j = 0; j < M && pick.size() < N; j++)
pick.push_back(j); pick.push_back(j);
return pick; return pick;
} }
return choose_possible().guess; return choose_possible().guess;
} }
void Solver::learn(vector<int> guess, Response response) { void Solver::learn(vector<int> guess, Response response) {
// Eliminating impossible sequences
set<vector<int>> next_possible; set<vector<int>> next_possible;
for(auto sequence : possible) for(auto sequence : possible)
if(validate(sequence, guess) == response) if(validate(sequence, guess) == response)
@ -43,24 +43,27 @@ void Solver::learn(vector<int> guess, Response response) {
} }
int Solver::get_weight(vector<int> guess) { int Solver::get_weight(vector<int> guess) {
// Indexing by N*somewhere + correct and holding how many sequences got that response // Bucketing possible sequences by responses
vector<int> response_count(N*N+1, 0); vector<int> response_count(N*N+N+1, 0);
for(auto sequence : possible) { for(auto sequence : possible) {
Response response = validate(sequence, guess); Response response = validate(sequence, guess);
response_count[N*response.somewhere + response.correct]++; response_count[(N+1)*response.somewhere + response.correct]++;
} }
// Get highest possible number of sequences left // Get size of the fullest bucket
int max = 0; int max = 0;
for(int count : response_count) for(int count : response_count)
if(count > max) if(count > max)
max = count; max = count;
return max - possible.count(guess); // Possible guesses have higher priority
return 2*max - possible.count(guess);
} }
// Choosing next guess
Weighed_guess Solver::minimax(vector<int> carry) { Weighed_guess Solver::minimax(vector<int> carry) {
if(carry.size() == N) if(carry.size() == N)
return {get_weight(carry), carry}; return {get_weight(carry), carry};
// Pick the best of next picks
Weighed_guess best = {-1, {}}; Weighed_guess best = {-1, {}};
for(int col = 0; col < M; col++) { for(int col = 0; col < M; col++) {
carry.push_back(col); carry.push_back(col);