对OpenAI的ChatGPT大模型进行微调

8/24/2023 OpenAIChatGPT大模型微调

# 1. 前言

# 1.1 基本介绍

2023 年 8 月 22 日,OpenAI宣布支持对ChatGPT大模型进行微调了。OpenAI的基座模型本身已经非常出色,通过微调,即使样本数量很少也可以获得良好效果,使得其他模型难以竞争。然而,数据安全性的问题仍未解决,用户数据最终会流向OpenAI,这对于安全性要求高的用户来说是个问题,因此训练本地私有化模型仍然有必要。

微调不仅可以提供更高质量的结果,还可以训练更多范例、节省Token和降低请求延迟。GPT模型通过prompt有效使用,而微调则进一步改善了少样本学习,实现了更好的结果。微调的过程包括准备和上传训练数据、训练新的微调模型和使用微调模型,从而节约成本并实现低延迟请求。

ChatGPT大模型支持微调的介绍

# 1.2 哪些模型可以进行微调

目前可以对三个型号进行微调,包括gpt-3.5-turbo-0613(推荐)、babbage-002和davinci-002。

  • 其中,gpt-3.5-turbo被期望成为大多数用户在结果和易用性方面的正确模型,除非用户需要迁移旧的微调模型。
  • 未来也将支持对 GPT-4 进行微调,预计该功能将于今年晚些时候推出。

目前哪些ChatGPT大模型支持微调

# 1.3 什么情况需要微调

微调GPT模型可以使它们更适合特定应用,但需要谨慎投入时间和精力。官方推荐首先尝试通过prompt工程、prompt链(将复杂任务分解为多个提示)和函数调用来获得良好的结果,主要原因包括:

  • 有许多任务,模型最初可能表现不佳,但通过更好的提示可以取得更好的效果,可能不需要微调。
  • 与微调相比,通过提示和其他策略进行迭代的反馈循环要快得多,微调需要创建数据集和运行训练任务。
  • 在仍需要微调的情况下,最初的提示工程工作并未浪费 —— 通常在微调数据中使用良好的提示(或结合提示链/工具使用和微调)时会看到最佳结果。
  • GPT最佳实践指南 (opens new window) 提供了一些在不进行微调的情况下获得更好性能的最有效策略和技巧的背景知识。在playground中快速迭代提示可能会有所帮助。

ChatGPT大模型什么时候适合微调

# 1.4 微调可以处理哪些问题

通过微调可以处理以下场景:

  • 设定风格、语气、格式或其他质量方面的要求
  • 提高产生所需输出的可靠性
  • 纠正无法遵循复杂提示的错误
  • 以特定方式处理许多边缘情况
  • 执行难以用prompt准确表达的新技能或任务

通过微调,可以使用更短的prompt来完成任务,从而减少token数,降低时延与成本。此外,由于GPT-4的效果优越但成本较高,可以使用GPT-4的结果作为训练数据来微调ChatGPT,进一步实现GPT-4的蒸馏。

ChatGPT大模型微调可以解决哪些问题

# 2. 准备微调数据集

确定微调是合适的解决方案后,需要为训练模型准备数据。应创建一组多样化的示例对话,与使用时模型响应的对话相似。

# 2.1 数据格式

微调gpt-3.5-turbo需要准备的数据格式如下:

