用 Python + NEAT 算法让 AI 从零进化玩转贪吃蛇

在这里插入图片描述

引言

传统贪吃蛇 AI 多半基于 A* 或 BFS 寻路,让蛇硬算最短路径去吃东西。这种方法虽然有效,但写出来的策略是固定的,缺乏“智能感”。一旦蛇身变长,很容易把自己绕进死胡同。

换一种思路:不让 AI 学会怎么玩,而是让它自己从零开始“进化”出玩游戏的策略。这就是本文要介绍的方法——基于 NEAT(NeuroEvolution of Augmenting Topologies)算法,结合 Pygame 模拟环境,让一群蛇通过遗传算法不断迭代,最终涌现出能长期存活的“蛇王”。

本文将详细拆解 AI 的感知方式、决策机制、环境设计,并提供经过最新版 neat-python 踩坑后整理好的完整代码,直接复制即可运行。


一、核心原理:NEAT 与进化策略

1. 什么是 NEAT

NEAT 是一种进化神经网络的方法。与常规深度学习不同,NEAT 不依赖梯度反向传播,而是通过遗传算法同时优化网络的权重拓扑结构(层数、神经元数量、连接方式)。每一代中,表现好的个体会被保留,它们的“基因”(网络结构和权重)经过交叉和变异产生下一代。这样经过多代繁衍,网络结构会逐渐复杂,解决问题的能力也逐步提升。

2. 本项目的进化框架

  • 种群大小:每代 50 条蛇(可在配置文件中修改)。
  • 输入:24 维环境特征(下文详述)。
  • 输出:4 个方向(上、下、左、右)的得分,选最高分方向移动。
  • 适应度函数:综合存活时间、进食数量、死亡惩罚,作为自然选择的依据。

二、游戏环境设计:为 AI 进化扫清障碍

为了让 AI 的得分真实反映其策略水平,游戏环境必须排除一切随机干扰。食物刷新策略是其中关键。

1. 平行宇宙:每条蛇吃自己的食物

画面中同时跑 50 条蛇,但它们不在同一个棋盘上抢食物。初始化时,每条蛇对应一个独立的 Food 实例,通过数组索引一一绑定:

snakes.append(Snake())
foods.append(Food())

这样做保证了 控制变量法:一条蛇的得分完全取决于它自己的寻路能力,不会因为抢到别人面前的食物而“作弊”。

2. 网格化生成:严格对齐坐标系

游戏窗口 400×400,蛇每次移动 20 像素(BLOCK_SIZE)。食物如果随机在任意坐标,蛇头永远无法正好踩中。因此食物生成必须“卡”在 20 的整数倍上:

