0

0

微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100

碧海醫心

碧海醫心

发布时间:2024-10-12 14:17:07

|

825人浏览过

|

来源于机器之心

转载

随着 AI 模型的参数量越来越大,对算力的需求也水涨船高。

比如最近,Llama-3.1 登上了最强开源大模型的宝座,但超大杯 405B 版本的内存就高达 900 多 GB,这对算力构成了更加苛刻的挑战。

如何降低算力的使用成本和使用门槛,已经成为许多公司寻求突破的关键。Felafax 就是其中的一家创业公司,致力于简化 AI 训练集群的搭建流程。
微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100
                       Nikhil Sonti 和 Nikhin Sonti 创立了 Felafax,他们的口号是在构建开源 AI 平台,为下一代 AI 硬件服务,将机器学习的训练成本降低 30%。

与英伟达相比,AMD 的 GPU,尤其是 MI300X 系列,提供了更高的性价比,按每美元计算,其性能表现更为出色。

最近,Felafax 的联合创始人 Nikhil Sonti 发布了一篇博客,详细分享了如何通过 8 张 AMD MI300X GPU 和 JAX 微调 LLaMA 3.1 405B 模型的方法,所有代码现已开源。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

微调大模型,amd mi300x就够了!跟着这篇博客微调llama 3.1 405b,效果媲美h100

Github 链接:https://github.com/felafax/felafax

本站对博客内容进行了不改变原意的编译、整理,以下是博客内容:

JAX 尤其适合非英伟达硬件

JAX 是一个强大的机器学习库,结合了类似 NumPy 的 API、自动微分功能以及 Google 的 XLA 编译器。它在模型并行化方面提供了优秀的 API,因此非常适合像 LLaMA 3.1 405B 这样的超大模型训练。

在使用 AMD 硬件时,JAX 有几个明显的优势:

  • 多硬件并行支持:JAX 采用 XLA(加速线性代数)编译器,将计算编译为硬件无关的中间表示(HLO),这意味着同样的 JAX 代码无需修改便可高效运行在不同硬件后端,包括 AMD GPU。
  • 独立于底层硬件:XLA 编译器的优化策略是通用的,不针对某个特定的硬件平台。这使得任何支持 XLA 的硬件设备(如 CPU、GPU、TPU)都能受益于这些优化,获得更好的性能表现。
  • 极高的适应性:从 NVIDIA 转移到 AMD(或其他硬件)时,JAX 只需做极少的代码改动。而相较之下,PyTorch 与英伟达的 CUDA 生态系统紧密耦合,迁移过程相对复杂。

因此,JAX 成为了我们在非英伟达硬件上的最佳选择。

拉取 Docker 镜像:
docker pull rocm/jax:latest
启动 Docker 容器:
# Pull the Docker Image:docker pull rocm/jax:latest # Start the Docker Container:docker run -it -w /workspace --device=/dev/kfd --device=/dev/dri --group-add video \--cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G rocm/jax:latest# Verify the Installation: python3 -c 'import jax; print(jax.devices())'
验证安装
python3 -c 'import jax; print (jax.devices ())'
训练使用了一个配备了 8 张 AMD MI300x GPU 的 AMD 节点。每张 MI300x 拥有 192GB 的 HBM3 内存,性能表现与最新的英伟达 H100 GPU 相比非常出色。
微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100
                              与英伟达 H100 的比较,来源:TensorWave

训练 LLaMA 405B:性能与可扩展性

使用 JAX,可以成功地在 AMD GPU 上训练 LLaMA 405B 模型。我们使用 LoRA 微调,将所有模型权重和 LoRA 参数都设为 bfloat16,LoRA rank 设为 8,LoRA alpha 设为 16:

  • 模型大小:LLaMA 模型的权重占用了约 800GB 的显存。
  • LoRA 权重 + 优化器状态:大约占用了 400GB 的显存。
  • 显存总使用量:占总显存的 77%,约 1200GB。
  • 限制:由于 405B 模型的规模过大,batch 大小和序列长度的空间有限,使用的 batch size 为 16,序列长度为 64。
  • JIT 编译:由于空间限制,无法运行 JIT 编译版本;它可能需要比急切模式稍多的空间。
  • 训练速度:使用 JAX 急切模式,约为 35 tokens / 秒。
  • 内存效率:稳定在约 70% 左右。
  • 扩展性:在 8 张 GPU 上,使用 JAX 的扩展性接近线性。

由于硬件和显存的限制,我们无法运行 JIT 编译版本的 405B 模型,整个训练过程是在 JAX 的急切模式下执行的,因此还有很大的进步空间。 

下图中显示了在一次微调训练步骤中,8 张 GPU 的显存利用率和 rocm-smi 输出:

GPU 利用率:
微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100
显存利用率:
微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100
rocm-smi 输出:

微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100

训练设置 