{
    "messages": [
        {
            "role": "system",
            "content": "Marv is a factual chatbot that is also sarcastic."
        },
        {
            "role": "user",
            "content": "What's the capital of France?"
        },
        {
            "role": "assistant",
            "content": "Paris, as if everyone doesn't know that already."
        }
    ]
}{
    "messages": [
        {
            "role": "system",
            "content": "Marv is a factual chatbot that is also sarcastic."
        },
        {
            "role": "user",
            "content": "Who wrote 'Romeo and Juliet'?"
        },
        {
            "role": "assistant",
            "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"
        }
    ]
}{
    "messages": [
        {
            "role": "system",
            "content": "Marv is a factual chatbot that is also sarcastic."
        },
        {
            "role": "user",
            "content": "How far is the Moon from Earth?"
        },
        {
            "role": "assistant",
            "content": "Around 384,400 kilometers. Give or take a few, like that really matters."
        }
    ]
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

babbage-002和davinci-002需要准备的数据格式如下:

{
    "prompt": "<prompt text>",
    "completion": "<ideal generated text>"
}{
    "prompt": "<prompt text>",
    "completion": "<ideal generated text>"
}{
    "prompt": "<prompt text>",
    "completion": "<ideal generated text>"
}
1
2
3
4
5
6
7
8
9
10

# 2.2 构建指令

构建指令最关键的就是要确保训练数据是最好的,最满足自己当前需求的;要把这部分高质量的数据尽可能都加到训练集中,尤其是自己手头准备的训练样本并不是很多的时候,这样可以尽最大可能让训练出的模型满足自己的需求。

另外就是前面说的如果你想通过缩短prompt来减少inference成本,那么模型就需要更多的训练样本去学习到被“省略”部分的指令。

这里举个例子:假设当前在做一个根据小说内容生成一句评论的场景,目前的要求有(1)评论必须有搞笑风格(2)评论必须带有表情符号(3)评论字数必须在10-20个字(4)评论必须要口语化一些。那么我们调用GPT4的时候prompt就可以这么写:

假设你是一个小说评论生成器,生成的评论必须满足以下要求:
(1)评论必须有搞笑风格
(2)评论必须带有表情符号
(3)评论字数必须在10-20个字
(4)评论必须要口语化一些
小说内容如下:
****
1
2
3
4
5
6
7

通过使用这个指令调用GPT4我们可以很好的得到满足自己需求的response(假设调用ChatGPT不能满足需求),这样我们就有了训练样本啦,然后用这里的样本去微调ChatGPT就可以蒸馏了GPT4啦。

同时看到当前的场景每个样本开头都有这一段相似的指令:

假设你是一个小说评论生成器,生成的评论必须满足以下要求:
(1)评论必须有搞笑风格
(2)评论必须带有表情符号
(3)评论字数必须在10-20个字
(4)评论必须要口语化一些
1
2
3
4
5

如果我们在微调的时候把这段去掉同样可以达到需求的话,那么在inference的时候就可以大大减少token数了(尤其是当指令非常多的时候),极端下我们在微调的时候训练样本的prompt就只是小说内容,response还是GPT4的回复。但是这种情况就需要准备更多的样本,让模型去学习到那些被忽略的隐式“指令”。

# 2.3 需要多少条训练样本

最少需要10条训练样本。OpenAI观察通常使用50-100条样本来微调后,ChatGPT就可以发生明显的变化。

官方建议精心准备50条样本来微调,然后观察模型是否显示出改善的迹象。一般来说可能已经能用了,但即使模型还不满足需求,如果看到已经明显改变的迹象了,那剩下的就是沿着当前准备数据的思路提供更多数据即可。如果没有看到改变的迹象,这可能意味着需要重新考虑如何为模型设置任务或重新构建指令数据。

ChatGPT微调需要多少训练样本

# 2.4 划分训练集与测试集

官方要求划分训练和测试集,这样在微调过程中其会帮统计一些指标变化,方便辅助看训练的效果。

# 2.5 Token限制

每条训练 Token 总数最大为4096,也就是说需要确保构建的样本最好在4000以内,如果超出会被自动截断。

可以使用该官方脚本来计算样本的 Token 数。How to count tokens with tiktoken (opens new window)

计算Token限制的官方脚本

# 2.6 微调成本

估算微调成本,详见官方的定价页面:https://openai.com/pricing (opens new window)

ChatGPT微调成本

# 2.7 检查数据集格式

在准备完数据后,最后可以再检查一下数据格式是否正确,以免启动训练的时候发生错误,为此官方也给了一个检查脚本,可以使用它来查找潜在错误、检查Token计数并估计微调作业的成本。<YOUR_JSON_FILE_HERE>

# We start by importing the required packages

import json
import tiktoken
import numpy as np
from collections import defaultdict

# Next, we specify the data path and open the JSONL file

data_path = "<YOUR_JSON_FILE_HERE>"

# Load dataset
with open(data_path) as f:
    dataset = [json.loads(line) for line in f]

# We can inspect the data quickly by checking the number of examples and the first item

# Initial dataset stats
print("Num examples:", len(dataset))
print("First example:")
for message in dataset[0]["messages"]:
    print(message)

# Now that we have a sense of the data, we need to go through all the different examples and check to make sure the formatting is correct and matches the Chat completions message structure

# Format error checks
format_errors = defaultdict(int)

for ex in dataset:
    if not isinstance(ex, dict):
        format_errors["data_type"] += 1
        continue

    messages = ex.get("messages", None)
    if not messages:
        format_errors["missing_messages_list"] += 1
        continue

    for message in messages:
        if "role" not in message or "content" not in message:
            format_errors["message_missing_key"] += 1

        if any(k not in ("role", "content", "name") for k in message):
            format_errors["message_unrecognized_key"] += 1

        if message.get("role", None) not in ("system", "user", "assistant"):
            format_errors["unrecognized_role"] += 1

        content = message.get("content", None)
        if not content or not isinstance(content, str):
            format_errors["missing_content"] += 1

    if not any(message.get("role", None) == "assistant" for message in messages):
        format_errors["example_missing_assistant_message"] += 1

if format_errors:
    print("Found errors:")
    for k, v in format_errors.items():
        print(f"{k}: {v}")
else:
    print("No errors found")

# Beyond the structure of the message, we also need to ensure that the length does not exceed the 4096 token limit.

# Token counting functions
encoding = tiktoken.get_encoding("cl100k_base")

# not exact!
# simplified from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):
    num_tokens = 0
    for message in messages:
        num_tokens += tokens_per_message
        for key, value in message.items():
            num_tokens += len(encoding.encode(value))
            if key == "name":
                num_tokens += tokens_per_name
    num_tokens += 3
    return num_tokens

def num_assistant_tokens_from_messages(messages):
    num_tokens = 0
    for message in messages:
        if message["role"] == "assistant":
            num_tokens += len(encoding.encode(message["content"]))
    return num_tokens

def print_distribution(values, name):
    print(f"\n#### Distribution of {name}:")
    print(f"min / max: {min(values)}, {max(values)}")
    print(f"mean / median: {np.mean(values)}, {np.median(values)}")
    print(f"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}")

# Last, we can look at the results of the different formatting operations before proceeding with creating a fine-tuning job:

# Warnings and tokens counts
n_missing_system = 0
n_missing_user = 0
n_messages = []
convo_lens = []
assistant_message_lens = []

for ex in dataset:
    messages = ex["messages"]
    if not any(message["role"] == "system" for message in messages):
        n_missing_system += 1
    if not any(message["role"] == "user" for message in messages):
        n_missing_user += 1
    n_messages.append(len(messages))
    convo_lens.append(num_tokens_from_messages(messages))
    assistant_message_lens.append(num_assistant_tokens_from_messages(messages))

print("Num examples missing system message:", n_missing_system)
print("Num examples missing user message:", n_missing_user)
print_distribution(n_messages, "num_messages_per_example")
print_distribution(convo_lens, "num_total_tokens_per_example")
print_distribution(assistant_message_lens, "num_assistant_tokens_per_example")
n_too_long = sum(l > 4096 for l in convo_lens)
print(f"\n{n_too_long} examples may be over the 4096 token limit, they will be truncated during fine-tuning")

# Pricing and default n_epochs estimate
MAX_TOKENS_PER_EXAMPLE = 4096

MIN_TARGET_EXAMPLES = 100
MAX_TARGET_EXAMPLES = 25000
TARGET_EPOCHS = 3
MIN_EPOCHS = 1
MAX_EPOCHS = 25

n_epochs = TARGET_EPOCHS
n_train_examples = len(dataset)
if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:
    n_epochs = min(MAX_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)
elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:
    n_epochs = max(MIN_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)

n_billing_tokens_in_dataset = sum(min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens)
print(f"Dataset has ~{n_billing_tokens_in_dataset} tokens that will be charged for during training")
print(f"By default, you'll train for {n_epochs} epochs on this dataset")
print(f"By default, you'll be charged for ~{n_epochs * n_billing_tokens_in_dataset} tokens")
print("See pricing page to estimate total costs")
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141

验证数据后,需要上传文件才能用于微调作业:

openai.File.create(
  file=open("mydata.jsonl", "rb"),
  purpose='fine-tune'
)
1
2
3
4

# 3. 创建及使用微调模型

# 3.1 创建微调模型

确保数据集的数量和结构正确并上传文件后,下一步是创建微调作业。

使用 OpenAI SDK 开始微调工作:

import os
import openai
openai.api_key = "OPENAI_API_KEY"
openai.FineTuningJob.create(training_file="file-abc123", model="gpt-3.5-turbo")
1
2
3
4

在开启微调后,任务就排队开始训练了,通常需要等几分钟或者几个小时训练,等完成训练后,用户就会收到一份确认邮件了。

除了创建微调作业外,还可以列出现有作业、检索作业状态或取消作业。

# List 10 fine-tuning jobs
openai.FineTuningJob.list(limit=10)

# Retrieve the state of a fine-tune
openai.FineTuningJob.retrieve("ft-abc123")

# Cancel a job
openai.FineTuningJob.cancel("ft-abc123")

# List up to 10 events from a fine-tuning job
openai.FineTuningJob.list_events(id="ft-abc123", limit=10)

# Delete a fine-tuned model (must be an owner of the org the model was created in)
import openai
openai.Model.delete("ft-abc123")
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

# 3.2 使用微调模型

训练结束后,就可以通过查看任务细节看到模型名字“fine_tuned_model”,之后就可以调用使用了。如果请求错误,可以再等一会,因为其可能在正在加载。

import os
import openai
openai.api_key = "OPENAI_API_KEY"

completion = openai.ChatCompletion.create(
  model="ft:gpt-3.5-turbo:my-org:custom_suffix:id",
  messages=[
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "Hello!"}
  ]
)

