机器学习之深度Q网络

目录

  1. 简介
  2. 背景知识
  3. DQN核心原理
  4. 关键技术创新
  5. 数学公式推导
  6. 算法实现
  7. 应用场景
  8. 局限性与改进
  9. 代码示例
  10. 总结

简介

深度Q网络(Deep Q-Network,简称DQN)是一种结合了深度学习和强化学习的算法,由DeepMind团队在2013年提出,并于2015年在《Nature》期刊上发表。DQN通过使用神经网络来近似Q函数,成功地将Q学习算法扩展到高维状态空间,使得智能体能够直接从原始输入(如图像像素)中学习策略。

主要贡献

  • 首个成功将深度学习与强化学习结合的算法
  • 在Atari游戏上达到人类玩家水平
  • 引入经验回放和目标网络解决训练不稳定问题

背景知识

Q学习 (Q-Learning)

Q学习是一种基于值的强化学习算法,通过学习状态-动作值函数(Q函数)来指导智能体的决策。

Q函数定义

Q函数 Q ( s , a ) Q(s, a) Q(s,a) 表示在状态 s s s 下采取动作 a a a 后,遵循最优策略所能获得的期望累积回报:

Q ( s , a ) = E [ ∑ t = 0 ∞ γ t r t + 1 ∣ s 0 = s , a 0 = a ] Q(s, a) = \mathbb{E}\left[\sum_{t=0}^{\infty} \gamma^t r_{t+1} \bigg| s_0 = s, a_0 = a\right] Q(s,a)=E[t=0γtrt+1 s0=s,a0=a]

其中:

  • γ ∈ [ 0 , 1 ] \gamma \in [0, 1] γ[0,1] 是折扣因子
  • r t + 1 r_{t+1} rt+1 t + 1 t+1 t+1 时刻的奖励
Bellman方程

Q函数满足以下Bellman最优方程:

Q ∗ ( s , a ) = E s ′ ∼ P ( ⋅ ∣ s , a ) [ r + γ max ⁡ a ′ Q ∗ ( s ′ , a ′ ) ] Q^*(s, a) = \mathbb{E}_{s' \sim P(\cdot|s,a)}\left[r + \gamma \max_{a'} Q^*(s', a')\right] Q(s,a)=EsP(s,a)[r+γamaxQ(s,a)]

Q学习更新规则

传统Q学习使用表格存储Q值,更新规则为:

