I’m just getting started with AI and trying to build a bot that can play the Snake game. I’m stuck on figuring out what kind of network design would work best for my setup.
Right now I have my game data set up as a 4x12x12 grid with three different sections. The first section shows where all the snake body pieces are located using 1s and 0s. The second section marks where the food item appears on the board. The third section indicates the current head position of the snake.
What I need as output is basically 4 numbers that represent the possible actions the snake can take - moving up, down, left, or right.
Can anyone give me advice on what network structure might work well here? Also wondering if my way of organizing the game information makes sense or if there’s a better approach I should consider.
The structure you’ve chosen is actually quite good for reinforcement learning approaches. I’d recommend trying a dueling DQN architecture where you separate value and advantage streams after your feature extraction layers. Start with two conv layers using 32 filters each, then split into separate branches - one estimating state value and another estimating action advantages, then combine them for your final Q-values. This worked particularly well for me because snake involves both understanding how good your current position is and comparing specific moves. Your three-channel input is smart since it separates different types of game information clearly. One improvement I found helpful was normalizing the head position channel differently from the binary body/food channels since position data has different characteristics. Also consider adding some dropout between your dense layers to prevent overfitting, especially if you’re training on limited game variations.
I’ve built a few game AIs and ran into similar decisions. Your input representation is actually pretty clean.
One thing I’d try is a hybrid approach - start with a couple conv layers to extract spatial features from your 4x12x12 grid, then flatten and feed into dense layers. The conv layers will catch patterns like “food is two squares away” or “body is forming a trap” which are crucial for snake.
But here’s what really helped me: add distance calculations as extra inputs. Compute Manhattan distance to food, distance to nearest wall, distance to nearest body part. Feed these as additional inputs alongside your flattened conv features. This gives the network both spatial awareness AND explicit numerical relationships.
For the network size, I found 32 and 64 filters work well for the conv layers, then maybe 128-64 dense layers before your final 4 outputs.
Also consider using DQN instead of basic supervised learning if you haven’t already. The experience replay really helps with the sequential decision making that snake requires.
Your current setup will definitely work though - I’d just add those distance features and see how much it improves performance.
Your data representation looks reasonable for a convolutional approach. I’d suggest starting with a simple CNN architecture - maybe 2-3 convolutional layers with small kernels like 3x3, followed by max pooling, then flatten and connect to a couple of dense layers before your final 4-output layer. The spatial nature of your 4x12x12 input should work well with convolutions since they can learn to recognize patterns like food proximity and body collision risks. One thing I learned when working on similar projects is that the reward function matters just as much as network architecture. Make sure you’re giving proper rewards for eating food and strong penalties for collisions, maybe even small rewards for surviving longer. You might also want to experiment with adding a fourth channel showing walls or boundaries if your snake game has them, but your current setup should definitely work as a starting point.
honestly your input setup seems pretty solid already. i’d go with a basic feedforward network first before jumping into cnns - just flatten that 4x12x12 into 576 inputs, add maybe 2 hidden layers with like 256 and 128 neurons, then your 4 outputs with softmax. sometimes simple works better than complex for snake ai and its way easier to debug when things go wrong.