print(completion.choices[0].message)
1
2
3
4
5
6
7
8
9
10
11
12
13

# 3.3 分析微调模型

官方会提供一些常规的量化指标,比如:training loss、training token accuracy、test loss、test token accuracy等。

另外直接看生成的case可能更直观,可以对比原版的ChatGPT和微调后的ChatGPT的效果。

[1] 迭代训练数据质量

如果模型结果不尽人意,那么首先可以检查当前的训练数据质量是否过关。

  • 收集当前的badcase,然后针对性的加对应的数据来纠正模型
  • 如果模型在语法、风格等方面不满足需要,那需要查看当前的训练样本中是否就已经包含了对应的脏数据,需要剔除或者修正。
  • 平衡数据:假设训练数据中有60%是"I cannot answer this",但是inference的时候只有5%是这种情况,那么需要平衡好这个比例。
  • 确保模型包含所有信息,比如有个需求是希望模型根据用户的个人特征来赞美他们,而训练数据中包含了与之前对话中未提及的特征相关的赞美,那么模型可能会学会产生虚构的信息即幻觉。
  • 确保高的一致性:如果当前需求的训练集是多个人协同创建的,那么模型的性能可能会受到多人之间的一致性和一致性水平的限制。例如,在一个文本提取任务中,如果多人只在提取的片段上达成了70%的一致,那么模型可能天花板就是70%。
  • 训练集确保和官方要求的格式一样。

