大模型“取长补短”新思路入选NeurIPS'24,显著优于现有路由方法

  RouterDC 团队投稿

  量子位公众号 QbitAI

  高效组合多个大模型“取长补短”新思路,被顶会 NeurIPS 2024 接收。

  名为 RouterDC,是一种基于双重对比学习的路由架构,具有参数高效性(小于 100M 的参数)和计算高效性(不需要对于 LLM 进行梯度回传)的优势。

  在具有挑战性语言理解、代码生成和数学推理等推理任务实验中,RouterDC 在分布内(+2.76%)和分布外(+1.90%)设定下,都远超于现有的 routing 方法。

  众所周知,LLM 通常在不同数据集上预训练和微调,导致它们在不同任务上的性能强弱不同。

  LLM 路由则是一种组合多个 LLM 的新思路,它通过学习一个路由器(Router)来为每一个请求(query)选择最合适的 LLM。在推理时,LLM 路由只需要调用所选的 LLM 进行推理,使其在保持计算高效性的同时利用多个 LLM 的互补能力。

  RouterDC 这种新方法,包括一个较小的语言模型作为编码器和一系列与候选 LLM 对应的可学习的LLM embeddings

  对于训练数据中的每个 query,首先将候选 LLM 的预测与真实标签进行比较获得表现最好和最差的 LLM,然后构造两个对比损失:

  • sample-LLM 对比损失:使得 query embedding(由编码器提取)与表现最佳的 LLM embeddings 相似,同时与表现最差的 LLM embeddings 不相似。
  • sample-sample 对比损失:提高训练的稳定性,将所有训练 query 聚类成多个组,最大化同组 query 之间的相似性的同时最小化不同组 query 之间的相似性。

  这项研究由来自南方科技大学,香港科技大学的研究团队提出,以下是更为详细的介绍。

  双对比学习实现 Router 训练

  Router 架构

  如图 1 所示,RouterDC 包括一个较小的语言模型(mDeBERTaV3-base)作为编码器ε,和一系列的与候选 LLM 对应的可学习 LLM 嵌入 kT。对于每个 query xi,RouterDC 生成对于T个 LLMs 的选择概率如下:

  其中,sim (·,·)表示 cosine 相似度。

  △图1:RouterDC 方法示意图

  sample-LLM 对比损失

  为了训练 router,研究者将 query 的样本嵌入和在其上表现最好的K+ 个 LLM 对应嵌入拉进,和在其上表现最差的K-个 LLM 对应嵌入拉远。因此,样本-LLM 对比损失可以表示为:

  sample-sample 对比损失

  研究者通过实验发现,在 routing 问题中只使用样本-LLM 对比损失并不稳定,使得相似的 query 可能具有不相似的嵌入。

  为了提升训练的鲁棒性,训练样本被聚类成不同的组,从而在训练中拉近同一个组内的样本,拉远不同组的样本。和样本-LLM 对比损失类似,样本-样本对比损失可以公式化为:

  训练及推理

  最终的优化目标为最小化样本-LLM 对比损失和样本-样本对比损失的结合:

  推理时,每个测试 query 只需要通过训练好的 router 选取概率最大的 LLM,并使用选择的 LLM 对 query 进行回答。

  RouterDC 在训练时不需要任何经过 LLM 的梯度回传,并且在推理时只需要调用进行一次 LLM,同时具有训练和推理的高效性。

  实验效果如何?

  主要结果

  RouterDC 在分布内数据集的测试准确率结果如表 1 所示。可以发现:

  RouterDC 显著好于最优的单个模型,平均具有 3.98% 性能提升。在单个任务的层面,RouterDC 在三个任务上相比表现最优的单个模型取得了准确率的提升,其中 GSM8K 提升了 0.51%,ARC-C 提升了 0.57%,HumanEval 提升了 1.63%。

  和现有路由方法 CosineClassifier 以及 ZOOTER 对比,RouterDC 在所有任务上都具有更好的表现。和 LoraRetriever 对比,RouterDC 具有平均 2.77% 的准确率提升。

  △表1:分布内任务的测试准确率(%)

  为了评估 RouterDC 的泛化能力,表 2 展示了 RouterDC 在三个分布外数据集(PreAlgebra,MBPP,C-EVAL)的测试准确率。

  可以看出,RouterDC 再次达到最高的测试准确率,显著超过表现最佳的单个 LLM(dolphin-2.9-llama3-8b)1.9%。

  △表2:分布外任务的测试准确率(%)

  sample-sample 损失的作用

  为了探究样本-样本损失的作用,图 3 展示了在是否有样本-样本损失的条件下训练和测试准确率曲线。可以看出,RouterDC(w/o Lsample-sample)有明显的震荡现象,而 RouterDC 则稳定得多。

  △图2:RouterDC 在 GSM8K 任务上的训练和测试准确率曲线

  图3(a)可视化了使用 RouterDC(w/o Lsample-sample)提取的训练样本的 TSNE 特征,可以看到,属于不同任务的训练样本粗略地混合在一起。而在结合 Lsample-sample 之后,训练样本有了清晰的聚类结构(如图3(b)所示)。

  △图3:学习到的 router 所提取出训练样本 embedding 的t-SNE 可视化

  RouterDC 具有成本高效性

  由于价格(cost)同样是一个评估 LLM 的重要指标,研究者通过 RouterBench 上的两个任务的实验来格外考虑 cost 的影响。如图 16 所示,RouterDC 相比于 CosineClassifier 和 ZOOTER 更加的成本高效。

  △图4:在 RouterBench 上使用不同的 Cost 获取的测试准确率

  论文地址:https://arxiv.org/abs/2409.19886

  代码地址:https://github.com/shuhao02/RouterDC