0

0

DeepSeek用的GRPO占用大量内存?有人给出了些破解方法

DDD

DDD

发布时间:2025-02-07 18:00:16

|

927人浏览过

|

来源于php中文网

原创

rtx 3080 移动版训练大型语言模型的实用指南

本文旨在指导 GPU 资源受限的开发者如何利用 GRPO (Group Relative Policy Optimization) 训练大型语言模型。DeepSeek-R1 的发布使得 GRPO 成为强化学习训练大型语言模型的热门方法,因为它高效且易于训练。 GRPO 通过利用模型自身生成的训练数据进行迭代改进,目标是最大化生成文本的优势函数,同时保持模型与参考策略的接近性。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

图片

选择合适的模型大小和训练方法(全参数微调或参数高效微调 - PEFT)是训练的关键。本文作者 Greg Schoeninger (Oxen.ai CEO) 使用配备 16GB 显存的 RTX 3080 笔记本电脑进行实验,并分享了其经验。

图片原文链接:https://www.php.cn/link/61d8c968f0a66dcf2b05982bdccb484b}}

作者在使用 trl 库的 GRPO 实现时,遇到了显存不足 (OOM) 错误:

  1. torch.OutOfMemoryError: CUDA out of memory.

  2. Tried to allocate 1.90 GiB. GPU 0 has a total capacity of 15.73 GiB of which 1.28 GiB is free.

  3. Including non-PyTorch memory, this process has 14.43 GiB memory in use. Of the allocated memory 11.82 GiB is allocated by PyTorch, and 2.41 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

实验结果与内存需求分析

作者进行了一系列实验,测试不同模型大小(5亿到140亿参数)在 GSM8K 数据集上训练前 100 步的峰值内存使用情况,并比较了全参数微调和 PEFT 的内存需求。所有实验均在 Nvidia H100 上完成。

图片

使用的模型包括:

图片

GRPO 对内存需求高的原因在于其内部涉及多个模型(策略模型、参考模型、奖励模型)以及每个查询产生的多个输出。

图片

优化内存使用

8位优化器和梯度检查点技术可以有效减少内存占用。8位优化器更高效地存储优化器状态,而梯度检查点则通过在训练过程中拍摄快照来减少内存使用,虽然会降低训练速度。

代码示例

trl 库简化了 GRPO 的使用。以下代码示例展示了如何使用 trl 训练小型模型:

  1. import torch

  2. from datasets import load_dataset, Dataset

  3. from transformers import AutoTokenizer, AutoModelForCausalLM

  4. from trl import GRPOConfig, GRPOTrainer

  5. import re

  6. SYSTEM_PROMPT = """

  7. Respond in the following format:

  8. ...

  9. ...

  10. """

  11. def extract_hash_answer(text: str) -> str | None:

  12. if "####" not in text:

  13. return None

  14. return text.split("####")[1].strip()

  15. def get_gsm8k_questions(split = "train") -> Dataset:

  16. data = load_dataset('openai/gsm8k', 'main')[split]

  17. data = data.map(lambda x: {

  18. 'prompt': [

  19. {'role': 'system', 'content': SYSTEM_PROMPT},

  20. {'role': 'user', 'content': x['question']}

  21. ],

  22. 'answer': extract_hash_answer(x['answer'])

  23. })

  24. return data

  25. def extract_xml_answer(text: str) -> str:

  26. answer = text.split("")[-1]

  27. answer = answer.split("")[0]

  28. return answer.strip()

  29. def format_reward_func(completions, **kwargs) -> list[float]:

  30. """Reward function that checks if the completion has a specific format."""

  31. pattern = r"^\n.*?\n\n\n.*?\n\n$"

  32. responses = [completion[0]["content"] for completion in completions]

  33. matches = [re.match(pattern, r) for r in responses]

  34. return [0.5 if match else 0.0 for match in matches]

  35. def accuracy_reward_func(prompts, completions, answer, **kwargs) -> list[float]:

  36. """Reward function that extracts the answer from the xml tags and compares it to the correct answer."""

    Meku
    Meku

    AI应用和网页开发工具

    下载
  37. responses = [completion[0]['content'] for completion in completions]

  38. extracted_responses = [extract_xml_answer(r) for r in responses]

  39. return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

  40. def main():

  41. dataset = get_gsm8k_questions()

  42. model_name = "meta-llama/Llama-3.2-1B-Instruct"

  43. model = AutoModelForCausalLM.from_pretrained(

  44. model_name,

  45. torch_dtype=torch.bfloat16,

  46. attn_implementation="flash_attention_2",

  47. device_map=None

  48. ).to("cuda")

  49. tokenizer = AutoTokenizer.from_pretrained(model_name)

  50. tokenizer.pad_token = tokenizer.eos_token

  51. training_args = GRPOConfig(

  52. output_dir="output",

  53. learning_rate=5e-6,

  54. adam_beta1=0.9,

  55. adam_beta2=0.99,

  56. weight_decay=0.1,

  57. warmup_ratio=0.1,

  58. lr_scheduler_type='cosine',

  59. logging_steps=1,

  60. bf16=True,

  61. per_device_train_batch_size=1,

  62. gradient_accumulation_steps=4,

  63. num_generations=4,

  64. max_prompt_length=256,

  65. max_completion_length=786,

  66. num_train_epochs=1,

  67. save_steps=100,

  68. save_total_limit=1,

  69. max_grad_norm=0.1,

  70. log_on_each_node=False,

  71. )

  72. trainer = GRPOTrainer(

  73. model=model,

  74. processing_class=tokenizer,

  75. reward_funcs=[

  76. format_reward_func,

  77. accuracy_reward_func

  78. ],

  79. args=training_args,

  80. train_dataset=dataset,

  81. )

  82. trainer.train()

  83. if __name__ == "__main__":

  84. main()