[2] 迭代训练数据数量

当确保了当前训练集质量后,就可以考虑增加训练集的数量来更进一步提高性能。增加量级通常有助于模型更好地学习任务,特别是处理一些边缘case。预计每当增加一倍量级时,都会有类似客观的改进程度。可以通过以下方式粗略估计通过增加训练数据大小所带来的预期效果:

  • Step1:在当前数据集上进行微调
  • Step2:在当前数据集的一半上进行微调
  • Step3:观察两者之间的质量差距

一般来说,如果必须进行数据量权衡,较少数量的高质量数据通常比大量低质量数据更有效。

[3] 迭代训练超参

在第一次试水训练的时候,官方不建议指定epoch,而是由官方根据样本量大小设置一个默认值,然后训练完看效果后再决定怎么调整epoch:

  • 当需求是偏单一场景的时候(比如分类、实体抽取等等),可以适当尝试增加1-2个epoch
  • 当需求是偏多任务场景的时候或者说发现模型缺乏多样性的时候可以尝试减少1-2个epoch

# 4. ChatGPT大模型微调示例

如下示例代码,我已经在Github上开源了,项目地址:https://github.com/Logistic98/chatgpt-fine-tuning (opens new window)

# 4.1 准备ChatGPT-KEY的付费账号

前提条件:ChatGPT大模型微调需要 OpenAI 的 API-KEY,而且要求这个账号必须是付费过的,新账号的白嫖额度是不能用的(第三方购买的很多廉价账号都是机器批量注册的新账号,都是不能用的),会在创建微调时报错。这是必须的条件,没有就不用往下看了。

