Seq2Seq整理

一开始计划在国庆节结束的时候完成seq2seq的第一阶段的工作,但是没想到一直拖到今天,而且过程很曲折。原因在于seq2seq模型比想象中复杂很多,而且至今依然有两个最关键的问题没有解决:
1、隐藏状态h和c的初始值如何界定?在最早的rnn实现中就抛出了这个问题,但是rnn的模型可以通过循环训练来降低h初始值的影响。但是谷歌模型中并没有强调h和c的初始值的影响问题。实际中可以有三种方案:a)以一个固定的随机值作为每一次训练的起点;b)h和c全0启动;c)将最后一个和第一个h掉头接到一起。直观讲应该方案c会好一些,但是这样会很取巧,也不好看。方案b最公平,但是状态或开关哪一个全0启动都是很麻烦的。(直角坐标系的0点是强极性的很难处理)最后我选择方案a,这一方案公平来说不如b,效果直观说又不如c。




2、decoder向encoder传递隐藏层残差的问题,老实说这个问题在谷歌和keras的代码实现中均没有正面遇到,因为TensorFlow还是Torch这类框架都能自动计算残差,而我希望能够手动计算过去,翻了git上其它几个工程也都没有类似的实现,因为只要是依据框架设计的那么就不会直面这一问题。但限于数学水平欠缺,我在模型中传递的dh和真实值可能有很大出入,而且h和c的残差如何同时传到encoder我也没有思路,所以这一块相当于没怎么做,具体对结果的影响有多大我也不好说。

在上面两个我怀疑的问题的下,第一版按部就班实现的模型几乎没有效果。花了很长时间检查和调试之后,我决定放弃这一方案,因为我对自己推导的数学式子没有信心,并且lstm的参数实在太多按照工程的方式查找异常状态也很困难。主要是很难分辨什么是异常,比如h的梯度一开始特别小(单个维度小于0.01),后面又变得非常大(1万个epouch后达到80+),然后又开始减小,参数可读性太差,想了一些招,但是最终也没有定位到问题。

一筹莫展中,这个工作搁置了两天,昨天晚上睡觉前我突然想到要么干脆不用lstm回归rnn来检查效果,rnn的话只有一个状态h,其残差的传递很清晰,而且rnn的h可读性较高,比较容易观察到参数的可解释性变化。并且转念又想seq2seq其实和rnn差别并不大,rnn的话是一个n对n的网络,seq2seq只不过是一个n对m的网络,即输入输出并不一对一,那么其实我们可以先保留这个目的,然后弱化手段,再从0开始构造。这样一变化,其实我们根本不需要对ABC阶段进行学习,只要有某种手段将ABC的信息翻译为h就行,至于h是什么样子其实是无所谓的,因为运行良好的残差学习能够对h的任何组合方式进行适应。这样一来,ABC这边其实就变成了一个dnn网络(这里有趣的一点是google的训练过程使用了混淆输入词序的trick,这在概念上和弱化rnn一致)。不过当你真正考虑消除掉rnn的序列的时候才会记得rnn的好处——理论上对无限长序列的信息都能建模(虽然工程中几个词以前的东西就很难起作用了),但是这就好比炼金,不试试咋知道行不行呢?我在实际的实现中没有真的就用dnn来压缩ABC,而是将h的初始值设定方案设为c(先保证有效果),然后只在WXYZ阶段进行训练,这样设计下效果立刻就出来了。




实验中我设置xyz是干扰项,固定模式有三种:”abc-de”, “bcd-ea”, “cde-ab”,eos符号是”@”
生成了输入序列是:

1
a b c y z z @ d e @ y y b c d x z @ e a @ z c d e y z y @ a b @ z a b c x @ d e @ x x b c d y @ e a @ x x c d e y y z @ a b @ y y z a b c y y @ d e @ y b c d y @ e a @ c d e x x @ a b @ z x a b c x z @ d e @ x x y b c d @ e a @ y c d e @ a b @ y y x a b c x y @ d e @ z y b c d @ e a @ c d e y z @ a b @ y a b c z z @ d e @ b c d @ e a @ c d e z @ a b @ y z a b c z @ d e @ y y y b c d x @ e a @

测试序列是:

1
a b c @ d e @ b c d @ e a @ c d e @ a b @

另外值得一提的是,训练的中间阶段h的值几乎各个维度都成了1:




这意味着h整个是没起啥作用的。而且在这几天的debug过程中我发现各个参数时不时就会变得特别小或者特别大,一般来说这些参数要么起状态作用,要么起开关作用,无论是哪个极大或极小都不好。

另外当工作到这个阶段之后,我突然体会到早期做流事件处理时候的体会了,都是差不多的场景,采用了差不多的思路,只不过rnn可读性几乎为0。

后面计划开展两个方向:
1、用框架来做一部分对比工作
2、想办法解决向量的极值问题

【github】https://github.com/wangyaomail/word2vec-demo/blob/master/src/main/java/cn/dendarii/Seq2SeqDemo2.java