GaLore及BAdam实现低显存全量微调

4/27/2024 GaLoreBAdam8-bit Adam Optimizer降低显存

# 1. 前言

# 1.1 背景介绍

LLM 训练通常需要比较大的显存,主要是模型权重和优化器状态。节约显存常见的方法有 LoRA,然而其往往用于微调阶段,或需要满秩热启动,导致预训练依旧需要很大的显存。梯度低秩投影(GaLore)这是一种允许全参数学习的训练策略,但比常见的 LoRA 等方案更省显存。可以减少多达 65.5% 的显存。此方案可以在更省显存的同时基本不影响模型效果,但是训练时间会变得很长。

与 8-bit Adam 结合,8-bit GaLore 可以进一步减少高达 82.5% 的优化器内存和 63.3% 的总训练内存。甚至实现了在 24GB 显存的消费级 GPU(如 NVIDIA RTX 4090)上训练 7B 模型,而无需模型并行,Checkpointing 和 Offload 策略。

# 1.2 GaLore技术概述

GaLore 将低秩投影应用到模型训练的梯度上,可以大幅节约显存占用,为消费级显卡全量微调训练大模型提供了一种可能。

  • 梯度低秩投影(GaLore)是一种全量参数学习的训练策略,但比常见的低秩自适应方法(如LoRA)更节省显存。其关键思想是利用权重矩阵 W 的梯度缓慢变化的低秩结构,而不是试图将权重矩阵本身近似为低秩。
  • 作为一种梯度投影方法,GaLore 与优化器的选择无关,只需两行代码即可轻松插入现有优化器,GaLore目前实现了GaLoreAdamW, GaLoreAdamW8bit, GaLoreAdafactor 三种优化器。

项目地址:https://github.com/jiaweizzhao/GaLore (opens new window)

论文地址:GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection (opens new window)

GaLore

注:目前GaLore仅支持单GPU训练,该技术还处在开发阶段,官方说正式版将会支持多GPU训练。

# 1.3 其他的类似技术

# 1.3.1 BAdam技术

基本介绍:BAdam的核心思想是依次求解块坐标优化子问题。从实现的角度来看,该算法在参数的一小部分(通常是一个变压器层)上运行 Adam 的更新,因此与全参数 Adam 微调相比,需要的显存要少得多。使用 BAdam 只需要对原始代码进行一行修改。

# 1.3.2 8-bit Adam Optimizer技术

基本思想:一种对 optimizer 进行量化的方法,在不修改超参,不影响模型精度的情况下,把 adam / momentum 的状态量量化至 int8,从而缓解训练时的显存压力。原始的 adam 优化器对于每个参数都需要 m 和 v 两个 fp32 的参数,相当于每 1B 的参数都需要 8G 的存储空间,占了整体的很大一部分。所以如果能够把 optimizer state 量化下来,就能适当缓解显存的压力。

# 2. 准备测试环境

# 2.1 租用GPU服务器

实验环境:租用的AutoDL的GPU服务器,NVIDIA RTX 4090D / 24GB,Ubuntu20.04,Python 3.10, CUDA 11.8,数据盘额外扩容了100GB

由于这家的服务器都是境内的,拉取Github代码和HuggingFace模型都会受到墙的干扰,建议配置一下代理。

$ source /etc/network_turbo
1

# 2.2 安装基础环境

安装conda环境

$ curl -O https://repo.anaconda.com/archive/Anaconda3-2019.03-Linux-x86_64.sh   // 从官网下载安装脚本
$ bash Anaconda3-2019.03-Linux-x86_64.sh           // 阅读协议确认安装,安装完成后再输入yes以便不需要手动将Anaconda添加到PATH
$ conda create -n fine_tuning_env python=3.10      // 安装虚拟环境,fine_tuning_env是给虚拟环境起的别名(任意即可)
$ source /root/miniconda3/etc/profile.d/conda.sh   // conda初始化
$ conda activate fine_tuning_env                   // 激活虚拟环境
1
2
3
4
5

安装其他版本的CUDA/cuDNN

$ conda search cudatoolkit
$ conda install cudatoolkit==11.8.0
$ conda list cudatoolkit
$ conda search cudnn --channel nvidia
$ conda install cudnn=8.9.2.26
$ conda list cudnn
1
2
3
4
5
6

注:默认镜像都内置了最原生的CUDA和cuDNN,如果您自己安装了cudatoolkits等,那么一般会默认优先使用conda中安装的cudatoolkits。

# 2.3 下载模型文件

由于全量微调需要大量的计算资源,即便用了GaLore也是用时间来换取的,因此这里使用了参数量较小的 Qwen1.5-0.5B 作为实验模型。

安装huggingface_hub依赖:

$ pip3 install huggingface_hub
1