将 LLaMA 3.1 从 PyTorch 移植到 JAX 
微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100
此前,Nikhil Sonti 分享过如何将 LLaMA 3.1 从 PyTorch 移植到 JAX。他指出,目前 90% 的大型语言模型(LLM)都运行在 NVIDIA GPU 上,但实际上还有一些同样强大且性价比更高的替代方案。例如,在 Google TPU 上训练和部署 Llama 3.1 的成本比 NVIDIA GPU 低约 30%。

然而,支持非 NVIDIA 硬件的开发工具较为匮乏。Sonti 最初尝试使用 PyTorch XLA 在 TPU 上训练 Llama 3.1,但过程并不顺利。XLA 与 PyTorch 的集成不够完善,缺少一些关键的库(如 bitsandbytes 无法正常运行),同时还遇到了一些难以解决的 HuggingFace 错误。

为此,他决定调整策略,将 Llama 3.1 从 PyTorch 移植到 JAX,成功解决了这些问题。Sonti 还录制了详细的教程视频,并开源了所有代码:

微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100

扣子编程
扣子编程

扣子推出的AI编程开发工具

下载
  • 方法演示:https://dub.sh/felafax-demo
  • 代码仓库:https://github.com/felafax/felafax

加载模型,并把模型参数分片
 
处理像 LLaMA 405B 这样的超大模型,需要在多个设备之间高效地进行参数分片。以下是如何通过 JAX 实现这一点的。

在 JAX 中进行参数分片

为了将巨大的 LLaMA 405B 模型高效地分布到 8 张 AMD GPU 上,需要使用 JAX 的设备网格(device mesh)功能。

部署代码:https://github.com/felafax/felafax/blob/e2a96a0e207e1dc70effde099fe33a9e42a7d5cb/llama3_jax/trainer_engine/jax_utils.py#L69

JAX 的设备网格可以帮助我们把可用的设备组织成一个网格,让我们可以指定如何把模型的参数和计算分配到不同的 GPU 上。

在本文的设置中,需要创建一个形状为(1, 8, 1)的网格,并将轴分别命名为数据并行(dp)、全分片数据并行(fsdp)和模型并行(mp)。然后,为模型的每个张量定义特定的分片规则,指定这些维度如何沿着这些网格轴进行分片。
DEVICES = jax.devices () DEVICE_COUNT = len (DEVICES) DEVICE_MESH = mesh_utils.create_device_mesh ((1, 8, 1)) MESH = Mesh (devices=DEVICE_MESH, axis_names=("dp", "fsdp", "mp"))
可视化分片

可以使用以下代码来可视化分片结果,从而方便地验证分片规则是否按预期应用。
jax.debug.visualize_array_sharding 
分片规则

模型不同组件的分片规则如下所示:

  • 参数如何分片:

参数要在 8 个 GPU 之间分配。例如,LM head(lm_head/kernel)张量有两个轴,按照 PS ("fsdp", "mp") 进行分片。在本例中是 8 和 1,因此可以看到该张量在第一个轴上沿着 8 个 GPU 被拆分。

  • Non-Replicated 参数:

没有任何分片规范的参数会在所有设备上进行复制。例如,层归一化(attention_norm/kernel 和 ffn_norm/kernel)没有设置分片规范,是 PS (None)。

应用分片函数
 
在加载模型时,使用以下分片函数逐步对模型权重进行分片:
def make_shard_and_gather_fns (partition_specs):def make_shard_fn (partition_spec):out_sharding = NamedSharding (mesh, partition_spec)def shard_fn (tensor):return jax.device_put (tensor, out_sharding).block_until_ready ()return shard_fnshard_fns = jax.tree_util.tree_map (make_shard_fn, partition_specs)return shard_fns# Create shard functions based on partitioning rulesshard_fns = make_shard_and_gather_fns (partitioning_rules)

这使得我们能够将每个参数放置在指定的设备上,并按照设定的分片进行处理。

分片训练 Batch

最初,训练 Batch 是正常创建的,但在输入模型之前,需要按照下面的代码在 GPU 上进行分片:
train_batch = jax.device_put ( train_batch,NamedSharding (self.mesh, PS ("dp", "fsdp")))

在这里,我们指定训练 Batch 应该在 "dp" 和 "fsdp" 轴上进行分片,在本例中分别对应于被分成 1 和 8 份,如果把结果可视化出来,如下所示:

分片前:
微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100
在调用  jax.device_put 之后:
微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100
加入 LoRA

LoRA 通过将权重更新分解为低秩矩阵,减少了可训练参数的数量,这对于微调大型模型特别有效。以下是在 AMD GPU 上微调 Llama 3.1-405 的 LoRA 的要点:

  • 将 LoRA 参数(lora_a 和 lora_b)与主模型参数分开。
  • 使用 jax.lax.stop_gradient (kernel) 来防止对主模型权重的更新。
  • 使用 lax.dot_general 进行快速、精确控制的矩阵运算。
  • LoRA 输出在添加到主输出之前会被缩放为 (self.lora_alpha/self.lora_rank)。

LoRADense 层

