stable-baselines3中的SAC

现象

本来自己写了一个SAC模型用于测试parking环境(http://highway-env.farama.org/environments/parking/),该环境模拟自动停车过程,小车需要停到停车场随机的一个目标车位(见下视频)

无奈模型怎么也无法达到预期效果,经过多次测试发现,仅仅将DDPG修改为不确定性策略是可行的,但一旦加上最大熵(SAC的核心)模型就不不行了,然后使用stable-baselines3(https://github.com/DLR-RM/stable-baselines3/tree/master/stable_baselines3/sac)就可以,遂研究了一下两者代码的区别,发现,SB3默认会使用一个ent_coef的可学习参数,作为最大熵的系数,该参数可设置为固定值,以下分别为设置为固定值和可学习参数得到的效果:

使用固定值为1的效果,见日志中的 train/ent_coef,从最终效果中可以看到,好像小车有那么一点点趋势会向目标点靠近,但不多

注意到 train/ent_coef 的值经过学习后变得很小(初始值为1),仅为0.02,从视频也能看到模型训练达到了预期的效果

随后我在我自己的SAC代码中也加入该可学习参数,发现确实也work了,该参数经过一段时间的学习也变得非常小:

分析

对于SAC来说,貌似有两个理论上很少提到但是实际上又不可或缺的东西

log_ent_coef 最大熵的温度控制

该参数原始代码大概长这样:

# 定义一个可学习参数
log_ent_coef = torch.log(torch.ones(1, device=device)).requires_grad_(True)
ent_coef_optimizer = torch.optim.Adam([log_ent_coef], lr=1e-3)
...
# 将上述参数作为 -log_prob 的系数,即最大熵的系数
ent_coef = torch.exp(log_ent_coef.detach())
target_Q = self.critic_target(next_states, act_next) - ent_coef * log_prob_next.detach().sum(dim=1).reshape([-1, 1])
...

# 参数学习
act, act_log_prob = actor(current_states)

ent_coef_loss = -(log_ent_coef * (act_log_prob - np.prod(env.action_space.shape)).detach()).mean()
ent_coef_loss.backward()

如何理解这个loss函数呢?以下是我由果推因的想法:

将loss单独拿出来看:

ent_coef_loss = -(log_ent_coef * (act_log_prob - np.prod(env.action_space.shape)).detach()).mean()  # np.prod(env.action_space.shape)=2

需要先说明的是,act_log_prob 这个变量是当前policy在当前state下得到的动作的log probability,由于模型输出的是一个高斯分布,其均值和方差可能为任意数,所以高斯密度函数输出的值也可能为任意数,故而 log probability 也可能为任意值,有正有负。

再者action可能是一个集合,而一般计算 log_prob 时会将每个action对应的 log_prob 加起来一起返回,我觉得这是理解 act_log_prob - np.prod(env.action_space.shape) 的关键。这个np.prod(env.action_space.shape)其实可以理解为 log_prob 的期望,事实上,在parking这个游戏中,env.action_space.shape=(2,),所以SB3中,默认 log_prob 的期望为2,即对单个action期望为1(因为是两个action的log_prob相加,所以期望要乘以2)。

而log_prob是和对应action的概率成正相关的,log_ent_coef 则可以看成是熵对reward的增益,直觉上来说,当模型的输出愈发趋近于确定性策略,则更应该鼓励其探索

由此,可以得出结论:

  • 当预测的 log_prob 即对某个action的输出概率越大,则说明可能探索力度还不够,我就应该增大熵对reward的增益。反之,
  • 当输出action的概率较小,则说明探索的比较好,应该减小熵的增益

所以,当你希望SAC偏向确定性策略时,可以适当设置一个大一点的 log_prob 期望值(即上面代码中的 np.prod(env.action_space.shape)=2,可以将其设置为2.5或3等更大的值)

Squashed Gaussian Trick

policy模型输出一个action的高斯分布,然后从该分布中sample一个action,问题就在于sample出的action值域是负无穷到正无穷,而实际环境中的action是存在边界的,所以需要使用有界函数对输出的action进行限定,例如 sigmoid、tanh 等函数

然而,对于SAC来说,你不仅要输出一个action,你还得输出一个action的log值,即log_prob,用于作为最大熵。那么既然输出的是tanh(act),那么实际应该是通过计算 tanh(act) 的概率来计算最终需要的log_prob值。

这里有点绕,一开始我觉得只需要将输出的action进行tanh应该足够了,而log_prob本身可以作为模型的学习结果。但再细想一下其实是不对的,因为模型的输出就是一个高斯分布,对于这个分布来说,它有无穷个action,并且每个action都有一个一一对应的log_prob,而你现在强行将无穷的action压缩到 [-1, 1] 的有限区间,也就意味着可能一个新的action对应了很多的log_prob(多个action都被映射到了同一个值),但是你policy输出的却只是某一个action对应的log_prob,这是不对的。

当然,由于log_ent_coef这个可学习参数的存在,即使是不对log_prob进行转换,模型仍然是可以训练出来的,只不过它的log_ent_coef参数会学得接近0,进而将log_prob短路,使得最大熵丧失效果:

总的来说,上面从直觉方面简单说明了log_prob也需要被转换的必要性,下面说下怎么转换:

SAC论文中有提到具体的细节:

这里对这两个公式做进一步详细的解释:

首先需要知道密度函数的换底公式:

设一只随机变量X的分布函数为 \( F_X(x) \) 和密度函数 \( p_X(x) \),设 \( Y=g(X) \),其中函数 \( g(·) \) 是严格单调可导的函数,其反函数为 \( h(·) \),则Y的密度函数为:

\[ P_Y(y) = p_X(h(y))|h'(y)| \]

以下为公式的证明过程:

注:\( F_Y(y) \) 是概率分布函数,\( p_Y(y) \) 是概率密度函数,概率分布是概率密度的积分,或者说概率密度是概率分布的导数。

现在将 \( h(·) = tanh(·) \) 代入上式就可以得到论文中的公式(20)了。

由于action可能不是一个数而是一个数组,所以论文中以雅可比行列式的方式简化计算,但这并不是关键。公式(21)就是对公式(20)的两边同时求log,并将右边使用log的性质进行展开。式中的 \(1-tanh^2(u_i)\)就是tanh函数的导数,也就是 \( \frac{da}{du} \)

我开始有个疑问,就是既然log_prob是依据 tanh(act) 计算出来的,那么为什么不能直接使用分布去计算 tanh(act) 的log_prob 呢,即 dist.log_prob(tanh(act)),其中dist是policy返回的概率密度对象,而要使用论文中的公式(21)进行计算呢?实际去计算的话你会发现两者并不相等,还是之前那个想法,tanh(act)中的act才是从dist这个分布中sample出来的,而tanh(act)是对这个分布下多个act的合并,它应该对应多个log_prob,但如果你直接将tanh(act)带入dist计算log_prob则只是计算的一个action对应的log_prob,直觉上这是有问题的。

这部分参考:

https://blog.csdn.net/bingfeiqiji/article/details/81908948

https://stats.stackexchange.com/questions/239588/derivation-of-change-of-variables-of-a-probability-density-function

http://rail.eecs.berkeley.edu/deeprlcourse-fa18/static/homeworks/hw5b.pdf

https://zhuanlan.zhihu.com/p/138021330

https://blog.csdn.net/qq_35200479/article/details/84502844

其他

研究stable-baselines3的代码的过程怎么说呢,五味陈杂。

我不知道是不是我功底还不够,还是说他代码写得本身就很烂,只能说,如果直接拿来用,真的非常好用,但是如果你想去看懂他的逻辑甚至于自己修改代码调试,是真的非常痛苦。代码的抽象程度非常高,导致逻辑非常错综复杂,追踪源码的时候,跳来跳去,代码之间的联系也非常紧密,你也说不上他是高耦合的实现还是高内聚的实现。

我个人比较推崇纯函数的概念,但SB3貌似完全没这方面的想法,经常就是,你追踪一个方法,进去之后突然来一大堆 self.xxx 的,然后你才意识到你可能错过了某些东西,比如下面这段代码:

调了两个实例方法,可以看到第二个被调用函数的第一个参数是第一个函数的返回值,但你知道第二个参数是哪来的吗,然后通过调试你发现,第二个参数貌似也是第一个函数调用设置的,只是没有作为返回值返回,继续追踪代码,你会发现他调用了一个 get_actions() 的方法,该方法调用了一个 sample() 的方法,这个方法是一个抽象方法,所以你必须在代码运行起来后进行断点追踪

进入到实现类中才终于发现端倪,哦,原来在这,然后你就会想,如果我子类没有给这个成员变量赋值,是不是后面的代码就行不通了。

但这个过程中,你已经在多个代码文件中反复横跳了多次,以至于你都不确定此时的self和之前那个self,到底是不是同一个self。于是你陷入了沉思,你会后悔当初应该多花点时间学习设计模式,你会怀疑自己到底是不是一个合格的软件开发,进而反思自己到底适不适合做这一行

很多时候我就在想,在一个新的模型被构思起来之后,研究员实现之后发现与自己设想的差太多了,然后他就想方设法加上各种各样的trick,最终模型生效了,皆大欢喜。但是这些trick是不是模型的一部分呢,说是吧,从理论上讲它们只是起到了催化剂的作用,说不是吧,没有它们又不行。

我还想,到底还有多少好的idea都是因为一些trick没解决导致被丢弃了。

从上面的SB3的SAC实现来看,他的这个log_ent_coef其实对熵是有短路效果的,也就是说,可能模型能够被训练,这个熵的作用并没有理论上的那么大,但它却是SAC的核心。

Leave a Comment