OpenAI不允许未付费账号进行ChatGPT大模型微调

ChatGPT KEY国内无法直充,我这里是通过欧易充值的USDT,经过一些币种转换,再通过Depay信用卡充值的ChatGPT KEY,充值过程极为繁琐。

# 4.2 准备并上传微调数据集

# 4.2.1 制作微调数据集

以法律方向的微调为例,我这里只是为了走通流程,偷个懒就不自己制作微调数据集了。

原始数据集从 ChatLaw (opens new window) 项目中进行下载,通过如下脚本,将其转换成 gpt-3.5-turbo 微调所需的数据格式。

make_dataset.py

# -*- coding: utf-8 -*-

import json

# 从demo_data_法律咨询.jsonl文件中读取数据
# 来源:https://github.com/PKU-YuanGroup/ChatLaw/blob/main/data/demo_data_%E6%B3%95%E5%BE%8B%E5%92%A8%E8%AF%A2.jsonl
data = []
with open('./data/demo_data_法律咨询.jsonl', 'r', encoding='utf-8') as file:
    for line in file:
        data.append(json.loads(line))

# 转换格式变成 gpt-3.5-turbo 微调所需的数据格式
formatted_data = []
for entry in data:
    meta_instruction = entry["meta_instruction"].replace("你一个名叫ChatLAW,由北京大学团队开发的人工智能助理:", "你一个人工智能法律助理:")
    messages = []
    messages.append({
        "role": "system",
        "content": meta_instruction
    })
    for chat in entry["chat"]:
        messages.append({
            "role": "user",
            "content": chat["咨询者"]
        })
        messages.append({
            "role": "assistant",
            "content": chat["ChatLAW"]
        })
    formatted_data.append({
        "messages": messages
    })

# 将结果写入到fine_tuning.jsonl文件中
with open('./data/fine_tuning.jsonl', 'w', encoding='utf-8') as file:
    for item in formatted_data:
        file.write(json.dumps(item, ensure_ascii=False))
        file.write('\n')
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

# 4.2.2 检查微调数据集格式

制作完数据集后,使用官方脚本(我这里把注释和print内容翻译成中文了)校验一下数据集格式是否符合要求。

check_dataset.py

# -*- coding: utf-8 -*-

import json
import tiktoken
import numpy as np
from collections import defaultdict

# 指定数据路径并打开JSONL文件
data_path = "./data/fine_tuning.jsonl"

# 加载数据集
with open(data_path) as f:
    dataset = [json.loads(line) for line in f]

# 通过检查示例数量和第一项来快速查看数据
print("示例数量:", len(dataset))
print("第一个示例:")
for message in dataset[0]["messages"]:
    print(message)

# 我们需要遍历所有不同的示例,确保格式正确,并符合Chat completions消息结构
format_errors = defaultdict(int)

for ex in dataset:
    if not isinstance(ex, dict):
        format_errors["data_type"] += 1
        continue

    messages = ex.get("messages", None)
    if not messages:
        format_errors["missing_messages_list"] += 1
        continue

    for message in messages:
        if "role" not in message or "content" not in message:
            format_errors["message_missing_key"] += 1

        if any(k not in ("role", "content", "name") for k in message):
            format_errors["message_unrecognized_key"] += 1

        if message.get("role", None) not in ("system", "user", "assistant"):
            format_errors["unrecognized_role"] += 1

        content = message.get("content", None)
        if not content or not isinstance(content, str):
            format_errors["missing_content"] += 1

    if not any(message.get("role", None) == "assistant" for message in messages):
        format_errors["example_missing_assistant_message"] += 1