trl 项目地址:https://www.php.cn/link/ccb8dbcf2c004cbbae8858760e4a22fa

超参数调整与VRAM估算

num_generations 超参数会显著影响 VRAM 消耗。建议在内存瓶颈解决前使用 num_generations=4

图片

GitHub 问题讨论:https://www.php.cn/link/3057aa0acb6d937295819f3d94f015e9

其他影响 VRAM 的因素包括 batch_sizegradient_accumulation_stepsmax_prompt_lengthmax_completion_length 和 LoRA 的 target_modules

图片

最后,作者分享了 10 亿参数 Llama 3.2 模型的训练结果,展示了 GRPO 在提高模型准确率方面的潜力。

通过合理的参数设置和优化技术,即使使用资源有限的 RTX 3080 移动版 GPU,也能有效训练大型语言模型。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
github中文官网入口 github中文版官网网页进入
github中文官网入口 github中文版官网网页进入

github中文官网入口https://docs.github.com/zh/get-started,GitHub 是一种基于云的平台,可在其中存储、共享并与他人一起编写代码。 通过将代码存储在GitHub 上的“存储库”中,你可以: “展示或共享”你的工作。 持续“跟踪和管理”对代码的更改。

722

2026.01.21

http与https有哪些区别
http与https有哪些区别

http与https的区别:1、协议安全性;2、连接方式;3、证书管理;4、连接状态;5、端口号;6、资源消耗;7、兼容性。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

2068

2024.08.16

Python 自然语言处理(NLP)基础与实战
Python 自然语言处理(NLP)基础与实战

本专题系统讲解 Python 在自然语言处理(NLP)领域的基础方法与实战应用,涵盖文本预处理(分词、去停用词)、词性标注、命名实体识别、关键词提取、情感分析,以及常用 NLP 库(NLTK、spaCy)的核心用法。通过真实文本案例,帮助学习者掌握 使用 Python 进行文本分析与语言数据处理的完整流程,适用于内容分析、舆情监测与智能文本应用场景。

10

2026.01.27

拼多多赚钱的5种方法 拼多多赚钱的5种方法
拼多多赚钱的5种方法 拼多多赚钱的5种方法

在拼多多上赚钱主要可以通过无货源模式一件代发、精细化运营特色店铺、参与官方高流量活动、利用拼团机制社交裂变,以及成为多多进宝推广员这5种方法实现。核心策略在于通过低成本、高效率的供应链管理与营销,利用平台社交电商红利实现盈利。

109

2026.01.26

edge浏览器怎样设置主页 edge浏览器自定义设置教程
edge浏览器怎样设置主页 edge浏览器自定义设置教程

在Edge浏览器中设置主页,请依次点击右上角“...”图标 > 设置 > 开始、主页和新建标签页。在“Microsoft Edge 启动时”选择“打开以下页面”,点击“添加新页面”并输入网址。若要使用主页按钮,需在“外观”设置中开启“显示主页按钮”并设定网址。

16

2026.01.26

苹果官方查询网站 苹果手机正品激活查询入口
苹果官方查询网站 苹果手机正品激活查询入口

苹果官方查询网站主要通过 checkcoverage.apple.com/cn/zh/ 进行,可用于查询序列号(SN)对应的保修状态、激活日期及技术支持服务。此外,查找丢失设备请使用 iCloud.com/find,购买信息与物流可访问 Apple (中国大陆) 订单状态页面。

138

2026.01.26

npd人格什么意思 npd人格有什么特征
npd人格什么意思 npd人格有什么特征

NPD(Narcissistic Personality Disorder)即自恋型人格障碍,是一种心理健康问题,特点是极度夸大自我重要性、需要过度赞美与关注,同时极度缺乏共情能力,背后常掩藏着低自尊和不安全感,影响人际关系、工作和生活,通常在青少年时期开始显现,需由专业人士诊断。

7

2026.01.26

windows安全中心怎么关闭 windows安全中心怎么执行操作
windows安全中心怎么关闭 windows安全中心怎么执行操作

关闭Windows安全中心(Windows Defender)可通过系统设置暂时关闭,或使用组策略/注册表永久关闭。最简单的方法是:进入设置 > 隐私和安全性 > Windows安全中心 > 病毒和威胁防护 > 管理设置,将实时保护等选项关闭。

6

2026.01.26

2026年春运抢票攻略大全 春运抢票攻略教你三招手【技巧】
2026年春运抢票攻略大全 春运抢票攻略教你三招手【技巧】

铁路12306提供起售时间查询、起售提醒、购票预填、候补购票及误购限时免费退票五项服务,并强调官方渠道唯一性与信息安全。

122

2026.01.26

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号