在此设定一个自定义的 LoRADense 层,该层集成了 LoRA 参数:
class LoRADense (nn.Module):features: intlora_rank: int = 8lora_alpha: float = 16.0@nn.compactdef __call__(self, inputs: Any) -> Any:# Original kernel parameter (frozen)kernel = self.param ('kernel', ...)y = lax.dot_general (inputs, jax.lax.stop_gradient (kernel), ...)# LoRA parameters (trainable)lora_a = self.variable ('lora_params', 'lora_a', ..., ...)lora_b = self.variable ('lora_params', 'lora_b', ..., ...)# Compute LoRA outputlora_output = lax.dot_general (inputs, lora_a.value, ...)lora_output = lax.dot_general (lora_output, lora_b.value, ...)# Combine original output with LoRA modificationsy += (self.lora_alpha/self.lora_rank) * lora_outputreturn y.astype (self.dtype)

分片 LoRA 参数

为了高效地在设备之间分配 LoRA 参数,我们也通过 JAX 设定了分片规则,这确保了 LoRA 参数与主模型参数的分片一致,优化了内存使用和计算效率。
LoRA A matrices (lora_a)

LoRA A 矩阵(lora_a)

  • 分片规则:PS ("fsdp", "mp")
  • 可视化结果:如下图所示,lora_a 参数被分片为 (8, 1),这意味着第一个轴在 8 个设备上进行分片("fsdp" 轴),而第二个轴未进行分片。
微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100
LoRA B 矩阵(lora_b)

  • 分片规则:PS ("mp", "fsdp")
  • 可视化结果:如下图所示,lora_b 参数被分片为 (1, 8),这意味着第二个轴在 8 个设备上进行分片(fsdp 轴),而第一个轴未进行分片。
微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100
这种分片策略优化了参数的分配,减少了通信开销,并在训练过程中增强了并行性。它确保每个设备仅持有一部分 LoRA 参数,使得大模型如 LLaMA 405B 的高效扩展成为可能。

仅更新 LoRA 参数 

为了优化训练,在微调 LLaMA 405B 模型,只计算 LoRA 参数的梯度,保持主模型参数不变。这个方法减少了内存使用,并加速了训练,因为只更新较少的参数。可以移步 GitHub 仓库,查看实现细节。

在训练过程中,每一步都涉及将一批输入数据通过模型进行处理。由于只有 LoRA 参数是可训练的,因此模型的预测和计算的损失仅依赖于这些参数,然后对 LoRA 参数进行反向传播。只更新这些参数简化了训练过程,使得在多个 GPU 上高效微调像 LLaMA 405B 这样的大型模型成为可能。

更多研究细节,请参考原博客。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
github中文官网入口 github中文版官网网页进入
github中文官网入口 github中文版官网网页进入

github中文官网入口https://docs.github.com/zh/get-started,GitHub 是一种基于云的平台,可在其中存储、共享并与他人一起编写代码。 通过将代码存储在GitHub 上的“存储库”中,你可以: “展示或共享”你的工作。 持续“跟踪和管理”对代码的更改。

875

2026.01.21

k8s和docker区别
k8s和docker区别

k8s和docker区别有抽象层次不同、管理范围不同、功能不同、应用程序生命周期管理不同、缩放能力不同、高可用性等等区别。本专题为大家提供k8s和docker区别相关的各种文章、以及下载和课程。

257

2023.07.24

docker进入容器的方法有哪些
docker进入容器的方法有哪些

docker进入容器的方法:1. Docker exec;2. Docker attach;3. Docker run --interactive --tty;4. Docker ps -a;5. 使用 Docker Compose。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

499

2024.04.08

docker容器无法访问外部网络怎么办
docker容器无法访问外部网络怎么办

docker 容器无法访问外部网络的原因和解决方法:配置 nat 端口映射以将容器端口映射到主机端口。根据主机兼容性选择正确的网络驱动(如 host 或 overlay)。允许容器端口通过主机的防火墙。配置容器的正确 dns 服务器。选择正确的容器网络模式。排除主机网络问题,如防火墙或连接问题。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

401

2024.04.08

docker镜像有什么用
docker镜像有什么用

docker 镜像是预构建的软件组件,用途广泛,包括:应用程序部署:简化部署,提高移植性。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

438

2024.04.08

pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

433

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

24

2025.12.22

http与https有哪些区别
http与https有哪些区别

http与https的区别:1、协议安全性;2、连接方式;3、证书管理;4、连接状态;5、端口号;6、资源消耗;7、兼容性。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

2081

2024.08.16

俄罗斯Yandex引擎入口
俄罗斯Yandex引擎入口

2026年俄罗斯Yandex搜索引擎最新入口汇总,涵盖免登录、多语言支持、无广告视频播放及本地化服务等核心功能。阅读专题下面的文章了解更多详细内容。

158

2026.01.28

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
最新Python教程 从入门到精通
最新Python教程 从入门到精通

共4课时 | 22.3万人学习

Python 教程
Python 教程

共137课时 | 7.7万人学习

Java 教程
Java 教程

共578课时 | 52.6万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号