6700万参数比肩万亿巨兽GPT-4!微软MIT等联手破解Transformer推理密码

  新智元报道

  编辑:桃子乔杨

  来自微软、MIT 等机构的学者提出了一种创新的训练范式,攻破了大模型的推理缺陷。他们通过因果模型构建数据集,直接教模型学习公理,结果只有 67M 参数的微型 Transformer 竟能媲美 GPT-4 的推理能力。

  「因果推理」绝对是当前 GenAI 热潮下的小众领域,但是它有一个大佬级的坚定支持者——Yann LeCun。

  他在推特上的日常操作之一,就是炮轰 Sora 等生成模型,并为自己坚信的因果推理领域摇旗呐喊。

  甚至,早在 2019 年 VentureBeat 的采访中,他就表达过这一观点:我们需要在深度学习模型中引入事件的因果关系,才能增强泛化能力,减少训练数据使用。

  对于当前最流行的模型架构 Transformer,我们能教它因果推理吗?

  最近,来自微软 MIT 等机构的研究人员提出了一种训练大模型新范式——公理框架(Axiomatic Framework)。

  论文中,作者从头开始训练了 6700 万参数的模型,仅使用了简单的因果链作为训练数据。

  令人惊讶的是,在推断复杂图表中的因果关系时,67M 模型的表现超越了十亿级参数 LLM,甚至可以与 GPT-4 相媲美。

  论文地址:https://arxiv.org/abs/2407.07612v1

  微软 MIT 等团队最新方法的提出,是受到了图灵奖得主 Judea Pearl 启发。

  Pearl 曾提出了结构化因果规则中的因果无关性公理,即直接通过符号化公理示例来教 Transformer 模型学习被动数据(passive data)。

  这种方法不同于传统机器学习模型,使用由公理推导出的数据。

  正如结果所示,通过公理训练,研究证明了 Transformer 模型可以学习因果,从而推断因果关系,并从相关性中识别因果性。

  这暗示了,像 GPT-4 等大模型的训练,可以通过网络数据中的带噪声的公理化示例学习因果知识,而无需进行干预实验。

  网友称赞道,「研究者的观点非常耐人寻味,因果推理一直是 LLM 的致命弱点,进一步发展这一领域,势在必行」。

  「这类研究可能是通向半 AGI 的一条途径」。

  研究背景

  因果推理(causal reasoning)是一种推理过程,遵守有特定因果性的预定义公理或规则。

  图灵奖得主 Judea Pearl 曾通过如下的「因果关系阶梯」(ladder of causation)定义了可能的因果推理类型。

  通常因果推理所用的公理或规则并不会被直接引入,模型学习的只是数据。公理或规则作为归纳偏差被纳入模型,比如通过正则化、模型架构或变量选择等方式。

  而这篇论文想要探讨的,就是模型能否从被动的符号演示中直接学习公理或规则。作者将这种方法称为「公理化训练」(axiomatic training)。

  假设因果公理都可以以如下形式表示: <前提,假设,结果> ,其中结果只有「是」和「否」两种形式。

  这基本类似于亚里士多德提出的「三段论」格式,比如 Judeal Pearl 书中提出的「碰撞公理」(collider axiom)就可以表示为:

前提:∐, ⟂̸⟂, ⟂̸⟂ 假设:A是否导致C? 结论:是

  这只是单个公理的表示,那么如何表达一个复杂系统中多个公理的组合呢?甚至,我们能用有限数量的公理表达任意因果模型吗?

  此处,论文引用了 Judea Pearl 和 David Galles 在 1997 年发表的一项研究,他们证明了,对于给定的稳定概率因果模型,都存在一组有限公理,可以充分表征对应的有向因果图。

  因果模型M=(X,U,F)被定义为内部变量X、外部变量U和一组结构方程F的集合,结构方程描述了变量X和U之间的因果关系。

  模型M的另一种等效表示方式就是有向图G,用有向边 Vi⭢Vj 表示两个节点 Vi 和 Vj 之间的因果关系。

  所谓的「稳定概率」(stable probabilistic)因果模型,是指他们对模型作出的稳定性假设,指M中所有的不相关性(X ↛ YZ)都是稳定的,写作:

  在稳定性假设下,Galles 和 Pearl 共描述了 6 个公理,而这篇论文主要关注传递性公理。对于稳定概率的因果模型,给定系统中的变量X、Y、Z,传递性公理可以写作:

  将上述表达式通过取反进一步简化,可以写出其含有因果相关性的版本:

  其中表达式左侧即为前提,右侧即为假设。

  这样的公理可以派生出数千个合成的符号表达式,从而用于向 Transformer 模型「教授」特定公理。

  公理化训练

  训练数据

  上述含有前提和假设的公理能映射到「是」或「否」的标签,一条训练数据就可以表示为{(P,H,L)}的元组形式。

  给定一个真实的因果图,就可以通过应用传递性公理(一次或多次),枚举出所有可能的N个元组{(P,H,L)},从而构建出数据集D。

  比如,因果图中包含 X1⭢X2⭢X3⭢…⭢Xn 这样的链拓扑时,一个可能的前提是 X1⭢X2∧X2⭢X3,相应的假设 X1⭢X3 的标签为「是」,而另一个假设 X3⭢X1 标签就为「否」。

  值得注意的是,论文中为了表达的清晰性,使用了数学语言进行描述,但实际上用于训练的数据集只包含自然语言。

  比如,上面例子中的前提应该表达为「X1 导致 X2,且 X2 导致 X3」。

  数据扰动:泛化的关键

  之前有研究表明,以「扰动」(perturbation)形式增加训练数据的可变性与多样性,有助于提升模型的泛化能力。

  因此,作者在不同层次上对训练数据引入结构化扰动,以最大化数据集分布的多样性。

  1)节点名称:传递链上每个节点的名称都由1~3 个字母/数字组成,长度和使用的特定字符是随机生成的。

  2)因果图拓扑结构:主要包含两种类型

  - 顺序结构(sequential):所有的因果边方向都是从后向前,共同形成一个典型的「传递链」,比如X⭢Y⭢Z这种形式

  - 随机翻转(random flipping):给定一个顺序结构的传递链,对其中一些边进行随机翻转,从而引入复杂性。比如X⭢Y⭢Z可以被修改为X⭢Y⭠Z。

  随机翻转可以在单一方向的链中添加分叉结构(X⭠Y⭢Z,fork)和碰撞结构(X⭢Y⭠Z,collider),它们是任何有向因果图的基本构建块,有助于提升模型进行跨结构泛化的能力。

  3)链长度:训练集中加入了长度不等的链,包含3~6 节点。

  损失函数

  论文没有采用训练 Transformer 模型常用的 next token 预测损失,而是根据给定数据集中每个元组的真实标签进行定义,表示为:

  位置编码

  除了训练数据和损失函数之外,另一个重要因素是位置编码的选择。

  之前有研究表明,位置编码机制对 Transformer 的序列长度泛化能力有明显影响,但不同的研究似乎得出了互相矛盾的结果。

  因此,作者在研究中分别尝试了不同的方法,包括可学习位置编码(LPE)、正弦位置编码(SPE)和无位置编码(NoPE)。

  训练和评估的整体流程如图 1 所示,Transformer 模型在顺序链和带有随机翻转的链上训练,长度为3~6 个节点。

  之后,训练过的模型在具有>6 个节点的更复杂结构上进行评估,其中节点平均的出度(out-degree)和入度(in-degree)都更大,序列更长,且引入了分支、反转(reversal)等复杂变化。

  实现细节:架构、分词器和训练过程

  具体来说,研究人员基于 GPT-2 的架构,训练了一个拥有 6700 万参数的解码器模型。

  该模型有 12 个注意力层、8 个注意力头,以及 512 个嵌入维度。

  值得一提的是,67M 模型是在各种训练数据集上,从头开始训练的。为了理解位置编码(PE)的影响,他们考虑了正弦位置编码(SPE)、可学习位置编码(LPE)以及不使用位置编码(NoPE)三种情况。

  所有模型都使用 AdamW 优化器进行训练,学习率为 1e-4,训练 100 个 epoch。

  由于训练数据集遵循特定结构,研究人员还开发了一个自定义分词器(custom tokenizer)。

  字母数字节点名称在字符级别进行分词,而像「causes」、「cause」、「Does」、「Yes」「No」这样的特殊术语则在词级别进行分词。

  简言之,字符级分词用于字母数字节点名称,词级分词用于特殊术语。

  这种方法可以避免在测试时,出现词汇表外(OOV)token,因为测试集中的字母数字节点名称可能与训练集中的不同。

  采用这种方法后,6700 万参数 Transformer 模型的词汇表大小为 69。

  实验结果

  复杂因果场景的泛化

  研究人员首先展示了,通过公理化训练的 Transformer 模型在泛化到更大、更复杂的因果图方面的表现,并将其与预训练的大模型进行了比较。

  序列长度泛化

  表 1 展示了不同模型在评估训练过程中,未见过的更长因果链时的准确率。

  在基线预训练语言模型中,GPT-4 在标准和随机翻转的因果链上都取得了最高的准确率。

  令人惊讶的是,尽管 TS2(NoPE)模型在训练过程中从未见过更长的序列,但它的表现能够与万亿参数规模的 GPT-4 模型相媲美。

  虽然训练时只用到了长度为3~6 个节点的因果链,但序列长度为7~13 时,TS2(NoPE)在标准和随机翻转的链上,获得了比 GPT-4 更高或相当的准确率。

  对于序列长度为 14-15 的情况下,其准确率有所下降(标准链为 0.85,随机翻转链为 0.78),但仍然显著高于 Gemini-Pro 、Phi-3 模型。

  需要注意的是,随机预测会得到 50% 的准确率,这表明通过公理化训练的 TS2(NoPE)模型,能够将其推理能力泛化到更长的序列上。

  节点名称转变

  对于在 TS2 数据集上训练的模型,研究人员还评估了其对变量名称变化的泛化能力(图 3)。

  结果发现,TS2(NoPE)对节点名称的变化很稳健,在引入新的、更长的名称时仍能保持较高的准确率。它还保持了对新节点名称较长序列的通用性,其表现与 GPT-4 相似。

  因果序列顺序

  与长度和节点名称的变化不同,反转(reversal)以及分支(branching)操作改变了因果结构,因此能更好地评估模型是否学习到了对结构的准确表示。

  在表 2b 中,TS2(NoPE)在长度不超过 8 的因果链上,获得的准确率高于 Gemini Pro、Phi-3。长度为 9 时,TS2(NoPE)的准确率为 0.73,与 Gemini Pro(0.74)相当。

  在表 2a 中,研究者还观察到对完全反转序列进行评估的类似模式。

  在这项任务中,公理训练模型 TS2(NoPE)在限制链长度为3-6 时,表现优于 GPT-4。特别是,其准确率(长度为 6 的链为 0.94)大大高于 Gemini Pro 和 Phi-3(分别为 0.62 和 0.69)。

  分支(Branching)

  分支可能是最有挑战性的任务,因为它引入了在训练期间未见的新结构。

  虽然 GPT-4 在图大小不断增大的情况下获得了最佳准确率,但 TS2(NoPE)模型在除一个节点外的所有图大小上,都比 Gemini Pro 获得了更高的准确率。

  即使在有 12 个节点和 1.4 个分支因子的图形上进行评估,TS2(NoPE)模型也能获得 70% 的准确率,明显优于随机模型(50%)。

  总结

  在所有评估设置中,公理化训练模型 TS2(NoPE)的性能明显优于随机基线,即使因果链的长度超过其训练数据。

  特别是,模型没有在完全反转的链上进行训练,它的表现也与规模更大的 GPT-4 模型相当(图 2)。

  在其他任务中,它的准确性往往优于或与 Gemini Pro、Phi-3 等十亿参数规模的模型相当。

  这些结果表明,经过公理训练的模型可以从简单因果序列的演示中,学会推理更复杂的因果结构。这表明公理训练在因果图推理方面的潜力。

  其他结果:数据多样性和位置编码的作用

  位置编码的作用

  比较不同位置编码选择的模型性能,研究人员发现没有位置编码的模型在更长的序列(最长到 15 个节点的链)和复杂的、未见过的图结构上都能很好地泛化,尽管它们仅在3-6 个节点的链上进行训练。

  使用正弦位置编码(SPE)和可学习位置编码(LPE)的模型在更长的链上表现也不错,但当节点名称长度增加时表现较差,即使是在节点数较少的链上也是如此(图3)。

  这种使用 SPE 和 LPE 的泛化失败,突出了模型无法处理训练集中序列的微小扰动。

  此外,SPE 在不同的结构维度上表现不佳(如分支)以及基于顺序的设置(shuffling 和反转)。

  可学习的位置编码在长度达 9 的线性链上表现良好,但之后急剧下降。

  总的来说,研究结果扩展了早期关于不使用位置编码(NoPE)有效性的研究,将其应用于理解因果序列的任务,并在测试时泛化到更长的长度和复杂的结构。

  数据扰动的重要性

  除了位置编码外,训练数据中序列的多样性也起着重要作用。

  仅在因果链上,训练的模型可以泛化到较长的链(表 1),但不能泛化到其他 DAG 结构(见图 4 中的翻转,图 2 中的反转,表 3 中的分支)。

  在 TS1 或 TS1 上训练的模型在所有情况下都具有通用性,包括随机翻转、顺序排列和分支;因此突出了通过随机翻转在边水平上纳入可变性的影响。

  不过,在不同任务中,研究发现 TS2 的准确率高于 TS1,即使 TS1 因随机翻转而产生了更多变化。

  这表明,虽然扰动有助于结构泛化,但过度的扰动可能会阻碍结构泛化。

  使用公理训练从相关性推断因果关系

  接下来,作者研究这种能力是否可以转移到其他因果任务上。

  为此,研究人员将公理化训练应用于一个任务,该任务是从观察数据中的相关性陈述推断因果关系。

  如图 5 所示,每个数据实例包括用自然语言描述的 3 到 6 个节点图的相关关系;目标是推断假设的真值,判断任何给定节点之间是否存在直接或间接关系,以及可能存在的碰撞节点和混杂因素。

  这个任务比应用传递性公理要困难得多。

  由于任务的复杂性,结果发现像 Gemini Pro、Phi-3 这样的预训练模型的表现与随机猜测相似(准确率为 52%)。

  虽然 GPT-4 的表现稍好一些,但其性能仍然较低(准确率为 58%)。

  值得注意的是,研究者的小型 Transformer 模型表现优于所有基线模型,准确率达到 64%,比 GPT-4 高出6%。

  通过进一步探索不同的训练设置,公理化训练的 Transformer 模型可能会在这类因果推理任务上得到进一步的优化。

  总的来说,研究人员认为公理化训练是教 Transformer 模型学习因果关系的一种很有前景的方法。

  受 Judea Pearl 愿景的启发,这项工作代表着一个潜在的新科学前沿——因果关系研究和语言模型的交叉点上。

  参考资料:

  https://arxiv.org/abs/2407.07612v1

  https://x.com/AniketVashisht8/status/1811752011399877014