Q ( s , a ) ← Q ( s , a ) + α [ r + γ max ⁡ a ′ Q ( s ′ , a ′ ) − Q ( s , a ) ] Q(s, a) \leftarrow Q(s, a) + \alpha \left[r + \gamma \max_{a'} Q(s', a') - Q(s, a)\right] Q(s,a)Q(s,a)+α[r+γamaxQ(s,a)Q(s,a)]

其中 α \alpha α 是学习率。


DQN核心原理

问题动机

传统Q学习面临的主要挑战:

  1. 维数灾难:当状态空间很大时(如图像输入),表格存储不可行
  2. 样本相关性:连续采样的样本高度相关,导致训练不稳定
  3. 非平稳目标:目标值随着Q函数的变化而不断变化

DQN解决方案

DQN使用深度神经网络来近似Q函数:

Q ( s , a ; θ ) ≈ Q ∗ ( s , a ) Q(s, a; \theta) \approx Q^*(s, a) Q(s,a;θ)Q(s,a)

其中 θ \theta θ 是神经网络的参数。

网络结构

典型的DQN网络结构包括:

  • 输入层:接收状态表示(如图像像素)
  • 隐藏层:多个卷积层(处理图像)或全连接层
  • 输出层:输出每个动作的Q值

关键技术创新

1. 经验回放 (Experience Replay)

原理

将智能体的经验 ( s t , a t , r t , s t + 1 ) (s_t, a_t, r_t, s_{t+1}) (st,at,rt,st+1) 存储在经验池中,训练时从中随机采样。

优势
  • 打破样本相关性:随机采样使得训练样本独立同分布
  • 提高样本利用率:每个样本可以被多次使用
  • 稳定训练过程:减少数据分布的剧烈变化
实现

经验池通常使用循环缓冲区实现:

experience_pool = [(s_1, a_1, r_1, s'_1), (s_2, a_2, r_2, s'_2), ...]
minibatch = random.sample(experience_pool, batch_size)

2. 目标网络 (Target Network)

原理

使用一个独立的网络来计算目标值,该网络的参数定期从主网络复制。

目标值计算

使用目标网络计算TD目标:

y i = r i + γ max ⁡ a ′ Q ( s i ′ , a ′ ; θ − ) y_i = r_i + \gamma \max_{a'} Q(s'_i, a'; \theta^-) yi=ri+γamaxQ(si,a;θ)

其中 θ − \theta^- θ 是目标网络的参数。

优势
  • 稳定训练目标:目标值在一段时间内保持不变
  • 减少震荡:避免自举(bootstrapping)导致的发散
更新策略

目标网络参数的更新方式:

θ − ← θ (每C步更新一次) \theta^- \leftarrow \theta \quad \text{(每C步更新一次)} θθ(每C步更新一次)

或者使用软更新:

θ − ← τ θ + ( 1 − τ ) θ − \theta^- \leftarrow \tau \theta + (1 - \tau) \theta^- θτθ+(1τ)θ

其中 τ ≪ 1 \tau \ll 1 τ1 是软更新系数。


数学公式推导

损失函数

DQN使用均方误差作为损失函数:

L ( θ ) = E ( s , a , r , s ′ ) ∼ U ( D ) [ ( y − Q ( s , a ; θ ) ) 2 ] L(\theta) = \mathbb{E}_{(s,a,r,s') \sim U(D)}\left[\left(y - Q(s, a; \theta)\right)^2\right] L(θ)=E(s,a,r,s)U(D)[(yQ(s,a;θ))2]

其中:

  • D D D 是经验池
  • U ( D ) U(D) U(D) 表示从经验池中均匀采样
  • y y y 是TD目标值

梯度下降

对损失函数求梯度:

∇ θ L ( θ ) = E [ ( y − Q ( s , a ; θ ) ) ∇ θ Q ( s , a ; θ ) ] \nabla_\theta L(\theta) = \mathbb{E}\left[\left(y - Q(s, a; \theta)\right) \nabla_\theta Q(s, a; \theta)\right] θL(θ)=E[(yQ(s,a;θ))θQ(s,a;θ)]

使用随机梯度下降(SGD)更新参数:

θ ← θ + α ( y − Q ( s , a ; θ ) ) ∇ θ Q ( s , a ; θ ) \theta \leftarrow \theta + \alpha \left(y - Q(s, a; \theta)\right) \nabla_\theta Q(s, a; \theta) θθ+α(yQ(s,a;θ))θQ(s,a;θ)

TD误差

TD误差定义为:

δ = r + γ max ⁡ a ′ Q ( s ′ , a ′ ; θ − ) − Q ( s , a ; θ ) \delta = r + \gamma \max_{a'} Q(s', a'; \theta^-) - Q(s, a; \theta) δ=r+γamaxQ(s,a;θ)Q(s,a;θ)

TD误差反映了当前估计与目标值之间的差异。


算法实现

DQN算法伪代码

初始化经验池 D,容量为 N
初始化Q网络 Q(s,a;θ),参数为 θ
初始化目标Q网络 Q(s,a;θ⁻),参数为 θ⁻ = θ
设置学习率 α,折扣因子 γ,探索率 ε,批次大小 b,目标网络更新频率 C

for episode = 1 to M:
    初始化状态 s₁
    for t = 1 to T:
        # ε-贪婪策略选择动作
        以ε概率随机选择动作aₜ,否则选择 aₜ = argmaxₐ Q(sₜ, a; θ)
        
        # 执行动作并观察结果
        执行动作aₜ,获得奖励rₜ和下一状态sₜ₊₁
        
        # 存储经验到经验池
        将(sₜ, aₜ, rₜ, sₜ₊₁)存入经验池D
        
        # 更新当前状态
        sₜ ← sₜ₊₁
        
        # 当经验池中有足够样本时开始训练
        if 经验池D中有足够样本:
            # 从经验池中随机采样一个批次
            从D中随机采样一个minibatch {(sᵢ, aᵢ, rᵢ, s'ᵢ)}ᵢ₌₁ᵇ
            
            # 对每个样本计算目标值
            对每个样本 i = 1 到 b:
                如果 s'ᵢ 是终止状态:
                    yᵢ = rᵢ
                否则:
                    yᵢ = rᵢ + γ maxₐ' Q(s'ᵢ, a'; θ⁻)
            
            # 计算损失函数(均方误差)
            L(θ) = (1/b) Σᵢ₌₁ᵇ(yᵢ - Q(sᵢ, aᵢ; θ))²
            
            # 使用梯度下降更新网络参数
            θ ← θ - α∇θL(θ)
            
            # 定期更新目标网络参数
            if t mod C == 0:
                θ⁻ ← θ
                
        # 衰减探索率(可选)
        ε ← max(ε_min, ε * decay_rate)

数学表达式详解

1. 网络初始化
  • 主网络 Q ( s , a ; θ ) Q(s,a;\theta) Q(s,a;θ),参数为 θ \theta θ
  • 目标网络 Q ( s , a ; θ − ) Q(s,a;\theta^-) Q(s,a;θ),参数为 θ − = θ \theta^- = \theta θ=θ
2. 动作选择策略

使用 ε \varepsilon ε-贪婪策略:

a t = { 随机动作 以概率  ε arg ⁡ max ⁡ a Q ( s t , a ; θ ) 以概率  1 − ε a_t = \begin{cases} \text{随机动作} & \text{以概率 } \varepsilon \\ \arg\max_a Q(s_t, a; \theta) & \text{以概率 } 1-\varepsilon \end{cases} at={随机动作argmaxaQ(st,a;θ)以概率 ε以概率 1ε

3. 经验存储

经验四元组: ( s t , a t , r t , s t + 1 ) (s_t, a_t, r_t, s_{t+1}) (st,at,rt,st+1)

4. 目标值计算

对于每个样本 i i i,目标值 y i y_i yi 的计算:

y i = { r i 如果  s i ′  是终止状态 r i + γ max ⁡ a ′ Q ( s i ′ , a ′ ; θ − ) 否则 y_i = \begin{cases} r_i & \text{如果 } s'_i \text{ 是终止状态} \\ r_i + \gamma \max_{a'} Q(s'_i, a'; \theta^-) & \text{否则} \end{cases} yi={riri+γmaxaQ(si,a;θ)如果 si 是终止状态否则

其中 γ ∈ [ 0 , 1 ] \gamma \in [0,1] γ[0,1] 是折扣因子。

5. 损失函数

使用均方误差损失函数:

L ( θ ) = 1 b ∑ i = 1 b ( y i − Q ( s i , a i ; θ ) ) 2 L(\theta) = \frac{1}{b} \sum_{i=1}^{b} (y_i - Q(s_i, a_i; \theta))^2 L(θ)=b1i=1b(yiQ(si,ai;θ))2

其中 b b b 是批次大小。

6. 梯度更新

使用梯度下降更新网络参数:

θ ← θ − α ∇ θ L ( θ ) \theta \leftarrow \theta - \alpha \nabla_\theta L(\theta) θθαθL(θ)

其中 α \alpha α 是学习率, ∇ θ L ( θ ) \nabla_\theta L(\theta) θL(θ) 是损失函数对参数 θ \theta θ 的梯度:

∇ θ L ( θ ) = − 2 b ∑ i = 1 b ( y i − Q ( s i , a i ; θ ) ) ∇ θ Q ( s i , a i ; θ ) \nabla_\theta L(\theta) = -\frac{2}{b} \sum_{i=1}^{b} (y_i - Q(s_i, a_i; \theta)) \nabla_\theta Q(s_i, a_i; \theta) θL(θ)=b2i=1b(yiQ(si,ai;θ))θQ(si,ai;θ)

7. 目标网络更新

C C C 步更新一次目标网络:

θ − ← θ \theta^- \leftarrow \theta θθ

或者使用软更新策略:

θ − ← τ θ + ( 1 − τ ) θ − \theta^- \leftarrow \tau \theta + (1-\tau) \theta^- θτθ+(1τ)θ

其中 τ ≪ 1 \tau \ll 1 τ1 是软更新系数。

超参数

参数 典型值 说明
学习率 α \alpha α 0.0001 神经网络学习率
折扣因子 γ \gamma γ 0.99 未来奖励的权重
经验池大小 1,000,000 经验回放缓冲区容量
批次大小 32 每次训练使用的样本数
目标网络更新频率 C 1000 主网络参数复制到目标网络的步数
ε-贪婪初始值 1.0 初始探索概率
ε-贪婪最终值 0.1 最终探索概率
ε衰减步数 1,000,000 从初始到最终值的衰减步数

应用场景

1. 游戏AI

DQN在Atari 2600游戏上取得了突破性成果:

  • 49款游戏:在49款Atari游戏中达到人类玩家水平
  • 超越人类:在部分游戏(如Breakout)上超越人类玩家
  • 通用智能:同一算法架构适用于不同游戏

2. 机器人控制

  • 机械臂操作:学习抓取和放置物体
  • 移动机器人:导航和避障
  • 无人机控制:飞行姿态调整

3. 资源调度

  • 数据中心冷却:Google使用深度强化学习优化数据中心冷却系统
  • 网络流量调度:动态调整带宽分配
  • 云资源管理:自动扩缩容和负载均衡

4. 金融交易

  • 投资组合优化:动态调整资产配置
  • 高频交易:毫秒级决策
  • 风险管理:实时风险控制

局限性与改进

DQN的局限性

  1. 高估问题:Q值的最大值操作导致Q值被系统性高估

    Q ( s , a ) ≥ Q ∗ ( s , a ) Q(s, a) \geq Q^*(s, a) Q(s,a)Q(s,a)

  2. 样本效率低:需要大量训练样本才能收敛

  3. 探索不足:ε-贪婪策略在高维空间中探索效率低

  4. 不适用于连续动作空间:需要离散化动作空间

改进算法

1. Double DQN

解决Q值高估问题,使用主网络选择动作,目标网络评估:

y i = r i + γ Q ( s i ′ , arg ⁡ max ⁡ a ′ Q ( s i ′ , a ′ ; θ ) ; θ − ) y_i = r_i + \gamma Q(s'_i, \arg\max_{a'} Q(s'_i, a'; \theta); \theta^-) yi=ri+γQ(si,argamaxQ(si,a;θ);θ)

2. Dueling DQN

将Q值分解为状态价值和优势函数:

Q ( s , a ) = V ( s ) + A ( s , a ) Q(s, a) = V(s) + A(s, a) Q(s,a)=V(s)+A(s,a)

其中:

  • V ( s ) V(s) V(s) 是状态价值,与动作无关
  • A ( s , a ) A(s, a) A(s,a) 是优势函数,表示动作相对于平均水平的优劣
3. Prioritized Experience Replay

根据TD误差的绝对值对经验进行优先级采样:

p i = ∣ δ i ∣ + ϵ ∑ j ( ∣ δ j ∣ + ϵ ) p_i = \frac{|\delta_i| + \epsilon}{\sum_j (|\delta_j| + \epsilon)} pi=j(δj+ϵ)δi+ϵ

4. Rainbow DQN

结合多种改进技术:

  • Double Q-learning
  • Prioritized Experience Replay
  • Dueling networks
  • Multi-step learning
  • Distributional RL
  • Noisy Nets
5. DDPG (Deep Deterministic Policy Gradient)

扩展到连续动作空间,结合策略梯度和值函数近似。


代码示例

PyTorch实现

import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import numpy as np

class DQNNetwork(nn.Module):
    def __init__(self, state_size, action_size, hidden_size=64):
        super(DQNNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, action_size)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return self.fc3(x)

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            np.array(states),
            np.array(actions),
            np.array(rewards),
            np.array(next_states),
            np.array(dones)
        )
    
    def __len__(self):
        return len(self.buffer)

class DQNAgent:
    def __init__(self, state_size, action_size, lr=0.001, gamma=0.99, 
                 buffer_size=10000, batch_size=32, target_update=100):
        self.state_size = state_size
        self.action_size = action_size
        self.gamma = gamma
        self.batch_size = batch_size
        self.target_update = target_update
        self.update_count = 0
        
        # 主网络和目标网络
        self.q_network = DQNNetwork(state_size, action_size)
        self.target_network = DQNNetwork(state_size, action_size)
        self.target_network.load_state_dict(self.q_network.state_dict())
        
        # 优化器
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
        
        # 经验回放
        self.replay_buffer = ReplayBuffer(buffer_size)
        
        # ε-贪婪参数
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
    
    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.randrange(self.action_size)
        else:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            q_values = self.q_network(state_tensor)
            return q_values.argmax().item()
    
    def train(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        
        # 采样minibatch
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
        
        # 转换为tensor
        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones)
        
        # 计算当前Q值
        current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1))
        
        # 计算目标Q值
        next_q_values = self.target_network(next_states).max(1)[0].detach()
        target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
        
        # 计算损失
        loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
        
        # 优化
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # 更新目标网络
        self.update_count += 1
        if self.update_count % self.target_update == 0:
            self.target_network.load_state_dict(self.q_network.state_dict())
        
        # 衰减epsilon
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

训练循环示例

def train_dqn(env, agent, episodes=1000):
    scores = []
    
    for episode in range(episodes):
        state = env.reset()
        total_reward = 0
        done = False
        
        while not done:
            # 选择动作
            action = agent.select_action(state)
            
            # 执行动作
            next_state, reward, done, _ = env.step(action)
            
            # 存储经验
            agent.replay_buffer.push(state, action, reward, next_state, done)
            
            # 训练网络
            agent.train()
            
            state = next_state
            total_reward += reward
        
        scores.append(total_reward)
        
        if (episode + 1) % 100 == 0:
            avg_score = np.mean(scores[-100:])
            print(f"Episode {episode + 1}, Average Score: {avg_score:.2f}, Epsilon: {agent.epsilon:.3f}")
    
    return scores

总结

深度Q网络(DQN)是强化学习领域的重要里程碑,它成功地将深度学习的感知能力与强化学习的决策能力相结合。通过引入经验回放和目标网络两项关键技术,DQN解决了高维状态空间下的训练不稳定问题。

核心要点回顾

  1. 神经网络近似Q函数:使用深度神经网络处理高维状态输入
  2. 经验回放:打破样本相关性,提高样本利用率
  3. 目标网络:稳定训练目标,减少震荡
  4. ε-贪婪策略:平衡探索与利用

DQN的成功为后续的深度强化学习研究奠定了基础,催生了众多改进算法和应用场景,成为人工智能领域的重要技术之一。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