if format_errors:
    print("发现错误:")
    for k, v in format_errors.items():
        print(f"{k}: {v}")
else:
    print("未发现错误")

# 除了消息的结构,我们还需要确保长度不超过4096个令牌限制

# 计数令牌功能
encoding = tiktoken.get_encoding("cl100k_base")


# 不精确!简化自https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):
    num_tokens = 0
    for message in messages:
        num_tokens += tokens_per_message
        for key, value in message.items():
            num_tokens += len(encoding.encode(value))
            if key == "name":
                num_tokens += tokens_per_name
    num_tokens += 3
    return num_tokens


def num_assistant_tokens_from_messages(messages):
    num_tokens = 0
    for message in messages:
        if message["role"] == "assistant":
            num_tokens += len(encoding.encode(message["content"]))
    return num_tokens


def print_distribution(values, name):
    print(f"\n#### {name}的分布:")
    print(f"最小值 / 最大值: {min(values)}, {max(values)}")
    print(f"平均值 / 中位数: {np.mean(values)}, {np.median(values)}")
    print(f"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}")

# 最后,我们可以在创建微调作业之前查看不同格式操作的结果:

# 警告和令牌计数
n_missing_system = 0
n_missing_user = 0
n_messages = []
convo_lens = []
assistant_message_lens = []

for ex in dataset:
    messages = ex["messages"]
    if not any(message["role"] == "system" for message in messages):
        n_missing_system += 1
    if not any(message["role"] == "user" for message in messages):
        n_missing_user += 1
    n_messages.append(len(messages))
    convo_lens.append(num_tokens_from_messages(messages))
    assistant_message_lens.append(num_assistant_tokens_from_messages(messages))

print("缺少系统消息的示例数量:", n_missing_system)
print("缺少用户消息的示例数量:", n_missing_user)
print_distribution(n_messages, "每个示例的消息数量")
print_distribution(convo_lens, "每个示例的总令牌数量")
print_distribution(assistant_message_lens, "每个示例的助理令牌数量")
n_too_long = sum(l > 4096 for l in convo_lens)
print(f"\n{n_too_long}个示例可能超过4096个令牌限制,微调期间将被截断")

# 定价和默认n_epochs估计
MAX_TOKENS_PER_EXAMPLE = 4096
MIN_TARGET_EXAMPLES = 100
MAX_TARGET_EXAMPLES = 25000
TARGET_EPOCHS = 3
MIN_EPOCHS = 1
MAX_EPOCHS = 25

n_epochs = TARGET_EPOCHS
n_train_examples = len(dataset)
if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:
    n_epochs = min(MAX_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)
elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:
    n_epochs = max(MIN_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)
n_billing_tokens_in_dataset = sum(min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens)

print(f"数据集包含约{n_billing_tokens_in_dataset}个将在训练期间收费的令牌")
print(f"默认情况下,您将对此数据集进行{n_epochs}个时代的训练")
print(f"默认情况下,您将为约{n_epochs * n_billing_tokens_in_dataset}个令牌收费")
print("请参阅定价页面以估算总成本")
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

数据集是符合要求的,输出内容如下:

检查ChatGPT微调数据集格式

# 4.2.3 上传微调数据集

填写你的 OpenAI API-KEY,上传微调数据集,这里需要记录下 training_file.id,下面的微调任务会用到。

upload_dataset.py

# -*- coding: utf-8 -*-

import openai
openai.api_key = "your_openai_api_key"

# 上传训练数据集
training_file = openai.File.create(
file=open("./data/fine_tuning.jsonl", "rb"),
purpose="fine-tune"
)

# file.id要复制下来,下一步开始微调要用
print(training_file.id)
1
2
3
4
5
6
7
8
9
10
11
12
13

# 4.3 创建微调任务并完成微调