self.pos = (random.randint(0, (WIDTH - BLOCK_SIZE) // BLOCK_SIZE) * BLOCK_SIZE,
            random.randint(0, (HEIGHT - BLOCK_SIZE) // BLOCK_SIZE) * BLOCK_SIZE)

这样食物只会出现在 0,20,40,…,380 这些格点上,保证精准碰撞。

3. 防重叠机制:食物绝不刷在蛇肚子里

当蛇吃到一个食物后,新食物必须出现在空地上。如果直接随机,很可能刷在蛇身上(尤其蛇很长时),导致蛇一动就死,这不公平。因此采用 while True 循环反复尝试,直到位置不在 snake.body 中:

while True:
    new_food = Food()
    if new_food.pos not in snake.body:
        foods[x] = new_food
        break

这个机制让 AI 的死亡只归因于策略失误,而非环境 Bug。

4. 蛇的生命周期与死亡判定

每条蛇有三个死法:

  • 撞墙:头部坐标超出边界。
  • 咬自己:头部与身体任何一节重合。
  • 饿死:每走一步 steps_left 减 1,吃到食物加 100,初始 150 步。归零则饿死。

饿死机制防止蛇在原地转圈刷存活分,逼迫它必须不断觅食。


三、AI 的感知与决策

1. 八向雷达(24 维输入)

AI 没有任何全局视野,只能通过雷达感知周围。从蛇头向八个方向(上、下、左、右、四个对角线)发射射线,每条射线返回三个值:

  • 到墙壁的距离(归一化:1.0 / step,越近警报值越大)
  • 该方向上是否有食物(有=1.0,无=0.0)
  • 到自身身体的距离(同样 1.0 / step,越近越危险)

8 个方向 × 3 = 24 维输入向量,构成 AI 的“眼睛”。

2. 决策网络

24 个值输入神经网络(由 NEAT 自动生成拓扑),输出层 4 个节点分别代表上、下、左、右的得分。采用 action = output.index(max(output)) 选取最高分方向作为移动方向。

3. 适应度函数(奖惩规则)

  • 每存活一步:fitness += 0.1(鼓励多动)
  • 吃到食物:fitness += 10,同时 steps_left += 100(强烈激励觅食)
  • 死亡:fitness -= 2(严厉惩罚)

这个奖惩体系引导 AI 学会主动找食物、避开障碍,并且活得久。


四、进化流程

1. 世代交替

每代开始时同时运行 50 条蛇。主循环条件:while run and len(snakes) > 0。当所有蛇死亡(len(snakes)==0)时,退出当前代,NEAT 接管进行遗传操作。

2. 遗传操作

NEAT 根据每条蛇的最终适应度,淘汰低分个体,保留高分个体,然后进行交叉(基因重组)和变异(权重扰动、增加节点/连接等),生成全新的 50 条蛇进入下一代。


五、代码实现(开箱即用)

环境准备

pip install pygame neat-python

1. 配置文件 config-feedforward.txt

[NEAT]
fitness_criterion     = max
fitness_threshold     = 3000
pop_size              = 50
reset_on_extinction   = False
no_fitness_termination = False

[DefaultGenome]
single_structural_mutation = False
structural_mutation_surer  = default

feed_forward            = True
num_inputs              = 24
num_hidden              = 0
num_outputs             = 4
initial_connection      = full

activation_default      = relu
activation_mutate_rate  = 0.1
activation_options      = relu

aggregation_default     = sum
aggregation_mutate_rate = 0.0
aggregation_options     = sum

bias_init_type          = gaussian
bias_init_mean          = 0.0
bias_init_stdev         = 1.0
bias_max_value          = 30.0
bias_min_value          = -30.0
bias_mutate_power       = 0.5
bias_mutate_rate        = 0.7
bias_replace_rate       = 0.1

response_init_type      = gaussian
response_init_mean      = 1.0
response_init_stdev     = 0.0
response_max_value      = 30.0
response_min_value      = -30.0
response_mutate_power   = 0.0
response_mutate_rate    = 0.0
response_replace_rate   = 0.0

weight_init_type        = gaussian
weight_init_mean        = 0.0
weight_init_stdev       = 1.0
weight_max_value        = 30
weight_min_value        = -30
weight_mutate_power     = 0.5
weight_mutate_rate      = 0.8
weight_replace_rate     = 0.1

enabled_default         = True
enabled_mutate_rate     = 0.01
enabled_rate_to_false_add = 0.0
enabled_rate_to_true_add  = 0.0

node_add_prob           = 0.2
node_delete_prob        = 0.2
conn_add_prob           = 0.5
conn_delete_prob        = 0.5

compatibility_disjoint_coefficient = 1.0
compatibility_weight_coefficient   = 0.5

[DefaultSpeciesSet]
compatibility_threshold = 3.0

[DefaultStagnation]
species_fitness_func = max
max_stagnation       = 20
species_elitism      = 2

[DefaultReproduction]
elitism            = 2
survival_threshold = 0.2
min_species_size   = 1

2. 主程序 main.py

import pygame
import random
import os
import neat

# --- 游戏参数配置 ---
WIDTH, HEIGHT = 400, 400
BLOCK_SIZE = 20
GENERATION = 0


class Snake:
    def __init__(self):
        self.body = [(WIDTH // 2, HEIGHT // 2)]
        self.direction = (0, -BLOCK_SIZE)
        self.alive = True
        self.score = 0
        self.steps_left = 150  # 初始步数

    def move(self):
        if not self.alive: return
        head_x, head_y = self.body[0]
        dir_x, dir_y = self.direction
        new_head = (head_x + dir_x, head_y + dir_y)

        # 撞墙检测
        if new_head[0] < 0 or new_head[0] >= WIDTH or new_head[1] < 0 or new_head[1] >= HEIGHT:
            self.alive = False
            return

        # 撞自己检测
        if new_head in self.body:
            self.alive = False
            return

        self.body.insert(0, new_head)
        self.steps_left -= 1

        if self.steps_left <= 0:
            self.alive = False


class Food:
    def __init__(self):
        self.pos = (random.randint(0, (WIDTH - BLOCK_SIZE) // BLOCK_SIZE) * BLOCK_SIZE,
                    random.randint(0, (HEIGHT - BLOCK_SIZE) // BLOCK_SIZE) * BLOCK_SIZE)


# --- 核心:8向视线雷达探测 (共 8*3=24 维输入) ---
def get_state(snake, food):
    state = []
    head_x, head_y = snake.body[0]

    # 8个探测方向:上、下、左、右、左上、左下、右上、右下
    directions = [
        (0, -BLOCK_SIZE), (0, BLOCK_SIZE), (-BLOCK_SIZE, 0), (BLOCK_SIZE, 0),
        (-BLOCK_SIZE, -BLOCK_SIZE), (-BLOCK_SIZE, BLOCK_SIZE), (BLOCK_SIZE, -BLOCK_SIZE), (BLOCK_SIZE, BLOCK_SIZE)
    ]

    for dir_x, dir_y in directions:
        dist_to_wall = 0
        dist_to_food = 0
        dist_to_body = 0

        food_found = False
        body_found = False
        step = 1

        curr_x, curr_y = head_x, head_y
        while True:
            curr_x += dir_x
            curr_y += dir_y

            # 1. 探测墙壁
            if curr_x < 0 or curr_x >= WIDTH or curr_y < 0 or curr_y >= HEIGHT:
                dist_to_wall = 1.0 / step  # 距离越近,激活值越大
                break

            # 2. 探测食物 (该射线上首次遇到)
            if not food_found and food.pos == (curr_x, curr_y):
                dist_to_food = 1.0
                food_found = True

            # 3. 探测自己身体 (该射线上首次遇到)
            if not body_found and (curr_x, curr_y) in snake.body:
                dist_to_body = 1.0 / step
                body_found = True

            step += 1

        state.extend([dist_to_wall, dist_to_food, dist_to_body])

    return state


# --- NEAT 算法评估函数 ---
def eval_genomes(genomes, config):
    global GENERATION
    GENERATION += 1

    pygame.init()
    screen = pygame.display.set_mode((WIDTH, HEIGHT))
    pygame.display.set_caption(f"AI Snake Evolution - Generation {GENERATION}")
    clock = pygame.time.Clock()

    nets = []
    snakes = []
    foods = []
    ge = []

    for genome_id, genome in genomes:
        genome.fitness = 0
        net = neat.nn.FeedForwardNetwork.create(genome, config)
        nets.append(net)
        snakes.append(Snake())
        foods.append(Food())
        ge.append(genome)

    run = True
    while run and len(snakes) > 0:
        clock.tick(60)  # 如果想飞速进化,改成 clock.tick(0)

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                run = False
                pygame.quit()
                quit()

        screen.fill((20, 20, 20))

        # 倒序遍历:防止在循环中执行 pop 导致索引越界
        for x in range(len(snakes) - 1, -1, -1):
            snake = snakes[x]

            state = get_state(snake, foods[x])
            output = nets[x].activate(state)

            action = output.index(max(output))
            if action == 0 and snake.direction != (0, BLOCK_SIZE): snake.direction = (0, -BLOCK_SIZE)
            if action == 1 and snake.direction != (0, -BLOCK_SIZE): snake.direction = (0, BLOCK_SIZE)
            if action == 2 and snake.direction != (BLOCK_SIZE, 0): snake.direction = (-BLOCK_SIZE, 0)
            if action == 3 and snake.direction != (-BLOCK_SIZE, 0): snake.direction = (BLOCK_SIZE, 0)

            snake.move()

            if snake.alive:
                ge[x].fitness += 0.1  # 每走一步没死,给微小奖励

                # 吃食物检测
                if snake.body[0] == foods[x].pos:
                    ge[x].fitness += 10
                    snake.score += 1
                    snake.steps_left += 100  # 吃到食物补充步数寿命

                    # 确保食物不会生成在蛇的身体上
                    while True:
                        new_food = Food()
                        if new_food.pos not in snake.body:
                            foods[x] = new_food
                            break
                else:
                    snake.body.pop()

            # 绘制食物
            pygame.draw.rect(screen, (255, 50, 50), (foods[x].pos[0], foods[x].pos[1], BLOCK_SIZE, BLOCK_SIZE))

            # 绘制蛇 (蛇头用不同颜色标识)
            for i, part in enumerate(snake.body):
                color = (100, 255, 100) if i == 0 else (50, 200, 50)
                pygame.draw.rect(screen, color, (part[0], part[1], BLOCK_SIZE, BLOCK_SIZE))

            # 死亡结算与清理
            if not snake.alive:
                ge[x].fitness -= 2  # 死亡惩罚
                snakes.pop(x)
                nets.pop(x)
                ge.pop(x)
                foods.pop(x)

        pygame.display.update()


def run_neat(config_file):
    config = neat.config.Config(neat.DefaultGenome, neat.DefaultReproduction,
                                neat.DefaultSpeciesSet, neat.DefaultStagnation,
                                config_file)

    p = neat.Population(config)
    p.add_reporter(neat.StdOutReporter(True))
    stats = neat.StatisticsReporter()
    p.add_reporter(stats)

    winner = p.run(eval_genomes, 1000)  # 最多进化 1000 代
    print('\nBest genome:\n{!s}'.format(winner))


if __name__ == "__main__":
    local_dir = os.path.dirname(__file__)
    config_path = os.path.join(local_dir, 'config-feedforward.txt')
    run_neat(config_path)

六、避坑指南:新版 neat-python 配置文件

如果你在网上找过 NEAT 贪吃蛇的旧代码,很可能运行就报错。因为 neat-python 新版对配置文件的检查变得极其严格,许多以前可以省略的参数现在必须显式声明。上面给出的配置文件已补全所有必填项(如 bias_init_typeresponse_init_typesingle_structural_mutation 等),直接复制使用不会出错。


七、运行观察与进化阶段

clock.tick(60) 保持为 60 可以看清蛇的行为;改为 clock.tick(0) 则飞速进化,适合快速迭代。

控制台会输出每一代的统计信息(平均适应度、最大适应度、物种数等)。如果盯着屏幕观察,会看到明显的阶段性进步:

  • 第 1~10 代:随机行为,蛇迅速撞墙或自转而死。
  • 第 10~30 代:少数蛇开始“趋光”——朝着有食物的方向直冲,但仍容易缠住自己。
  • 第 50~100 代:学会 S 形走位、贴墙走,懂得在长身体时预留空间。
  • 100 代以后:可能出现能长期存活的个体,几乎不犯低级错误。

这些策略并非人工编码,而是自然选择塑造出来的。


八、程序终止条件

程序会在以下情况结束:

  1. 达到最高代数(此处设为 1000 代),输出最佳基因组后退出。
  2. 适应度达到阈值(配置文件中的 fitness_threshold = 3000),提前宣布胜利。
  3. 物种停滞:如果连续 20 代没有进步,且 reset_on_extinction = False,所有物种灭绝,程序报错终止。
  4. 手动关闭窗口

九、进阶想法

  • 食物保质期:给食物加倒计时,超时消失并扣分,逼迫 AI 更主动觅食。
  • 增加输入维度:加入当前蛇长度、头部坐标等,让网络感知更丰富。
  • 调整变异概率:修改配置文件中的 conn_add_prob 等参数,观察进化速度变化。

十、总结

用 NEAT 让贪吃蛇从零进化,比传统寻路算法更有意思。你不需要告诉 AI 怎么躲墙、怎么绕圈,它自己会在无数次生死中“悟”出来。如果你也想体验当造物主的感觉,不妨复制代码跑一跑,看看多少代后会出现第一条“蛇王”。

欢迎在评论区分享你的进化成果。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