关于人工智能:浅谈TD3从算法原理到代码实现

6次阅读

共计 4327 个字符,预计需要花费 11 分钟才能阅读完成。

本文首发于:行者 AI

家喻户晓,在基于价值学习的强化学习算法中,如 DQN,函数近似误差是导致 Q 值高估和次优策略的起因。咱们表明这个问题仍然在 AC 框架中存在,并提出了新的机制去最小化它对演员(策略函数)和评论家(估值函数)的影响。咱们的算法建设在双 Q 学习的根底上,通过选取两个估值函数中的较小值,从而限度它对 Q 值的过高估计。(出自 TD3 论文摘要)

1. 什么是 TD3

TD3 是 Twin Delayed Deep Deterministic policy gradient algorithm 的全称。TD3 全称中 Deep Deterministic policy gradient algorithm 就是 DDPG 的全称。那么 DDPG 和 TD3 有何渊源呢?其实简略的说,TD3 是 DDPG 的一个优化版本。

1.1 TD3 为什么被提出

在强化学习中,对于离散化的动作的学习,都是以 DQN 为根底的,DQN 则是通过的 $argMaxQ_{table}$ 的形式去抉择动作,往往都会过大的预计价值函数,从而造成误差。在间断的动作管制的 AC 框架中,如果每一步都采纳这种形式去预计,导致误差一步一步的累加,导致不能找到最优策略,最终使算法不能失去收敛。

1.2 TD3 在 DDPG 的根底上都做了些什么

  • 应用两个 Critic 网络。应用两个网络对动作价值函数进行预计,(这 Double DQN 的思维差不多)。在训练的时候抉择 $min(Q^{\theta1}(s,a),Q^{\theta2}(s,a))$ 作为估计值。
  • 应用软更新的形式。不再采纳间接复制,而是应用 $\theta = \tau\theta^′ + (1 – \tau)\theta$ 的形式更新网络参数。
  • 应用策略乐音。应用 Epsilon-Greedy 在摸索的时候应用了摸索乐音。(还是用了策略噪声,在更新参数的时候,用于平滑策略冀望)
  • 应用提早学习。Critic 网络更新的频率要比 Actor 网络更新的频率要大。
  • 应用梯度截取。将 Actor 的参数更新的梯度截取到某个范畴内。

2. TD3 算法思路

TD3 算法的大抵思路,首先初始化 3 个网络,别离为 $Q_{\theta1},Q_{\theta2},\pi_\phi$,参数为 $\theta_1,\theta_2,\phi$,在初始化 3 个 Target 网络,别离将开始初始化的 3 个网络参数别离对应的复制给 target 网络。$\theta{_1^′}\leftarrow\theta_1,\theta{_2^′}\leftarrow\theta_2,\phi_′\leftarrow\phi$。初始化 Replay Buffer $\beta$。而后通过循环迭代,一次次找到最优策略。每次迭代,在抉择 action 的值的时候退出了乐音,使 $a~\pi_\phi(s) + \epsilon$,$\epsilon \sim N(0,\sigma)$,而后将 $(s,a,r,s^′)$ 放入 $\beta$,当 $\beta$ 达到肯定的值时候。而后随机从 $\beta$ 中 Sample 出 Mini-Batch 个数据,通过 $\tilde{a} \sim\pi_{\phi^′}(s^′) + \epsilon$,$\epsilon \sim clip(N(0,\tilde\sigma),-c,c)$,计算出 $s^′$ 状态下对应的 Action 的值 $\tilde a$,通过 $s^′,\tilde a$,计算出 $targetQ1,targetQ2$,获取 $min(targetQ1,targetQ)$,为 $s^′$ 的 $targetQ$ 值。

通过贝尔曼方程计算 $s$ 的 $targetQ$ 值,通过两个 Current 网络依据 $s,a$ 别离计算出以后的 $Q$ 值,在将两个以后网络的 $Q$ 值和 $targetQ$ 值通过 MSE 计算 Loss,更新参数。Critic 网络更新之后,Actor 网络则采纳了延时更新,(个别采纳 Critic 更新 2 次,Actor 更新 1 次)。通过梯度回升的形式更新 Actor 网络。通过软更新的形式,更新 target 网络。

  • 为什么在更新 Critic 网络时,在计算 Action 值的时候退出乐音,是为了平滑后面退出的乐音。
  • 贝尔曼方程:针对一个间断的 MRP(Markov Reward Process)的过程(间断的状态处分过程),状态 $s$ 转移到下一个状态 $s^′$ 的概率的固定的,与后面的几轮状态无关。其中,$v$ 示意一个对以后状态 state 进行估值的函数。$\gamma$ 个别为趋近于 1,然而小于 1。

3. 代码实现

代码次要是依据 DDPG 的代码以及 TD3 的论文复现的,应用的是 Pytorch1.7 实现的。

3.1 搭建网络结构

Q1 网络结构次要是用于更新 Actor 网络

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        self.f1 = nn.Linear(state_dim, 256)
        self.f2 = nn.Linear(256, 128)
        self.f3 = nn.Linear(128, action_dim)
        self.max_action = max_action
    def forward(self,x):
        x = self.f1(x)
        x = F.relu(x)
        x = self.f2(x)
        x = F.relu(x)
        x = self.f3(x)
        return torch.tanh(x) * self.max_action
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic,self).__init__()
        self.f11 = nn.Linear(state_dim+action_dim, 256)
        self.f12 = nn.Linear(256, 128)
        self.f13 = nn.Linear(128, 1)

        self.f21 = nn.Linear(state_dim + action_dim, 256)
        self.f22 = nn.Linear(256, 128)
        self.f23 = nn.Linear(128, 1)

    def forward(self, state, action):
        sa = torch.cat([state, action], 1)

        x = self.f11(sa)
        x = F.relu(x)
        x = self.f12(x)
        x = F.relu(x)
        Q1 = self.f13(x)

        x = self.f21(sa)
        x = F.relu(x)
        x = self.f22(x)
        x = F.relu(x)
        Q2 = self.f23(x)

        return Q1, Q2

3.2 定义网络

 self.actor = Actor(self.state_dim, self.action_dim, self.max_action)
        self.target_actor = copy.deepcopy(self.actor)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)

        #定义 critic 网络
        self.critic = Critic(self.state_dim, self.action_dim)
        self.target_critic = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)

3.3 更新网络

更新网络采纳 软更新 提早更新 等形式

 def learn(self):
        self.total_it += 1
        data = self.buffer.smaple(size=128)
        state, action, done, state_next, reward = data
        with torch.no_grad:
            noise = (torch.rand_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
            next_action = (self.target_actor(state_next) + noise).clamp(-self.max_action, self.max_action)
            target_Q1,target_Q2 = self.target_critic(state_next, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + done * self.discount * target_Q
        current_Q1, current_Q2 = self.critic(state, action)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
        critic_loss.backward()
        self.critic_optimizer.step()

        if self.total_it % self.policy_freq == 0:

            q1,q2 = self.critic(state, self.actor(state))
            actor_loss = -torch.min(q1, q2).mean()

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()
            for param, target_param in zip(self.critic.parameters(), self.target_critic.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(), self.target_actor.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

4. 总结

TD3 是 DDPG 的一个升级版,在解决很多的问题上,成果要比 DDPG 的成果好的多,无论是训练速度,还是后果都有显著的进步。

5. 材料

  1. http://proceedings.mlr.press/…
正文完
 0