# 4.3.1 创建微调任务

上传完微调数据集之后,就可以创建微调任务了。首先执行如下命令更新一下 openai 包,旧版没有FineTuningJob功能。

$ pip3 install --upgrade openai
1

之后填写 OpenAI API-KEY 及上一步得到的 training_file.id,开始微调训练。

# -*- coding: utf-8 -*-

import openai
openai.api_key = "your_openai_api_key"

# 创建微调模型
openai.FineTuningJob.create(training_file="your_training_file_id", model="gpt-3.5-turbo")
1
2
3
4
5
6
7

# 4.3.2 微调过程中查看状态

微调过程中,可以查看作业列表、作业状态、作业事件等信息,并可以随时取消作业。

get_fine_tuning_status.py

# -*- coding: utf-8 -*-

import openai
openai.api_key = "your_openai_api_key"

print("===列出10个微调作业")
print(openai.FineTuningJob.list(limit=10))

print("===检索微调作业的状态")
print(openai.FineTuningJob.retrieve("your_ftjob_id"))

print("===列出最多10个来自微调作业的事件")
print(openai.FineTuningJob.list_events(id="your_ftjob_id", limit=10))

# print("===取消作业")
# print(openai.FineTuningJob.cancel("your_ftjob_id"))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

部分输出内容如下:

===检索微调作业的状态
{
  "object": "fine_tuning.job",
  "id": "ftjob-5XsithSRiJ6mvf24IP9xq7eW",
  "model": "gpt-3.5-turbo-0613",
  "created_at": 1692941148,
  "finished_at": 1692941721,
  "fine_tuned_model": "ft:gpt-3.5-turbo-0613:personal::7rJlWrzp",
  "organization_id": "org-sE6KS2sIIgrV8cmzJYQCkfDA",
  "result_files": [
    "file-x0qLLS90VDNV3Xk3EEhA3iFB"
  ],
  "status": "succeeded",
  "validation_file": null,
  "training_file": "file-LX7MRoSwB7je9yuC4FydIgV5",
  "hyperparameters": {
    "n_epochs": 5
  },
  "trained_tokens": 26100
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

微调训练完成后,OpenAI官方会给你发送邮件通知:

ChatGPT微调完成的邮件通知

# 4.4 使用微调模型

# 4.4.1 在OpenAI Playground上试用

点开邮件通知的 OpenAI Playground 链接,在 USER 处输入问题,点击 Submit 按钮提交,在线预览微调效果。

在OpenAI-Playground上试用

# 4.4.2 使用API在代码里应用

model 可以通过上文“检查微调作业的状态”的输出里获取 fine_tuned_model,也可以从 OpenAI Playground 链接的路径里获取。

use_fine_tuning_model.py

# -*- coding: utf-8 -*-

import openai
openai.api_key = "your_openai_api_key"

completion = openai.ChatCompletion.create(
  model="ft:gpt-3.5-turbo:my-org:custom_suffix:id",
  messages=[
    {"role": "user", "content": "如果有人擅自破坏水库闸门,但没有造成重大损失,是否构成决水罪?"}
  ]
)

print(completion.choices[0].message)
1
2
3
4
5
6
7
8
9
10
11
12
13

运行结果:

{
  "role": "assistant",
  "content": "根据《中华人民共和国刑法》第一百一十一条的规定,故意破坏水利设施,罪行轻微的,处三年以下有期徒刑、拘役或者管制。具体来说,破坏水库闸门案件中,如果被告人故意破坏水库闸门,但是没有造成重大损失,属于罪行轻微的情形,构成故意破坏水利设施罪。"
}
1
2
3
4

# 5. 参考资料

[1] OpenAI官方支持ChatGPT进行微调了 from 吃果冻不吐果冻皮 (opens new window)

[2] Fine-tuning from OpenAI官方文档 (opens new window)

[3] GPT best practices from OpenAI官方文档 (opens new window)

[4] 付费定价页面 from OpenAI官方文档 (opens new window)

[5] 固定Prompt测试LLM应用程序效果 from GIthub (opens new window)

Last Updated: 2/14/2024, 1:47:08 PM