破解ChatGPT惊人耗电!DeepMind新算法训练提效13倍,能耗暴降10倍

  新智元报道

  编辑:桃子乔杨

  ChatGPT 能耗惊人,该怎么解?谷歌 DeepMind 新算法 JEST 问世,让 LLM 训练的迭代次数降低 13 倍,计算量减少 10 倍,或将重塑 AI 未来。

  ChatGPT 早已成为世界耗能大户:一天用掉超 50 万度电,相当于 1.7 万个美国家庭的用电量!

  然而,大模型对能源的吞噬,远不仅如此。

  国际能源署(IEA)预测,从 2022 年到 2026 年,数据中心的用电量将翻一番。

  随着 AI 计算需求的膨胀,还需要用水来冷却计算系统。研究称,微软用水量从 2021 年到 22 年飙升了 34%,ChatGPT 每处理5-50 个提示就会消耗接近半升水。

  针对这种现状,我们有更好的解决策略吗?

  最近,谷歌 DeepMind 研究团队提出了一种加快 AI 训练的新方法——多模态对比学习与联合示例选择(JEST),大大减少了所需的计算资源和时间。

  JEST 以 13 倍更少的迭代次数,以及 10 倍更少的计算量,超越了最先进的模型!

  论文地址:https://arxiv.org/pdf/2406.17711

  预训练的参考模型,已经学习了什么样的数据是有「优质的」或「有用的」。然后通过模型,来引导数据选择那些精心筛选过的小型数据集。

  这一发现揭示了,数据筛选水平可以作为评判 Scaling Law 的一个新维度。

  网友激动表示,「我没想到这么快就会发生。模型能够自主选择训练数据的能力是巨大的,因为它使训练变得显著更容易,你不再需要猜测什么是高质量的训练数据,你有一个能够『理解』什么样的数据对自身学习最有价值的模型」。

  前谷歌、苹果软件工程师称赞道,这项研究非常令人印象深刻。

  从「超级 batch」中筛选数据

  无论是语言、视觉还是多模态模型,数据质量是预训练性能的重要驱动因素。比如 Phi-3、Gemma 2 等模型的成功让我们看到了,更少、更高质量的数据有可能实现更强大的性能。

  要筛选出高质量的数据,数据管道的建立就成为重要的工作。现有的方法大体可以分为两种:1)手动管理 2)基于模型的数据管理,用正在训练模型的特征选择高质量数据。

  前者成本高昂且难以扩展,后者则有望为多模态 LLM 实现 Scaling Law。

  然而,现有方法忽略了一个事实。

  如果仅在单个数据点的层面进行筛选,就没有考虑到数据集以及 batch 的总体组成。毕竟,训练数据是以 batch 为单位,数据点之间的依赖性不可忽视。

  许多计算机视觉的研究都曾表明,hard negatives(表达空间中相近但标签不同的样本)相比可被平凡解的数据簇,能提供更有效的学习信号。

  那么如何让模型以 batch 为单位筛选数据呢?

  论文提出的 JEST 算法正是要解决这个问题,原理很好理解:就是直接从「超级 batch」中筛选出「子 batch」。

  技术介绍

  用数学语言来描述这个问题,就是从大小为B的「超级 batch」中提取出与学习最相关的子 batch ℬ={,∈[1,…,]}⊂,过滤比率可以写作=1−/。

  之前的优先采样(prioritized sampling)会使用基于模型的评分函数对每个数据点打分,再按比例采样。JEST 则直接对整个子 batch 评分,再按照 batch 级别的分数采样。

  一种最直观的启发式方法就是在现有模型参数 : hard⁢(ℬ)=ℓ⁢(ℬ) 中,直接选择损失值最高的 batch,这种方法可被称之为「硬学习」(hard learner)。

  这种方法具有丢弃琐碎数据的理想属性,已被证明适用于小型、干净的数据集;然而对于较大、较少管理的数据集往往弊大于利,因为它依旧会采样到噪声数据。

  另一种方法常用于多模态,使用具有参数 ∗:^easy⁢(ℬ∗)=−ℓ⁢(ℬ∗) 的参考模型为预训练模型采样数据。但作者依旧否定了这个方案,因为它无法直接反映模型当前的状态,可能过度依赖参考模型的选择,而且不易于扩展。

  最后,论文选择借鉴 ICML 2022 年的一篇论文中提到的方法,将上述两方面的评分结合起来:^learn⁢(ℬ,∗)=hard⁢(ℬ)+^easy⁢(ℬ∗)=ℓ⁢(ℬ)−ℓ⁢(ℬ∗),并将这种启发式方法称为「可学习性评分」(learnability score)。

  其中,batch 上的损失值ℓ⁢(ℬ)是各数据点之和,使用 sigmoid 对比损失函数计算(sigmoid-contrastive loss),因为相比 softmax 对比损失而言,它的扩展性更强。

  由于 batch 上的对比损失可以分解为每个样本的条件损失之和,因此可学习性评分可被分解为单个样本可学习性评分⁢(,∗,ℬ)之和,写作:

  使用的顺序采样方法则受到了 block Gibbs 采样的启发。在第n次迭代、对第B_n个 batch 进行采样时,依据如下概率公式对块{X_k}进行无替换采样:

  将X_k块添加到B_n中来更新当前采样的 batch,直至迭代数n=N时终止。算法的总体流程如下图所示:

  实验中发现,使用迭代数N=16 且每次迭代时独立采样b/N=2048 个样本时,就足以恢复出学习性非常高的 batch。

  可学习性评分中涉及到使用参考模型为数据点打分,之前的方法惯常使用额外的小型模型,但这会增加每次迭代的计算成本,降低总体 FLOP 效率增益。

  因此论文使用了在线模型近似的方法以及效率较高的 FlexiViT 架构,只使用降低分辨率的 32×32 的 patch 来评估「超级 batch」,与全分辨率、patch 大小为 16×16 的方法相比减少了 72% 的 FLOP,以及 67% 的挂钟时间(wall-clock time)。

  此外,论文还提出了进行多分辨率训练的技巧。将每个 batch 随机分成两半,使用不同分辨率编码后再拼接起来,提升了评分过程和训练的效率。

  下图详细描述了全分辨率 JEST 和多分辨率 Flexi-JEST 方法的伪代码实现。

  所有 JEST 实验都在 WebLI 数据集上运行,包含经过宽松过滤的十亿规模的英语图像-文本对,参考模型的训练则使用其中经过高质量过滤 100M 大小的子集(被称为 WebLI-curated)。

  在 WebLI 的基础上,作者还额外从网络上抓取了 6 亿个文本-图像对并经过同样强度的过滤,组成 WebLI-curated++数据集训练参考模型,拓展出 JEST++/FlexiJEST++方法,来探索对数据管理的扩展。

  论文所报告的平均性能包括 4 个多模态规范基准:ImageNet 0-Shot 和 10-Shot 分类以及 COCO 图像到文本和文本到图像的 top-1 检索。

  实验结果

  图 1 中可以看到,使用 JEST 或 FlexiJEST 方法的最明显优势就是效率提升。

  左图中,相比原有的 SigLIP 基线模型,JEST++可以在训练数据量减少 13.1×的情况下达到相同准确率。即使考虑到额外引入的打分成本,也有近 10×的 FLOP 效率提升(中图)。

  右图展现了 JEST++/FlexiJEST++(绿色)与先前方法(灰色)的比较,相比 CLIP、EVA-CLIP 经典模型实现了计算成本和性能的双重提升。

  左图和中图的平均准确率由 8 个下游任务得出,右图性能由 ImageNet 和 COCO 基准测试得出

  产生可学习 batch

  研究人员首先评估了 JEST 在选择可学习 batch 方面的效果。

  为了直观地理解这一方法,作者们先将可学习性矩阵进行可视化,即学习模型和参考模型之间,对 batch 中所有示例对的损失差异。

  JEST 就是按照示例子矩阵的可学习性总和比例进行采样。

  由于矩阵明显非对角关系(图2,左),独立选择显然是次优的。

  经过少量迭代(对应于用N=16 个块填充 batch),作者发现子 batch 的可学习性快速增加,达到了需要数千次迭代的暴力吉布斯采样(Gibbs sampling )所提取 batch 的可学习性(图2,中)。

  对于 0.5、0.8 和 0.9 的过滤比例,他们从大小分别为 65,536、163,840 和 327,680 的超级 batch 中选择 32,768 个示例的子 batch。

  在图 2 右侧,研究者还发现子 batch 的可学习性随着更大的过滤比例而增加。

  总之,JEST 算法是在训练过程中选择高度可学习 batch 的有效,且高效的方法。

  加速多模态学习

  接下来,研究人员使用 JEST 算法选择的可学习 batch,检验训练模型的效果。

  所有实验都使用在 WebLI-curated 上训练的参考模型,这是一个 ViT-B/16 和 Bert-B 图像-文本双编码器,30 亿训练样本,采用 sigmoid 对比损失函数。

  图3(左)显示了在训练过程中多个下游任务(ImageNet 0-Shot/10-Shot 准确率和 COCO 图像到文本/文本到图像检索)的平均性能。

  结果还发现,JEST 显著加速了学习过程。

  在使用 50%、80% 和 90% 的过滤比例时,分别只需 20 亿、10 亿和 6.7 亿训练样本就达到了 30 亿均匀基准的最终性能。

  在更大的过滤比例下,坐着观察到类似于更大 batch size 时的训练不稳定性,需要修改 Adam 优化器(β2 = 0.95)以稳定训练,这表明 JEST 的数据筛选可以被视为增加了有效 batch size。

  在最终性能方面,当过滤 90% 的数据时,JEST 也带来了高达6% 的显著提升(图3,中间,蓝色曲线)。

  值得注意的是,这种 scaling 行为这种性能提升在独立样本选择方法中,并没有观察到。(图3,中间,橙色曲线)。

  最后,研究者还评估 JEST 是否也改善了,除可学习性之外的其他优先标准。

  图 3 右侧显示了使用 easy-reference 优先选择的模型在不同过滤比例下的性能。

  与基于可学习性的优先选择一致,JEST 仍优于独立样本选择,特别是在高过滤比例下(在这种情况下,独立样本选择导致性能下降)。

  优先选择具有最高损失的数据产生了较小的收益,并且随着过滤更多数据而更快地退化(图 10)。

  由于基于可学习性的 JEST 产生了最佳的 scaling 行为,研究人员在后续实验中保留了这一标准。

  多分辨率训练和在线 batch 选择之间的协同效应

  随着数据 batch 中被过滤的比例增加,基于可学习性评分的 JEST 变得更加高效。

  然而,评分的成本会带来显著的提升:过滤超级 batch 80% 的数据会导致每次迭代的浮点运算量是 IID 训练的 4 倍,或者在缓存参考模型得分时是 2.3 倍。

  尽管 JEST 在训练迭代次数方面(以下简称「训练效率」)显著提高了效率,但额外的评分浮点运算降低了其相对于 IID 基准的计算效率(图1,左 vs 右)。

  因此,作者还研究了一种计算效率更高的变体,称为 Flexi-JEST,它使用多分辨率训练和低分辨率评分,将总开销降低到仅比基准高 10%(图4,左)。

  这些近似方法对性能有什么影响?

  正如预期的那样,Flexi-JEST 的每次迭代性能相对于 JEST 有所下降,但仍然比 IID 有显著的加速(图1,左;图4,中)。

  然而,考虑到总浮点运算量的减少,每次迭代性能的下降是非常有利的:最好的 Flexi-JEST 模型与 40B Siglip 运行产生相同的平均性能,但浮点运算量减少了 9.9 倍,比全分辨率 JEST 少 2 倍(图1,右;图4,中)。

  这些实验表明了多分辨率训练和联合示例选择之间的协同效应,前者为加速后者提供了高效和准确的评分能力。

  实验结果,还指出了数据策划策略的帕累托前沿(pareto front)。

  如果以计算为代价来最大化训练速度或训练效率,全分辨率 JEST 方法相对于可比的 IID 训练运行,可以产生高达 13 倍的加速。

  实现强大数据质量引导

  可学习性评分的核心是,一个在人类选择的小型、精心筛选的数据集上,训练的参考模型。

  JEST 的性能如何随不同的筛选策略(在质量和数量之间权衡)而变化?

  此外,JEST 训练的改进是否与参考模型的性能相关,还是这些指标是分离的?

  理解质量与数量的权衡

  研究人员探索了三种规模的数据筛选,每种都是原始 WebLI 数据集的一个子集:

  - 弱筛选(十亿级规模):使用图像-文本对齐(ITA)过滤器。

  - 中度筛选(3 亿级规模):使用 ITA 过滤器或文本质量(TQ)过滤器。

  - 强筛选(1 亿级规模):结合使用 TQ、ITA 和额外的图像质量(aesthetic)过滤器。

  在整个过程中,作者将这个强筛选子集称为「WebLI-curated」。

  然后,他们在这四个 WebLI 子集上,各训练 10 个 epoch 的标准 SigLIP 编码器,并将它们用作在全 WebLI 数据集上进行 JEST 训练的参考模型。

  在不同的数据筛选方法中,参考模型的性能和 JEST 的性能似乎是解耦的(甚至可能是反相关的;图5,左)。

  虽然增加筛选(和减少数据集大小)会产生较弱的模型,但当它们被用作 JEST 预训练的参考模型时,却产生了相反的效果:

使用强筛选参考模型的 JEST 获得了 2.7% 的改进,中度筛选获得了 1.5% 的改进,弱筛选获得了 0.3% 的改进。

  扩展数据筛选

  假设参考模型性能与 JEST 性能之间的普遍解耦,可能仅仅是由数据筛选所施加的数据集大小限制造成的。

  为了理解这种效果,研究人员在 WebLI-curated 上训练了 5 个参考模型,同时改变所见的总样本数(从 2.5 亿到 30 亿)。

  在这种情况下,图5(右)显示了改进的参考模型与更好的 JEST 预训练之间存在着显著的相关性。

  这表明「解耦」现象主要可以归因于参考模型因筛选后数据集大小减少而导致的饱和。

  此外,研究人员还注意到,当数据集达到饱和时,图5(右)中的相关性开始崩解,即在 10 个 epoch 或者看到 10 亿个样本之后。

  这些结果表明,JEST 可能会从进一步扩大参考数据集的数据筛选中获益。

  鉴于使用 WebLI-curated++对数据进行扩展整理能显著提高参考模型的性能,作者提出了是否有必要在原始 WebLI 数据集上进行预训练的问题。

  然而,在评估参考模型在不同数据集上的性能时,却发现:虽然它在 2 个下游任务上的性能优于 WebLI 预训练,但在其他 6 个任务上的性能,以及平均性能都明显低于 WebLI 预训练(表 5)。

  与现有数据比较

  最后,论文应用 JEST++ 在公开的 LAION-2B 数据集上进行预训练,删除了其中不安全的图像-文本对,但没有进行其他的预先过滤。

  这个数据规模相比的 SOTA 方法 DBP 减少了4×,但 JEST++ 依旧远远超过了所有之前的离线数据管理方法。

  简化数据管理

  之前提到过,用于预训练的 WebLI-curated 是原始数据集 WebLI 过滤后得到的,以求筛选出高质量的图像-文本对齐的数据。

  如表 3 所示,这种离线数据管理流程对 IID(独立同分布)训练方法的性能至关重要,但 JEST++ 则表现出了对预过滤流程的鲁棒性。即使没有过滤,JEST++的性能也没有出现明显下滑,降低了模型对基础数据集的要求。

  结论和局限性

  总体来说,JEST 方法展现出了「数据质量引导」(data quality bootstrapping)方法的巨大潜力,即使用小规模精选数据集来指导对更大的、未经管理的数据集的学习。

  最近的研究表明,在下游任务未知时,静态数据集的过滤会限制模型性能。这篇论文的结果则表明,相比单独选择样本的方法,在线构建 batch 能提高预训练的效率。

  无论是使用 JEST 参考模型对数据集进行预评分,还是通过可学习性评分来根据模型需求进行动态调整,都可以成为通用基础数据集的更有效率的替代方案。

  论文的最后,作者也提出了该方法的局限性。虽然 JEST 同时实现了性能增益和训练成本降低,但依旧依赖于小型、精心管理的参考数据集,它指定了未经管理的更大数据集中优先考虑的分布。

  因此,未来的工作可以探索一种方法,从指定的下游任务中如何推断出参考数据集的组成和分布。

  参考资料:

  1. https://www.reddit.com/r/singularity/comments/1dw7xnf/google_deepminds_jest_method_can_reduce_ai/
  2. https://decrypt.co/238730/new-ai-training-technique-is-drastically-faster-says-google