Need help selecting neural network design for Snake game AI

I’m just starting out with AI and reinforcement learning. I want to build an agent that can learn how to play the Snake game but I’m stuck on what network structure to use.

My game state is represented as a 10x10x3 array. Each layer uses just 0s and 1s. The first layer shows where the snake body is located. The second layer marks where the food appears. The third layer indicates the snake head position.

The output needs to be 4 numbers representing the possible actions the snake can take (move up, down, left, or right).

Can anyone suggest what kind of network architecture would work well for this setup? Also wondering if my way of representing the game state makes sense or if there’s a better approach. Thanks for any advice!

Your state representation looks solid for starting out. I’d skip convolutions for now and just use a basic feedforward network. Feed those 300 flattened inputs through something like 256 → 128 → 64 → 4 layers with ReLU. Here’s what I learned building my Snake AI: architecture doesn’t matter as much as stable training. Use experience replay with around 10k transitions and update your target network every few hundred steps. One trick that really helped - add Manhattan distance to food as an extra input alongside your grid. Way faster than making the network learn spatial stuff from scratch.

I worked on a Snake AI last year, and a simple fully connected network worked great. Your three-channel state setup is solid; it gives the network clean info separation which helps learning. For architecture, flatten your 10x10x3 input to 300 neurons, then use two hidden layers at 128 and 64 neurons, finishing with your 4-output actions. ReLU works fine for hidden layers, and consider dropout around 0.2 if you’re overfitting. Here’s what I learned the hard way: reward structure beats architecture every time. Give small positive rewards for each survival step and moving toward food, rather than just rewarding eating. Add a tiny penalty for moving away from food to push more direct paths. Your state representation is way better than my starting point; I tried raw pixels first, which was needlessly complex.

cnn’s probably overkill but it’ll work with your grid. try a conv2d layer - 16 filters, 3x3 kernel - then flatten and add dense layers. your state representation looks solid though, way cleaner than stuffing everything into one channel like I did. just double-check you’re using the right loss function for q-learning.

This topic was automatically closed 24 hours after the last reply. New replies are no longer allowed.