Day 23: zobrist hash, still wrong collisions (works only without hash)

This commit is contained in:
2022-03-30 21:39:16 +02:00
parent 4fa4a5d366
commit ea429a99a7

View File

@@ -37,28 +37,33 @@ static const u64 cost[] = {
#define ROOM_C 0xF0000 /* 0x11110000000000000000*/
#define ROOM_D 0xF00000 /* 0x111100000000000000000000*/
#define ROOM 0xFFFF00
#define BIT(c) (1 << (c))
#define RAND_SEED 1337 /* seed for random generator */
static u32 rooms[4] = { ROOM_A, ROOM_B, ROOM_C, ROOM_D };
static u64 result = -1;
typedef struct pos {
u32 amp[4]; /* bitboards */
int moves;
u64 cost;
u32 occupied;
u32 available; /* ALLOWED & ~occupied */
u32 final; /* 1 if final destination */
int ok; /* amphipods in correct place */
u64 zobrist; /* for zobrist_2() */
struct {
u32 from, to;
} move_list[64];
struct list_head list;
} pos_t;
typedef struct hash {
u32 amp[4]; /* bitboards */
u64 cost;
//u32 count; /* collisions */
u64 zobrist; /* zobrist hash */
u32 amp[4];
struct list_head list;
} hash_t;
#define HASH_SIZE 4096
#define HASH_SIZE 131071
struct {
u32 count;
struct list_head list;
@@ -74,7 +79,50 @@ pool_t *pool_hash;
LIST_HEAD(pos_queue);
static void init_hash()
static u32 zobrist_table[24][4];
static void zobrist_init()
{
log_f(1, "zobrist init. RAND_MAX=%d seed=%d\n", RAND_MAX, RAND_SEED);
srand(RAND_SEED);
for (int i = 0; i < 24; ++i) {
for (int j = 0; j < 4; ++j) {
zobrist_table[i][j] = rand();
log(10, "%d ", zobrist_table[i][j]);
}
}
}
static inline u64 zobrist_1(pos_t *pos)
{
u32 tmp;
int bit;
u64 zobrist = 0;
for (int amp = 0; amp < 4; ++amp) {
bit_for_each32_2(bit, tmp, pos->amp[amp]) {
//log_f(2, "amp=%d/%c bit=%d\n", amp, amp+'A', bit);
zobrist ^= zobrist_table[bit][amp];
}
}
log_f(1, "zobrist=%lu -> %lu\n", zobrist, zobrist % HASH_SIZE);
return zobrist;
}
/* calculate zobrist hash from previous zobrist value
*/
static inline u64 zobrist_2(pos_t *pos, int amp, u32 from, u32 to)
{
u64 zobrist = pos->zobrist;
zobrist ^= zobrist_table[from][amp];
zobrist ^= zobrist_table[to][amp];
log_f(1, "zobrist=%lu -> %lu (amp=%d from=%u to=%u)\n",
zobrist, zobrist % HASH_SIZE, amp, from, to);
return zobrist;
}
static void hash_init()
{
for (int i = 0; i < HASH_SIZE; ++i) {
hasht[i].count = 0;
@@ -91,7 +139,6 @@ static hash_t *get_hash(pos_t *pos)
return NULL;
for (int i = 0; i < 4; ++i)
new->amp[i] = pos->amp[i];
new->cost = pos->cost;
INIT_LIST_HEAD(&new->list);
return new;
}
@@ -112,51 +159,43 @@ static void hash_stats()
min, max, ncollisions, ncollisions / HASH_SIZE);
}
/* hash function from:
* http://www.isthe.com/chongo/tech/comp/fnv/
*/
#define FNV_offset_basis 2166136261
#define FNV_prime 16777619
static hash_t *hash(pos_t *pos)
static u64 hash(pos_t *pos, int amp, u32 from, u32 to)
{
hash_t *cur;
u32 val = FNV_offset_basis;
u64 zobrist;
u32 val;
/* we use the 2 first chars of the amps 32, this should be enough
* to avoid most collisions
*/
for (int i = 0; i < 4; ++i) {
uchar *input = (uchar *) &pos->amp[i];
for (int c = 0; c < 2; ++c) {
val ^= input[c];
val *= FNV_prime;
}
}
//val ^= pos->cost;
log_f(1, "hash=%u", val);
val %= HASH_SIZE;
log_f(1, " ->%u, count=%d\n", val, hasht[val].count);
zobrist = zobrist_2(pos, amp, from, to);
val = zobrist % HASH_SIZE;
log_f(1, "zobrist=%lu->%u, count=%d\n", zobrist, val, hasht[val].count);
list_for_each_entry(cur, &hasht[val].list, list) {
if (pos->amp[0] == cur->amp[0] &&
pos->amp[1] == cur->amp[1] &&
pos->amp[2] == cur->amp[2] &&
pos->amp[3] == cur->amp[3]) {
if (pos->cost >= cur->cost) {
log(1, "collision, worse cost.\n");
return NULL;
} else {
log(1, "collision, better cost.\n");
cur->cost = pos->cost; /* adjust better solution */
return cur;
if (zobrist == cur->zobrist) {
u32 amp_tmp[4];
for (int i = 0; i < 4; ++i) {
amp_tmp[i] = pos->amp[i];
}
amp_tmp[amp] ^= BIT(from);
amp_tmp[amp] |= BIT(to);
if (amp_tmp[0] == cur->amp[0] &&
amp_tmp[1] == cur->amp[1] &&
amp_tmp[2] == cur->amp[2] &&
amp_tmp[3] == cur->amp[3])
return 0;
log(1, "zobrist collision for different positions\n");
}
}
hasht[val].count++;
log(1, "adding hash count=%u\n", hasht[val].count);
cur = get_hash(pos);
cur->zobrist = zobrist;
cur->amp[amp] ^= BIT(from);
cur->amp[amp] |= BIT(to);
list_add(&cur->list, &hasht[val].list);
return cur;
return zobrist;
}
/*
@@ -201,8 +240,6 @@ static s32 room_exit[4][2][6] = {
}
};
#define BIT(c) (1 << (c))
static char *int2bin(u32 mask, char *ret)
{
for (int i = 0; i < 32; ++i) {
@@ -305,7 +342,7 @@ static pos_t *get_pos(pos_t *from)
if (!from) {
new->amp[0] = new->amp[1] = new->amp[2] = new->amp[3] = 0;
new->moves = new->ok = 0;
new->cost = new->occupied = new->available = 0;
new->cost = new->occupied = 0;
} else {
*new = *from;
}
@@ -501,23 +538,30 @@ static pos_t *newmove(pos_t *pos, amphipod_t amp, u32 from, u32 to)
int rows = popcount32(pos->amp[0]);
move_t *move = &moves[from][to];
pos_t *newpos;
hash_t *collision;
u64 collision;
log_f(1, "rows=%d amp=%c from=%s to=%s dist=%u ok=%d cost=%lu\n",
rows, amp + 'A',
cells[from], cells[to],
move->dist, pos->ok, move->dist * cost[amp]);
collision = hash(pos);
if (!collision) {
log(1, "collision, skipping move :\n");
burrow_print(pos);
// free_pos(newpos);
return NULL;
if (pos->ok < 0) {
collision = hash(pos, amp, from, to);
if (!collision) {
log(1, "collision, skipping move :\n");
pos->amp[amp] ^= BIT(from);
pos->amp[amp] |= BIT(to);
burrow_print(pos);
pos->amp[amp] ^= BIT(to);
pos->amp[amp] |= BIT(from);
return NULL;
}
}
if (!(newpos = get_pos(pos)))
return NULL;
newpos->move_list[newpos->moves].from = from;
newpos->move_list[newpos->moves].to = to;
newpos->amp[amp] ^= BIT(from);
newpos->amp[amp] |= BIT(to);
newpos->occupied ^= BIT(from);
@@ -525,7 +569,6 @@ static pos_t *newmove(pos_t *pos, amphipod_t amp, u32 from, u32 to)
newpos->moves++;
newpos->cost += move->dist * cost[amp];
if (to >= A1) { /* final destination */
newpos->ok++;
newpos->final |= BIT(to);
@@ -535,6 +578,15 @@ static pos_t *newmove(pos_t *pos, amphipod_t amp, u32 from, u32 to)
log(1, "found solution! cost=%lu\n", newpos->cost);
burrow_print(newpos);
free_pos(newpos);
if (newpos->cost < result) {
result = newpos->cost;
log(1, "New best=%lu moves=%u List:", result, newpos->moves);
for (int i = 0; i < newpos->moves; ++i) {
log(1, " %s-%s", cells[newpos->move_list[i].from],
cells[newpos->move_list[i].to]);
}
log(1, "\n");
}
return NULL;
}
@@ -661,7 +713,6 @@ static pos_t *read_input(int part)
bit = 8 + adjline - 2;
for (int i = 0; i < 4; ++i) {
int amp = buf[i * 2 + 3] - 'A';
//printf("bit = %lu char = %c\n", bit, amp + 'A');
pos->amp[amp] |= BIT(bit);
log(3, "setting bit %d to %c\n", bit, amp + 'A');
bit += 4;
@@ -671,7 +722,6 @@ static pos_t *read_input(int part)
adjline++;
}
pos->occupied = get_occupancy(pos);
pos->available = ALLOWED & ~pos->occupied;
free(buf);
return pos;
}
@@ -716,11 +766,14 @@ int main(int ac, char **av)
pool_pos = pool_create("pos", 1024, sizeof(pos_t));
pool_hash = pool_create("hash", 1024, sizeof(hash_t));
init_hash();
zobrist_init();
hash_init();
init_moves();
print_moves(moves);
pos = read_input(part);
zobrist_1(pos);
push_pos(pos);
/*
for (int i = 0; i < 4; ++ i) {