新智元报道
编辑:LRS
研究人员提出了一种新的大型语言模型训练方法,通过一次性预测多个未来 tokens 来提高样本效率和模型性能,在代码和自然语言生成任务上均表现出显著优势,且不会增加训练时间,推理速度还能提升至三倍。
当前,大型语言模型,例如 GPT 和 Llama,主要是根据「前文的单词序列」对「下一个 token」进行预测的方式来训练。
但你有没有想过一个问题,为什么不对后文的 tokens 同时进行预测呢?
最近,Meta、巴黎高科路桥大学、巴黎萨克雷大学的研究人员就联合提出了一种新的训练方法,即一次性预测多个未来 tokens,可以提高模型的样本效率。
论文链接:https://arxiv.org/pdf/2404.19737
具体来说,在训练语料库的每一个位置,要求模型使用n个独立的输出头网络来预测紧随其后的n个 token,其中所有输出头都基于同一个模型主干。
研究人员将多 token 预测视作是一种辅助训练任务,实验发现该方法不仅能够提升模型在各种下游任务上的表现,而且不会增加训练时间,对代码生成和自然语言生成任务都是有益的。
随着模型尺寸的增大,该方法的优势变得更加明显,尤其是在进行多 epochs 训练时。
在编程等生成性任务的基准测试中,使用多 token 预测训练的模型的性能提升尤为显著,能够稳定地超过传统单 token 预测模型。
例如,13B 参数的模型在 HumanEval 基准测试中解决问题的能力比同等规模的单 token 模型高出 12%,在 MBPP 基准测试中高出 17%
此外,通过在小型算法任务上的实验,研究人员发现多 token 预测对于提升模型的归纳头(induction heads)和算法推理能力是有益的。
而且,使用多 token 预测训练的模型在推理时速度更快,最高可达三倍,即便是在处理大规模数据批次时也是如此。
多 token 预测
标准语言模型通过执行一个「下一个 token 预测」任务来对大型文本语料库进行学习,任务目标是最小化交叉熵损失,其中模型需要最大化「在给定之前 token 序列历史的条件下,预测下一个 token」的概率。
研究人员将「单 token 预测」任务泛化为「多 token 预测」,在训练预料上的每个位置,模型需要一次性预测未来n个 tokens,交叉熵损失改写为:
为了使问题可解,假设大型语言模型使用一个共享的主干网络来生成观察到的上下文的潜表征z,然后再把该表征送入到n个独立的头网络,以并行的方式预测每一个未来 token
多 token 预测的交叉熵损失可以分解为两部分:在给定 token 序列下的潜表征,以及在该潜表征条件下,预测n个未来 token
在实践中,该架构包括一个共享 Transformer 主干模型,根据上下文词序列来生成潜表征,n个独立的、基于 Transformer 层的输出头,以及一个共享的 unembedding 矩阵。
节省内存
在训练多 token 预测器时,一个关键问题是 GPU 显存占用过多。
在当前的大型语言模型(LLMs)中,词汇表的大小V通常远远大于潜在表示的维度d,因此 logit vectors 就成了 GPU 内存使用的瓶颈。
如果简单地实现多 token 预测器,将所有的 logit vectors 及其梯度都存储在内存中,会导致内存使用量迅速增加,因为每个向量的形状都是 (n, V),这种方式会极大地限制模型可同时处理的批次大小,并增加 GPU 显存的平均使用量。
研究人员提出了一种内存高效的实现方法,通过调整前向传播和反向传播操作的顺序来减少内存使用。
具体来说,在通过共享主干网络 fs 完成前向传播后,模型会按顺序对每个独立的输出头部 fi 执行前向和反向传播,并在主干网络处累积梯度,每个输出头部 fi 的 logit 向量和梯度在计算后就会被释放,无需一直占用内存,直到所有头部的计算完成。
这意味着,除了主干网络的梯度外,不需要长期存储其他任何梯度,从而显著降低了 GPU 内存的使用。
通过这种方式,模型的内存复杂度从O(nV+d)降低到了O(V+d),在不牺牲运行时间的情况下,显著减少了 GPU 的峰值内存使用。
推理阶段 Inference
在推理时,该模型的最基础用法是使用「下一个 token 预测头」(next-token prediction head)进行「基本 next-token 自回归预测」,同时丢弃所有其他头网络。
也可以利用额外的输出头网络进行自推理解码,对从下一个 token 预测头网络的解码进行加速:
1. 区块并行解码(blockwise parallel decoding),一种推理解码的变体方法,可以并行地预测多个 token,而不需要额外的草稿模型;
2. 使用类似美杜莎(Medusa)树注意力机制的推测解码,可以提高解码速度和效率。
实验结果
研究人员总共进行了七个大规模实验来证明多 token 预测损失的有效性。
为了公平对比 next-token 预测器和n-token 预测器,实验中的模型参数量均相同,也就是说,在预测未来头网络中添加n-1 层时,同时也会从共享模型主干中移除n-1 层。
1. 性能随模型尺寸增大而提升
为了研究模型尺寸的影响,研究人员从零开始训练了「六个」模型,尺寸范围覆盖了从 300M 到 13B 参数量,至少使用了 91B tokens 的代码。
从评估结果中可以看到,在 MBPP 和 HumanEval 上的实验表明,在相同的计算量下,使用多 token 预测,可以在固定数据集上获得更好的性能。
研究人员认为,该特性只有在大规模数据、大尺寸模型上才能体现出来,这也可能是多 token 预测一直没有在大型语言模型训练上广泛应用的原因。
2. 更快的推理速度
研究人员使用异构批量大小的 xFormers 实现贪婪自我推测解码(self-speculative decoding),并测量了最佳的4-tokens 预测模型(7B 参数)在补全代码和自然语言数据时的解码速度。
可以看到,该方法在代码生成任务上速度提升了 3.0 倍,文本生成的速度提升了 2.7 倍,在 8 字节预测模型上,推理速度提升了 6.4 倍。
使用多 token 预测进行预训练时,额外的头网络可以比单个 next-token 预测模型的微调更准确,从而让模型充分发挥自推测解码的全部潜力。
3. 用多字节预测来学习全局 pattern
为了展示 next-token 预测任务能够捕捉到局部模式,研究人员采取了极端情况,即字节级分词(byte-level tokenization),通过训练一个 7B 参数的字节级 Transformer 模型来处理 314B 个 byte,大约相当于 116B 个 tokens
8-byte 预测模型与 next-byte 预测相比取得了显著的性能提升,在 MBPP pass@1 上解决了超过 67% 的问题,在 HumanEval pass@1 上解决了 20% 的问题。
因此,多字节预测是一个非常有前景的方法,可以让字节级模型的训练更高效。
自推测解码可以实现 8 字节预测模型的 6 倍速度提升,完全弥补了在推理时「更长字节序列」的成本,甚至比 next-token 预测模型快近两倍。
尽管训练所用的数据量少了 1.7 倍,但 8 字节预测模型的性能仍然能接近基于 token 的模型。
4. 寻找最优的n值
为了更好地理解预测 token 数量的影响,研究人员在 7B 尺寸的模型(训练数据包含了 200B 个代码 token)上进行了全面的消融实验,在不同实验设置中尝试了 n = 1, 2, 4, 6 和8
实验结果显示,使用 4 个未来 token 进行训练时,在 HumanEval 和 MBPP 的所有 pass at 1, 10 和 100 指标上均超越了其他对比模型:MBPP 的改进分别为 +3.8%, +2.1% 和 +3.2%,HumanEval 的改进分别为 +1.2%, +3.7% 和 +4.1%
有趣的是,在 APPS/Intro 上,n = 6 时的性能提升分别为 +0.7%, +3.0% 和 +5.3%
最佳的窗口尺寸很可能取决于输入数据的分布。至于字节级模型,最佳窗口大小在基准测试中更为一致(8 字节)。
5. 多 epochs 训练
在进行机器学习模型训练时,多 tokens 训练方法在处理相同数据集的多个训练周期时,对于预测下一个 token 的任务仍然显示出了优势。
虽然随着训练周期的增加,优势略有下降,但在 MBPP 数据集上的 pass@1 指标上,仍然观察到了 2.4% 的提升;在 HumanEval 数据集上的 pass@100 指标上,提升更是达到了 3.2%
结果表明,即使在多次训练后,多 tokens 训练方法仍然能够带来一定的性能提升。
但对于 APPS/Intro 数据集来说,当训练 token 数量达到 200B 时,使用窗口大小为 4 的训练方法已经不再是最优的选择,可能需要调整窗口大小或采用其他策略来进一步提高模型性能。
6. 微调多 token 预测器
在机器学习领域,预训练模型通过多 token 预测损失函数进行训练,相较于传统的单 token 预测模型,该方法在后续的微调阶段展现出了更好的性能。
研究人员在 CodeContests 数据集上对具有 7B 参数的模型进行了微调测试,将一个能够预测接下来 4 个 token 的模型与基础的单 token 预测模型进行了比较,并尝试了一种将 4 tokens 预测模型去除额外预测头后,按照传统的单 token 预测目标进行微调的设置。
实验结果显示,在 pass@k指标上,无论采用哪种微调方式,4-tokens 预测模型的表现都超过了单 token 预测模型,也表明4-tokens 预测模型在理解任务、解决问题以及产生多样化答案的能力上更为出色。
实验结果还表明,在4-tokens 预测预训练的基础上进行单 token 预测微调,可能是一个综合性能最佳的策略,与先使用辅助任务进行预训练,然后进行特定任务微调的经典机器学习范式相吻合。
7. 在自然语言上的多 token 预测
研究人员训练了参数量为 7B 的模型,并使用了三种不同的预测损失方法:预测 4token、2-token 以及单个 token,并在 6 个标准的自然语言处理(NLP)基准测试中进行了性能评估。
在摘要任务中,使用了 8 个不同的基准测试,并通过 ROUGE 指标来自动评估生成文本的质量,结果显示,2-token 和4-token 的性能都比单 token 预测基线的表现更好。
参考资料: