文章目录
- 前言:显存焦虑症
- 一、推理 (Inference) 显存开销
- 二、全量训练 (Full Training) 显存开销
- 三、微调 (LoRA/QLoRA) 显存开销
- 四、终极速查表 (Cheat Sheet)
- 五、避坑指南
前言:显存焦虑症
做大模型(LLM)开发,最令人绝望的报错不是代码逻辑错误,而是那行冰冷的:CUDA Out of Memory (OOM)。
无论是自己部署本地知识库,还是尝试微调一个垂直领域的模型,“显存到底够不够”永远是第一个要计算的问题。
很多同学有一个误区:“7B 的模型文件只有 14GB,为什么我 24G 的 3090 跑训练还是直接爆显存?”
这篇文章将从底层原理出发,详细拆解推理、全量微调、LoRA微调三种场景下的显存计算公式,并附带 7B 和 70B 模型的实战估算。建议收藏备用!
一、推理 (Inference) 显存开销
推理是相对最省显存的环节。显存占用主要由两部分组成:静态的权重和动态的 KV Cache。
1. 模型权重 (Model Weights)
这是“入场券”,模型加载进显存就需要占用的空间。取决于模型的参数量和精度。
通用公式:
权重显存 ≈ 参数量(B) × 精度字节数 \text{权重显存} \approx \text{参数量(B)} \times \text{精度字节数}权重显存≈参数量(B)×精度字节数
- FP16 / BF16 (主流): 每个参数 2 Bytes。
- INT8 量化: 每个参数 1 Byte。
- INT4 量化: 每个参数 0.5 Bytes。
举例 (7B 模型):
- FP16:7 × 2 = 14 GB 7 \times 2 = 14 \text{ GB}7×2=14GB
- INT4:7 × 0.5 = 3.5 GB 7 \times 0.5 = 3.5 \text{ GB}7×0.5=3.5GB
2. KV Cache (隐形杀手)
这是推理时的动态瓶颈。随着Batch Size (并发数)和Sequence Length (上下文长度)的增加,显存线性暴涨。这也是为什么长文本模型推理特别吃显存。
计算公式:
KV Cache = 2 × 层数 × 隐藏层维度 × 序列长度 × Batch Size × 精度字节数 \text{KV Cache} = 2 \times \text{层数} \times \text{隐藏层维度} \times \text{序列长度} \times \text{Batch Size} \times \text{精度字节数}KV Cache=2×层数×隐藏层维度×序列长度×Batch Size×精度字节数
Tips:
Llama-3、Qwen-2 等新模型采用了GQA (Grouped Query Attention)技术,能将 KV Cache 的显存占用降低 4-8 倍,极大缓解了长文本压力。
二、全量训练 (Full Training) 显存开销
训练之所以比推理“贵”那么多,是因为我们需要存储大量的中间状态来支持反向传播。
1. 显存占用的“四大金刚”
在混合精度训练(FP16/BF16)+ AdamW 优化器的标准设定下,显存被以下四部分瓜分:
- 模型权重 (Model Weights): FP16 格式。
- 梯度 (Gradients): 对应每个参数的梯度,FP16 格式。
- 优化器状态 (Optimizer States):显存大户!AdamW 需要存动量(Momentum)和方差(Variance),且为了精度通常用 FP32 存储。
- 占用:约12 Bytes / 参数(包含 FP32 的主权重备份)。
- 激活值 (Activations): 前向传播的中间结果。与 Batch Size 和 序列长度 成正比。
2. 估算公式
训练总显存 ≈ 静态部分 ( 16 × Φ ) + 动态激活值 \text{训练总显存} \approx \text{静态部分}(16 \times \Phi) + \text{动态激活值}训练总显存≈静态部分(16×Φ)+动态激活值
其中Φ \PhiΦ是模型参数量。
- 静态部分: 权重(2) + 梯度(2) + 优化器(12) =16 Bytes / 参数。
- 动态部分: 需预留 20%-30% 显存给激活值(取决于
Context Length)。
残酷的现实:
训练一个 7B 模型,起步就要:7 × 16 = 112 GB 7 \times 16 = 112 \text{ GB}7×16=112GB显存。
单张 A100 (80G) 都跑不动全量微调!必须上多卡或 DeepSpeed Zero-3。
三、微调 (PEFT: LoRA & QLoRA) 显存开销
对于个人开发者和中小企业,PEFT (Parameter-Efficient Fine-Tuning)是唯一的出路。
1. LoRA (Low-Rank Adaptation)
- 原理: 冻结主模型,只训练旁路 Adapter。
- 显存:
- 主模型权重(FP16):2 × Φ 2 \times \Phi2×Φ
- 优化器状态:极小(只针对 Adapter,忽略不计)。
- 激活值:依然很大!因为要做前向传播。
- 估算: 约为推理显存的1.5 倍。
2. QLoRA (Quantized LoRA) —— 省显存的神
- 原理: 主模型用 4-bit (NF4) 加载并冻结。
- 显存:
- 主模型权重(INT4):0.5 × Φ 0.5 \times \Phi0.5×Φ
- LoRA 参数 + 优化器:少许。
- 激活值:通过 Gradient Checkpointing 技术大幅压缩。
- 估算: 7B 模型仅需6-8 GB显存即可微调!
四、终极速查表 (Cheat Sheet)
假设Context Length = 4096,Batch Size = 1(训练时)。
(注:数据为估算值,实际受框架 Overhead 影响可能波动 10-20%)
| 模型规模 | 场景 | 精度/方法 | 显存需求估算 | 推荐硬件 |
|---|---|---|---|---|
| 7B | 推理 | INT4 | ~6 GB | RTX 3060 / 4060 |
| 7B | 推理 | FP16 | ~15 GB | RTX 3090 / 4060 Ti (16G) |
| 7B | 微调 | QLoRA | ~8 GB | RTX 3060 / 2080 Ti |
| 7B | 微调 | LoRA (FP16) | ~24 GB | RTX 3090 / 4090 |
| 7B | 训练 | 全量 (Full) | ~120 GB | 2x A100 (80G) |
| 72B | 推理 | INT4 | ~42 GB | 2x 3090 / 1x A6000 |
| 72B | 推理 | FP16 | ~150 GB | 2x A100 (80G) |
| 72B | 微调 | QLoRA | ~48 GB | 2x 3090 / 2x 4090 |
五、避坑指南
- 一定要开 Gradient Checkpointing: 训练时显存不够,第一时间开这个。它能用“时间换空间”,显存占用通常能直接减半(牺牲 20% 训练速度)。
- Flash Attention 是标配: 尤其是现在的模型窗口越来越大(32k, 128k),不开 Flash Attention 2,显存分分钟爆掉。
- DeepSpeed ZeRO-2 Offload: 只有一张消费级显卡(如 12G 的 3060)想跑大一点的模型?开启 DeepSpeed 的 Offload 功能,把优化器状态踢到 CPU 内存里,能救命。
- Batch Size 的陷阱: 微调时如果显存紧张,把 Batch Size 设为 1,然后把
Gradient Accumulation Steps设大(比如 16 或 32),效果是一样的,但显存占用极低。
觉得有用的话,欢迎点赞收藏,防止下次 OOM 找不到解决办法!🚀