download_model.py

# -*- coding: utf-8 -*-

import os
from huggingface_hub import snapshot_download

# 模型仓库的标识
repo_id = "Qwen/Qwen1.5-0.5B"

# 下载模型到指定目录
local_dir = "/root/autodl-tmp/Qwen-1.5-0.5B"

# 检查目录是否存在,如果不存在则创建
if not os.path.exists(local_dir):
    os.makedirs(local_dir)

snapshot_download(repo_id=repo_id, local_dir=local_dir)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

# 3. 单卡全量微调

# 3.1 LLaMA-Factory项目

# 3.1.1 LLaMA-Factory基本介绍

LLaMA-Factory 是一个易于使用的大模型微调框架,旨在简化大型语言模型的微调过程,提供了一套完整的工具和接口,使得用户能够轻松地对预训练的模型进行定制化的训练和调整,以适应特定的应用场景。

LLaMA-Factory具备以下特性:

  • 多种模型:LLaMA,LLaVA,Mistral,Mixtral-MoE,Qwen,Yi,Gemma,Baichuan,ChatGLM,Phi。
  • 集成方法:(Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO and ORPO。
  • 多种精度:32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA,2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8。
  • 先进算法:GaLore,BAdam,DoRA,LongLoRA,LLaMA Pro,Mixture-of-Depths,LoRA+,LoftQ,Agent tuning。
  • 实用技巧:FlashAttention-2,Unsloth,RoPE scaling,NEFTune,rsLoRA。
  • 实验监控:LlamaBoard,TensorBoard,Wandb,MLflow。
  • 极速推理:OpenAI风格的API,Gradio UI和CLI、vLLM支持。

LLaMA-Factory项目已经集成支持了GaLore、BAdam等先进的优化算法。

LLaMA-Factory支持GaLore及BAdam

# 3.1.2 准备微调代码及显存监控脚本

拉取代码并安装依赖:

$ git clone https://github.com/hiyouga/LLaMA-Factory.git
$ cd /root/LLaMA-Factory
$ pip3 install -r requirements.txt
1
2
3

该项目下的 data 目录自带了大量的开源数据集,以下均采用 oaast_sft_zh.json 数据集进行测试。

oaast_sft_zh测试数据集

由于需要监控微调过程的显存占用,这里简单写了个Python脚本去实现。

$ pip3 install nvidia-ml-py matplotlib
1

monitor.py

# -*- coding: utf-8 -*-

import pynvml
import matplotlib.pyplot as plt
import datetime
import time
import signal

# 初始化NVML
pynvml.nvmlInit()
# 获取第一个GPU的句柄
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
memory_usage = []
times = []
running = True

def signal_handler(sig, frame):
    global running
    running = False

# 注册信号处理器,以便于接收到Ctrl+C时能够停止收集数据
signal.signal(signal.SIGINT, signal_handler)

print("开始监控GPU显存使用情况,按Ctrl+C停止...")

# 收集数据,直到接收到停止信号
try:
    while running:
        info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        memory_used = round(info.used / 1024 ** 2, 2)
        memory_usage.append(memory_used)
        times.append(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
        time.sleep(1)
finally:
    pynvml.nvmlShutdown()

# 计算最大值
max_usage = max(memory_usage, default=0)
max_index = memory_usage.index(max_usage) if memory_usage else -1

# 创建图表
plt.figure(figsize=(10, 5))
plt.plot(times, memory_usage, label='Memory Usage (MB)', linestyle='-', color='gray')
# 突出显示开始、最大值、结束的点
highlight_indices = [0, max_index, len(memory_usage) - 1]
highlight_times = [times[i] for i in highlight_indices]
highlight_usages = [memory_usage[i] for i in highlight_indices]
plt.scatter(highlight_times, highlight_usages, color='red', s=100, zorder=5)
# 添加特定时间点的标注
plt.annotate(f'Start: {memory_usage[0]:.2f} MB', (times[0], memory_usage[0]),
             textcoords="offset points", xytext=(0,10), ha='center', va='bottom')
plt.annotate(f'Max: {max_usage:.2f} MB', (times[max_index], max_usage),
             textcoords="offset points", xytext=(0,10), ha='center', va='bottom')
plt.annotate(f'End: {memory_usage[-1]:.2f} MB', (times[-1], memory_usage[-1]),
             textcoords="offset points", xytext=(0,-15), ha='center', va='top')

# 设置x轴仅显示开始、最大、结束的时间点
plt.xticks(highlight_times, rotation=45)
plt.legend()
plt.tight_layout()

# 保存图表
plt.savefig('/root/autodl-tmp/gpu_memory_usage.png')
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

# 3.2 三种方式单卡全量微调

# 3.2.1 不使用优化算法单卡全量微调

$ cd /root/LLaMA-Factory/examples/full_multi_gpu
1

这里不使用 DeepSpeed 分布式训练技术,修改 single_node.sh 脚本的内容如下,并执行。

#!/bin/bash

CUDA_VISIBLE_DEVICES=0 python3 ../../src/train_bash.py \
    --stage sft \
    --do_train \
    --model_name_or_path /root/autodl-tmp/Qwen-1.5-0.5B \
    --dataset oaast_sft_zh \
    --dataset_dir ../../data \
    --template default \
    --finetuning_type full \
    --output_dir /root/autodl-tmp/Qwen-1.5-0.5B/full \
    --overwrite_cache \
    --overwrite_output_dir \
    --cutoff_len 1024 \
    --preprocessing_num_workers 16 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --warmup_steps 20 \
    --save_steps 100 \
    --eval_steps 100 \
    --evaluation_strategy steps \
    --learning_rate 5e-5 \
    --num_train_epochs 3.0 \
    --max_samples 3000 \
    --val_size 0.1 \
    --plot_loss \
    --fp16
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

训练过程中的显存占用:最大显存 11.64GB

不使用优化算法单卡全量微调的显存占用

训练结束的日志信息:训练耗时 6min44s

不使用优化算法单卡全量微调的日志

# 3.2.2 使用GaLore单卡全量微调

$ cd /root/LLaMA-Factory/examples/extras/galore
$ pip3 install galore_torch
1
2

修改 sft.sh 脚本的内容如下,并执行。

#!/bin/bash

CUDA_VISIBLE_DEVICES=0 python3 ../../../src/train_bash.py \
    --stage sft \
    --do_train \
    --model_name_or_path /root/autodl-tmp/Qwen-1.5-0.5B \
    --dataset oaast_sft_zh \
    --dataset_dir ../../../data \
    --template default \
    --finetuning_type full \
    --use_galore \
    --galore_layerwise \
    --galore_target mlp,self_attn \
    --galore_rank 128 \
    --galore_scale 2.0 \
    --output_dir /root/autodl-tmp/Qwen-1.5-0.5B/galore_full \
    --overwrite_cache \
    --overwrite_output_dir \
    --cutoff_len 1024 \
    --preprocessing_num_workers 16 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --warmup_steps 20 \
    --save_steps 100 \
    --eval_steps 100 \
    --evaluation_strategy steps \
    --load_best_model_at_end \
    --learning_rate 5e-5 \
    --num_train_epochs 3.0 \
    --max_samples 3000 \
    --val_size 0.1 \
    --plot_loss \
    --pure_bf16
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

训练过程中的显存占用:最大显存 8.24GB

使用GaLore单卡全量微调的显存占用

训练结束的日志信息:训练耗时 7min48s

使用GaLore单卡全量微调的训练日志

# 3.2.3 使用BAdam单卡全量微调

$ cd /root/LLaMA-Factory/examples/extras/badam
$ pip3 install badam
1
2

修改 sft.sh 脚本的内容如下,并执行。

#!/bin/bash

CUDA_VISIBLE_DEVICES=0 python3 ../../../src/train_bash.py \
    --stage sft \
    --do_train \
    --model_name_or_path /root/autodl-tmp/Qwen-1.5-0.5B \
    --dataset oaast_sft_zh \
    --dataset_dir ../../../data \
    --template default \
    --finetuning_type full \
    --use_badam \
    --badam_switch_mode descending \
    --badam_switch_block_every 50 \
    --badam_verbose 2 \
    --output_dir /root/autodl-tmp/Qwen-1.5-0.5B/badam_full \
    --overwrite_cache \
    --overwrite_output_dir \
    --cutoff_len 1024 \
    --preprocessing_num_workers 16 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --warmup_steps 20 \
    --save_steps 100 \
    --eval_steps 100 \
    --evaluation_strategy steps \
    --load_best_model_at_end \
    --learning_rate 5e-5 \
    --num_train_epochs 3.0 \
    --max_samples 3000 \
    --val_size 0.1 \
    --plot_loss \
    --pure_bf16
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

训练过程中的显存占用:最大显存 8.30GB

使用BAdam单卡全量微调的显存占用

训练结束的日志信息:训练耗时 2min56s

使用BAdam单卡全量微调的训练日志

# 4. 参考资料

[1] GaLore:梯度低秩投影,消费级显卡训练 LLaMA-7B from 微信公众号 (opens new window)

[2] GaLore:通过梯度低秩投影进行内存高效的 LLM 训练 from Github (opens new window)

[3] BAdam:一种用于大型语言模型的内存高效全参数训练方法 from Github (opens new window)

[4] LLaMA-Factory:大模型微调框架 from Github (opens new window)

Last Updated: 5/2/2024, 11:38:13 AM