深入浅出了解llm模型微调

作者:timelocked

0.引言

随着人工智能行业的快速发展,以及最近大语言(LLM)模型的兴起,出现了各种各样 的大语言模型,例如gpt,qwen,grok,claude等等。

但是,在面对一些特殊性的应用场景时,这类通用型大预言模型往往表现的不尽如人意。 比如在法律,医疗领域内,需要模型具有强大的专业知识和技能,才能在这些垂直领域内 提供更精准、更有价值的服务。这类通用型大语言模型需要针对特定领域进行微调,以适 应特殊应用场景的需求。

微调(Fine-tuning)是一种让预训练模型适应特定任务或领域的技术。通过在特定领域的数据集上进行额外训练,模型可以获得该领域的专业知识和语言特点,从而提高在特定应用场景下的性能表现。

微调的主要优势包括:

  • 提高特定领域的准确性和专业性
  • 减少模型在专业领域的错误和幻觉
  • 使模型更好地理解和生成领域特定的术语和表达
  • 提升用户体验和满意度

对于法律、医疗等专业领域,微调后的模型能够更准确地理解专业术语、法规条文或医学知识,从而提供更可靠的服务和建议。这种针对性的优化使大语言模型能够在各个垂直领域发挥更大的价值。

并且微调是基于已有的大预言模型二次训练得到的,不光可以继承基础模型的语言能力和已具备的知识,还可以强化模型在特定领域的专业表现。因此,微调成为了将大语言模型应用到特定场景的关键技术,为模型赋予了更强的针对性和实用价值。本文将深入探讨LLM模型微调的原理、方法和实践应用,帮助读者全面了解这一重要技术。

以下介绍一些前置知识,以便于读者可以更好的明白

首先,所有的llm模型都是通过训练得到的,何为训练? 简单来说就是教育模型学习知识,就像是老师教小孩子一样。

那么知识的形式是什么? 我们学习可以通过视频,课本,书籍等,而模型的学习是通过一个叫作数据集的东西,所以数据集其实就是我们人类的课本,是承载模型学习知识的载体。那么实际表现就是一条条特定格式的数据。

模型是如何学会这些知识的呢?

简单来说,模型是通过参数的变化调整来学习的。一旦模型的参数发生改变,就表明模型在进行学习,所以无论是train还是fine本质都是通过外界数据的输入让其内部的参数发生改变。

具体过程如下
输入数据 → 前向传播 → 计算损失 → 反向传播 → 更新参数

本文所有实践微调实验,都会基于qwen3 系列模型。

1.什么是微调(Fine-tuning)?

微调(Fine-tuning)是一种机器学习技术,特指对预训练模型进行二次训练,使其适应特定任务或领域的过程。对于大语言模型(LLM)来说,微调是在已有基础模型上使用特定领域的数据进行进一步训练,从而让模型能够更好地理解和生成与该领域相关的内容。所以,简单来说,微调即是用已经训练好的llm在特定内容的数据集上进行二次训练,让模型在某些方面获得特定的强化。

微调也分为多种微调, 主要包括:全参数微调(Full Parameter Fine-tuning)、参数高效微调(Parameter-EfficientFine-tuning,PEFT)以及 指令微调(Instruction Tuning)。全参数微调是最为直接的方法,但需要较大的计算资源;参数高效微调(如LoRA、P-Tuning等)则通过仅更新部分参数来降低资源消耗;而指令微调则专注于使模型能够更好地理解和执行人类指令。

本文主要介绍lora微调,受限于现有设备,以及训练的速度,只能如此。所有的微调的原理其实基本都差异不大,无非是使用的架构、微调策略、学习率的区别。个人认为,要依据具体的业务场景来选择微调的方法。

2. 如何进行微调呢?(qlora微调qwen3例)

2.1 微调方式

废话不多说,本文以实践为主,直接来实战吧

这里以qwen3 的微调为例,主要的过程原理如图1所示

                                                                          图1

采用lora微调,冻结一部分底层模型权重,来进行训练(微调),如图2

                                                                      图2

r是lora微调的关键,r是LoRA中最关键的超参数,控制着低秩分解的秩大小,直接决定了模型的表达能力和效率平衡。

简单来说,r就代表了lora可以动的参数大小,r的值越大,代表微调的参数越多,反之越小,动的参数越小。以下是常见的r值对照表

微调中最关键的数值就是r,一般需要小运行几次,来找到最佳的r值是多少。

对比传统的微调,lora微调在尽力不损失原有模型能力的情况下,大幅度减少计算量与参数量。

 

2.2 微调实验准备工作

a.实验服务器配置

本实验采用了一台ubuntu服务器,具体配置如下
cpu:i510400f

gpu: rtx3080 20g 、1050ti 4g

RAM: 32g ddr4 2666

ROM:ssd 1t

b.实验用到的基础模型以及数据集

模型: llm 模型 qwen3 1.7b q4量化版本

数据集:来自于huggingface :

https://huggingface.co/datasets/ystemsrx/Erotic_Literature_Collection?not-for-all-audiences=true

主要关于中文色情文学的,用于提升qwen3在色情写作方面的能力

c. 实验具体流程

1. 新建用于实验的conda环境

本次实验完全基于conda虚拟环境进行构建,新建任意名称的虚拟环境。

需要以下的资源

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
_openmp_mutex             5.1                       1_gnu  
accelerate                1.10.1                   pypi_0    pypi
aiohappyeyeballs          2.6.1                    pypi_0    pypi
aiohttp                   3.12.15                  pypi_0    pypi
aiosignal                 1.4.0                    pypi_0    pypi
annotated-types           0.7.0                    pypi_0    pypi
async-timeout             5.0.1                    pypi_0    pypi
attrs                     25.3.0                   pypi_0    pypi
bitsandbytes              0.47.0                   pypi_0    pypi
bzip2                     1.0.8                h5eee18b_6  
ca-certificates           2025.9.9             h06a4308_0  
certifi                   2025.8.3                 pypi_0    pypi
charset-normalizer        3.4.3                    pypi_0    pypi
click                     8.2.1                    pypi_0    pypi
datasets                  4.0.0                    pypi_0    pypi
dill                      0.3.8                    pypi_0    pypi
expat                     2.7.1                h6a678d5_0  
filelock                  3.13.1                   pypi_0    pypi
frozenlist                1.7.0                    pypi_0    pypi
fsspec                    2024.6.1                 pypi_0    pypi
gitdb                     4.0.12                   pypi_0    pypi
gitpython                 3.1.45                   pypi_0    pypi
hf-xet                    1.1.9                    pypi_0    pypi
huggingface-hub           0.34.4                   pypi_0    pypi
idna                      3.10                     pypi_0    pypi
jinja2                    3.1.4                    pypi_0    pypi
ld_impl_linux-64          2.40                 h12ee557_0  
libffi                    3.4.4                h6a678d5_1  
libgcc-ng                 11.2.0               h1234567_1  
libgomp                   11.2.0               h1234567_1  
libstdcxx-ng              11.2.0               h1234567_1  
libuuid                   1.41.5               h5eee18b_0  
libxcb                    1.17.0               h9b100fa_0  
libzlib                   1.3.1                hb25bd0a_0  
markupsafe                2.1.5                    pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
multidict                 6.6.4                    pypi_0    pypi
multiprocess              0.70.16                  pypi_0    pypi
ncurses                   6.5                  h7934f7d_0  
networkx                  3.3                      pypi_0    pypi
numpy                     2.1.2                    pypi_0    pypi
nvidia-cublas-cu11        11.11.3.6                pypi_0    pypi
nvidia-cuda-cupti-cu11    11.8.87                  pypi_0    pypi
nvidia-cuda-nvrtc-cu11    11.8.89                  pypi_0    pypi
nvidia-cuda-runtime-cu11  11.8.89                  pypi_0    pypi
nvidia-cudnn-cu11         9.1.0.70                 pypi_0    pypi
nvidia-cufft-cu11         10.9.0.58                pypi_0    pypi
nvidia-curand-cu11        10.3.0.86                pypi_0    pypi
nvidia-cusolver-cu11      11.4.1.48                pypi_0    pypi
nvidia-cusparse-cu11      11.7.5.86                pypi_0    pypi
nvidia-nccl-cu11          2.21.5                   pypi_0    pypi
nvidia-nvtx-cu11          11.8.86                  pypi_0    pypi
openssl                   3.0.17               h5eee18b_0  
packaging                 25.0                     pypi_0    pypi
pandas                    2.3.2                    pypi_0    pypi
peft                      0.17.1                   pypi_0    pypi
pillow                    11.0.0                   pypi_0    pypi
pip                       25.2               pyhc872135_0  
platformdirs              4.4.0                    pypi_0    pypi
propcache                 0.3.2                    pypi_0    pypi
protobuf                  6.32.1                   pypi_0    pypi
psutil                    7.0.0                    pypi_0    pypi
pthread-stubs             0.3                  h0ce48e5_1  
pyarrow                   21.0.0                   pypi_0    pypi
pydantic                  2.11.7                   pypi_0    pypi
pydantic-core             2.33.2                   pypi_0    pypi
python                    3.10.18              h1a3bd86_0  
python-dateutil           2.9.0.post0              pypi_0    pypi
pytz                      2025.2                   pypi_0    pypi
pyyaml                    6.0.2                    pypi_0    pypi
readline                  8.3                  hc2a1206_0  
regex                     2025.9.1                 pypi_0    pypi
requests                  2.32.5                   pypi_0    pypi
safetensors               0.6.2                    pypi_0    pypi
sentry-sdk                2.37.1                   pypi_0    pypi
setuptools                78.1.1          py310h06a4308_0  
six                       1.17.0                   pypi_0    pypi
smmap                     5.0.2                    pypi_0    pypi
sqlite                    3.50.2               hb25bd0a_1  
sympy                     1.13.3                   pypi_0    pypi
tiktoken                  0.11.0                   pypi_0    pypi
tk                        8.6.15               h54e0aa7_0  
tokenizers                0.22.0                   pypi_0    pypi
torch                     2.7.1+cu118              pypi_0    pypi
torchaudio                2.7.1+cu118              pypi_0    pypi
torchvision               0.22.1+cu118             pypi_0    pypi
tqdm                      4.67.1                   pypi_0    pypi
transformers              4.56.1                   pypi_0    pypi
triton                    3.3.1                    pypi_0    pypi
trl                       0.23.0                   pypi_0    pypi
typing-extensions         4.12.2                   pypi_0    pypi
typing-inspection         0.4.1                    pypi_0    pypi
tzdata                    2025.2                   pypi_0    pypi
urllib3                   2.5.0                    pypi_0    pypi
wandb                     0.21.4                   pypi_0    pypi
wheel                     0.45.1          py310h06a4308_0  
xorg-libx11               1.8.12               h9b100fa_1  
xorg-libxau               1.0.12               h9b100fa_0  
xorg-libxdmcp             1.1.5                h9b100fa_0  
xorg-xorgproto            2024.1               h5eee18b_1  
xxhash                    3.5.0                    pypi_0    pypi
xz                        5.6.4                h5eee18b_1  
yarl                      1.20.1                   pypi_0    pypi
zlib                      1.3.1                hb25bd0a_0  

这里提供一个包安装脚本

#!/bin/bash

# Qwen3微调环境安装脚本
# 基于现有qwen3_wiki环境生成

echo "🚀 开始创建Qwen3微调环境..."

# 检查conda是否安装
if ! command -v conda &> /dev/null; then
    echo "❌ Conda未找到,请先安装Anaconda或Miniconda"
    exit 1
fi

# 设置环境名称
ENV_NAME="qwen3_fine_tuning"

# 检查环境是否已存在
if conda env list | grep -q "$ENV_NAME"; then
    echo "⚠️  环境 $ENV_NAME 已存在,是否删除并重新创建? (y/n)"
    read -r response
    if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
        conda env remove -n "$ENV_NAME" -y
        echo "🗑️  已删除旧环境"
    else
        echo "❌ 取消安装"
        exit 1
    fi
fi

# 创建基础conda环境(Python 3.10.18)
echo "📦 创建基础conda环境..."
conda create -n "$ENV_NAME" python=3.10.18 -y

# 激活环境
echo "🔧 激活环境..."
source "$(conda info --base)/etc/profile.d/conda.sh"
conda activate "$ENV_NAME"

# 验证环境激活
if [[ "$CONDA_DEFAULT_ENV" != "$ENV_NAME" ]]; then
    echo "❌ 环境激活失败"
    exit 1
fi

echo "✅ 环境已激活: $CONDA_DEFAULT_ENV"

# 升级pip
echo "📈 升级pip..."
pip install --upgrade pip

# 安装PyTorch及CUDA相关包
echo "🔥 安装PyTorch (CUDA 11.8)..."
pip install torch==2.7.1+cu118 torchaudio==2.7.1+cu118 torchvision==0.22.1+cu118 --index-url https://download.pytorch.org/whl/cu118

# 安装核心AI/ML包
echo "🤖 安装核心AI/ML包..."
pip install transformers==4.56.1
pip install datasets==4.0.0
pip install accelerate==1.10.1
pip install peft==0.17.1
pip install trl==0.23.0

# 安装量化和优化包
echo "⚡ 安装量化和优化包..."
pip install bitsandbytes==0.47.0
pip install triton==3.3.1

# 安装数据处理包
echo "📊 安装数据处理包..."
pip install numpy==2.1.2
pip install pandas==2.3.2
pip install pyarrow==21.0.0

# 安装工具包
echo "🛠️ 安装工具包..."
pip install tqdm==4.67.1
pip install click==8.2.1
pip install pyyaml==6.0.2
pip install regex==2025.9.1
pip install tiktoken==0.11.0
pip install tokenizers==0.22.0
pip install safetensors==0.6.2

# 安装网络和API包
echo "🌐 安装网络和API包..."
pip install requests==2.32.5
pip install aiohttp==3.12.15
pip install huggingface-hub==0.34.4

# 安装监控和日志包
echo "📈 安装监控包..."
pip install wandb==0.21.4
pip install sentry-sdk==2.37.1
pip install psutil==7.0.0

# 安装其他依赖包
echo "📚 安装其他依赖包..."
pip install pillow==11.0.0
pip install jinja2==3.1.4
pip install markupsafe==2.1.5
pip install packaging==25.0
pip install filelock==3.13.1
pip install fsspec==2024.6.1
pip install platformdirs==4.4.0
pip install sympy==1.13.3
pip install mpmath==1.3.0
pip install networkx==3.3
pip install typing-extensions==4.12.2

# 安装验证包
echo "📝 安装验证相关包..."
pip install pydantic==2.11.7
pip install pydantic-core==2.33.2
pip install annotated-types==0.7.0

# 安装时间处理包
echo "⏰ 安装时间处理包..."
pip install python-dateutil==2.9.0.post0
pip install pytz==2025.2
pip install tzdata==2025.2

# 安装其他工具包
echo "🔧 安装其他工具包..."
pip install six==1.17.0
pip install urllib3==2.5.0
pip install certifi==2025.8.3
pip install charset-normalizer==3.4.3
pip install idna==3.10
pip install attrs==25.3.0
pip install multiprocess==0.70.16
pip install dill==0.3.8
pip install xxhash==3.5.0
pip install protobuf==6.32.1

# 安装Git相关包
echo "📂 安装Git工具..."
pip install gitpython==3.1.45
pip install gitdb==4.0.12
pip install smmap==5.0.2

# 安装异步相关包
echo "🔄 安装异步处理包..."
pip install aiohappyeyeballs==2.6.1
pip install aiosignal==1.4.0
pip install async-timeout==5.0.1
pip install frozenlist==1.7.0
pip install multidict==6.6.4
pip install propcache==0.3.2
pip install yarl==1.20.1

# 安装HuggingFace相关包
echo "🤗 安装HuggingFace生态包..."
pip install hf-xet==1.1.9
pip install typing-inspection==0.4.1

# 验证安装
echo "✅ 验证安装..."

# 创建验证脚本
cat > verify_installation.py << 'EOF'
#!/usr/bin/env python3
"""验证Qwen3微调环境安装"""

import sys
import importlib

def check_package(package_name, display_name=None):
    """检查包是否正确安装"""
    if display_name is None:
        display_name = package_name
    
    try:
        module = importlib.import_module(package_name)
        version = getattr(module, '__version__', 'unknown')
        print(f"✅ {display_name:20} {version}")
        return True
    except ImportError as e:
        print(f"❌ {display_name:20} 未安装 - {e}")
        return False

def main():
    print("🔍 验证Qwen3微调环境安装")
    print("=" * 50)
    
    # 核心包检查
    core_packages = [
        ('torch', 'PyTorch'),
        ('transformers', 'Transformers'),
        ('datasets', 'Datasets'),
        ('accelerate', 'Accelerate'),
        ('peft', 'PEFT'),
        ('trl', 'TRL'),
        ('bitsandbytes', 'BitsAndBytes'),
    ]
    
    print("\n🔥 核心AI/ML包:")
    core_success = 0
    for pkg, name in core_packages:
        if check_package(pkg, name):
            core_success += 1
    
    # 工具包检查
    tool_packages = [
        ('numpy', 'NumPy'),
        ('pandas', 'Pandas'),
        ('tqdm', 'TQDM'),
        ('wandb', 'Weights & Biases'),
        ('tiktoken', 'TikToken'),
    ]
    
    print("\n🛠️ 工具包:")
    tool_success = 0
    for pkg, name in tool_packages:
        if check_package(pkg, name):
            tool_success += 1
    
    # CUDA检查
    print("\n🚀 CUDA检查:")
    try:
        import torch
        print(f"✅ CUDA可用: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"✅ CUDA版本: {torch.version.cuda}")
            print(f"✅ GPU数量: {torch.cuda.device_count()}")
            for i in range(torch.cuda.device_count()):
                print(f"✅ GPU {i}: {torch.cuda.get_device_name(i)}")
        else:
            print("⚠️  CUDA不可用,将使用CPU")
    except Exception as e:
        print(f"❌ CUDA检查失败: {e}")
    
    # 总结
    print("\n📊 安装总结:")
    print(f"核心包: {core_success}/{len(core_packages)}")
    print(f"工具包: {tool_success}/{len(tool_packages)}")
    
    if core_success == len(core_packages) and tool_success == len(tool_packages):
        print("🎉 环境安装完成!可以开始Qwen3微调了")
        return True
    else:
        print("⚠️  部分包安装失败,请检查错误信息")
        return False

if __name__ == "__main__":
    main()
EOF

# 运行验证
python verify_installation.py

# 创建快速测试脚本
cat > quick_test.py << 'EOF'
#!/usr/bin/env python3
"""快速功能测试"""

print("🧪 快速功能测试")
print("=" * 30)

try:
    # 测试PyTorch
    import torch
    print(f"✅ PyTorch: {torch.__version__}")
    
    # 测试Transformers
    from transformers import AutoTokenizer
    print("✅ Transformers导入成功")
    
    # 测试PEFT
    from peft import LoraConfig
    print("✅ PEFT LoRA配置导入成功")
    
    # 测试基本张量操作
    x = torch.randn(2, 3)
    y = torch.matmul(x, x.T)
    print("✅ PyTorch张量操作正常")
    
    # 测试CUDA(如果可用)
    if torch.cuda.is_available():
        x_cuda = x.cuda()
        print("✅ CUDA张量操作正常")
    
    print("\n🎉 所有测试通过!环境就绪")
    
except Exception as e:
    print(f"❌ 测试失败: {e}")

EOF

python quick_test.py

# 创建环境使用说明
cat > README.md << 'EOF'
# Qwen3微调环境使用说明

## 环境激活
```bash
conda activate qwen3_fine_tuning

核心包版本

  • Python: 3.10.18
  • PyTorch: 2.7.1+cu118
  • Transformers: 4.56.1
  • PEFT: 0.17.1
  • TRL: 0.23.0
  • BitsAndBytes: 0.47.0

快速开始

1. 基础LoRA微调

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model

# 加载模型
model_name = "Qwen/Qwen3-7B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# LoRA配置
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    lora_dropout=0.1,
)

# 应用LoRA
model = get_peft_model(model, lora_config)

2. 使用TRL进行SFT

from trl import SFTTrainer
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=5e-4,
    num_train_epochs=3,
    logging_steps=10,
    save_steps=500,
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=your_dataset,
)

trainer.train()

监控和日志

  • 使用Weights & Biases: wandb login
  • 查看GPU使用: nvidia-smi
  • 监控训练进度: 查看TensorBoard日志

故障排除

  1. CUDA内存不足: 减少batch_size或使用gradient_checkpointing
  2. 包冲突: 重新创建环境
  3. 权限问题: 检查文件夹写入权限

EOF

输出安装完成信息

echo ""
echo "🎉 Qwen3微调环境安装完成!"
echo ""
echo "📝 使用说明:"
echo "1. 激活环境: conda activate $ENV_NAME"
echo "2. 查看README: cat README.md"
echo "3. 验证安装: python verify_installation.py"
echo "4. 快速测试: python quick_test.py"
echo ""
echo "🚀 现在可以开始Qwen3微调了!"


请按照上述的名称与版本来构建环境,注意,如果采用不同的版本可能会出现兼容性的问题。

### 2.构建训练数据集

采用如下脚本进行构建
 

```python
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer
import json
import os

"""
使用ystemsrx/Erotic_Literature_Collection
构建本地数据集用于微调Qwen
采用保守的处理方式:
- 保守的上下文使用:避免信息过载
- 统一的格式:所有样本都遵循相同的格式
- 高质量过滤:确保每个样本都有效
- 适度长度:控制token数量,适合快速训练
"""

def load_and_preprocess_data():
    """加载并预处理ystemsrx/Erotic_Literature_Collection数据集 - 保守版"""
    
    print("正在加载Erotic_Literature_Collection数据集...")
    # 从huggingface 上加载数据集
    dataset = load_dataset("ystemsrx/Erotic_Literature_Collection")
    
    print(dataset)
    print(dataset.keys())
    print(f"训练集样本数: {len(dataset['train'])}")
    
    # 加载tokenizer
    model_name = "Qwen/Qwen-1_8B"  # 保持与你的原始代码一致
    print(f"加载tokenizer: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    
    # 修复token问题
    if tokenizer.eos_token is None:
        tokenizer.eos_token = "<|endoftext|>"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    print(f"EOS token: {tokenizer.eos_token}")
    
    # 手动处理数据,避免map函数的问题
    max_length = 512  # 保守的长度设置
    
    def _process_examples(dataset_split):
        processed_data = []
        print(f"开始处理数据...")
        
        for i, example in enumerate(dataset_split):
            if i % 1000 == 0:
                print(f"已处理: {i}/{len(dataset_split)}")
            
            try:
                # 提取文本内容
                text = example.get('text', '') or ""
                
                # 基本验证
                if not text or len(text.strip()) < 50:  # 跳过太短的文本
                    continue
                
                # 清理文本
                text = clean_text(text)
                if not text:
                    continue
                
                # 构建训练格式 - 使用简单的续写格式
                prompt = "请继续这个故事:\n"
                
                # 截取适当长度的文本作为prompt
                if len(text) > 1000:
                    # 取前面一部分作为prompt,后面作为completion
                    split_point = min(300, len(text) // 3)
                    story_prompt = text[:split_point].rsplit('.', 1)[0] + '.'  # 在句子边界截断
                    story_completion = text[len(story_prompt):].strip()
                    
                    if len(story_completion) < 100:  # 确保completion有足够内容
                        continue
                        
                    full_prompt = prompt + story_prompt + "\n\n"
                    full_text = full_prompt + story_completion + tokenizer.eos_token
                else:
                    # 短文本直接作为completion
                    full_text = prompt + text + tokenizer.eos_token
                
                # 编码检查长度
                input_ids = tokenizer.encode(full_text)
                if len(input_ids) > max_length:
                    # 尝试截断
                    truncated_text = tokenizer.decode(input_ids[:max_length-1]) + tokenizer.eos_token
                    input_ids = tokenizer.encode(truncated_text)
                
                # 添加到处理后的数据
                processed_data.append({
                    "text": full_text,
                    "input_ids": input_ids,
                    "token_count": len(input_ids)
                })
                
            except Exception as e:
                print(f"样本 {i} 处理失败: {e}")
                continue
        
        print(f"处理完成! 有效样本数: {len(processed_data)}")
        return processed_data
    
    def clean_text(text):
        """清理文本内容"""
        if not isinstance(text, str):
            return ""
        
        # 基础清理
        text = text.strip()
        
        # 移除过多的换行符和空白
        lines = [line.strip() for line in text.split('\n') if line.strip()]
        text = '\n'.join(lines)
        
        # 移除特殊字符序列
        import re
        text = re.sub(r'\s+', ' ', text)  # 多个空格合并为一个
        text = re.sub(r'\n{3,}', '\n\n', text)  # 多个换行合并
        
        # 长度检查
        if len(text) < 50 or len(text) > 5000:  # 跳过太短或太长的文本
            return ""
        
        return text.strip()
    
    # 处理数据
    processed_data = _process_examples(dataset['train'])
    
    # 创建新的Dataset
    processed_dataset = Dataset.from_list(processed_data)
    
    return {"train": processed_dataset}, tokenizer

def save_processed_data(dataset, tokenizer, output_dir="./processed_erotic_literature"):
    """保存处理后的数据"""
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"保存数据到 {output_dir}...")
    
    # 保存数据集
    dataset['train'].save_to_disk(f"{output_dir}/train_data")
    
    # 保存tokenizer
    tokenizer.save_pretrained(f"{output_dir}/tokenizer")
    
    # 保存数据集信息
    train_data = dataset['train']
    lengths = [item['token_count'] for item in train_data]
    
    info = {
        "total_samples": len(train_data),
        "tokenizer": tokenizer.name_or_path,
        "max_length": 512,
        "avg_length": sum(lengths) / len(lengths),
        "min_length": min(lengths),
        "max_length_actual": max(lengths),
        "format": "story_continuation"
    }
    
    with open(f"{output_dir}/dataset_info.json", "w", encoding="utf-8") as f:
        json.dump(info, f, indent=2, ensure_ascii=False)
    
    print("数据保存完成!")
    return info

def load_processed_dataset(dataset_path="./processed_erotic_literature"):
    """加载已处理的数据集"""
    from datasets import load_from_disk
    
    print(f"从 {dataset_path} 加载数据集...")
    
    train_dataset = load_from_disk(f"{dataset_path}/train_data")
    tokenizer = AutoTokenizer.from_pretrained(f"{dataset_path}/tokenizer", trust_remote_code=True)
    
    with open(f"{dataset_path}/dataset_info.json", "r", encoding="utf-8") as f:
        info = json.load(f)
    
    print("数据集加载完成!")
    print(f"数据集信息: {info}")
    
    return train_dataset, tokenizer, info

if __name__ == "__main__":
    try:
        # 处理数据
        dataset, tokenizer = load_and_preprocess_data()
        
        # 保存数据
        print("保存预处理数据...")
        info = save_processed_data(dataset, tokenizer)
        
        print("数据预处理完成!")
        print(f"训练样本数: {info['total_samples']}")
        print("文件保存位置:")
        print("  - 数据集: ./processed_erotic_literature/train_data")
        print("  - Tokenizer: ./processed_erotic_literature/tokenizer")
        
        # 显示统计信息
        print(f"序列长度统计:")
        print(f"  - 平均长度: {info['avg_length']:.1f}")
        print(f"  - 最大长度: {info['max_length_actual']}")
        print(f"  - 最小长度: {info['min_length']}")
        
        # 显示样本
        print("\n样本预览:")
        sample = dataset['train'][0]
        print(f"Token数: {sample['token_count']}")
        print(f"内容预览: {sample['text'][:200]}...")
        
        # 数据集划分建议
        total_samples = len(dataset['train'])
        train_size = int(total_samples * 0.9)
        eval_size = total_samples - train_size
        print(f"\n建议数据划分:")
        print(f"  - 训练集: {train_size} 样本")
        print(f"  - 验证集: {eval_size} 样本")
        
    except Exception as e:
        print(f"处理失败: {e}")
        import traceback
        traceback.print_exc()

ok ,代码有些难懂对吗?别急,让我来讲解下

首先要从huggingface 上下载数据集,需要用到dataset这个包,然后即可在python代码中引入数据

    print("正在加载Erotic_Literature_Collection数据集...")
    # 从huggingface 上加载数据集
    dataset = load_dataset("ystemsrx/Erotic_Literature_Collection")
    
    print(dataset)
    print(dataset.keys())
    print(f"训练集样本数: {len(dataset['train'])}")

这个dataset 是一个字典,可以工具规定的key来获取数据,至于具体的数据结构会写在网页中
例如本次所用到的数据集数据结构就如下图:

image.png

当本地没有数据的时候,会去huggingface的官方下载,并且存储在本地。

加载好了数据,我们还需要让qwen3 能看懂这些数据,因为现有数据集都是自然语言的,qwen3 模型并不能理解,模型的眼中只能理解特定的编码,所以就需要将自然语言进行转译,而tokenizer的作用就是这个,充当一个翻译的角色。每个不同的模型在训练的初期就已经定义好了自己的tokenizer,我这里只需要下载并且引用即可。具体操作如下,需要transformers库中的from_pretrained来进行加载

 # 加载tokenizer
    model_name = "Qwen/Qwen-1_8B"  
    print(f"加载tokenizer: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

加载完毕了,这里需要注意下,由于qwen3 公开的tokenizer 并没有没有预配置标准的EOS和PAD token,但是实际上qwen 使用了 "<|endoftext|>" 作为结束标记,以及确实没有PAD token ,这在批次训练的时候很致命,所以需要修改下:

# 修复token问题
    if tokenizer.eos_token is None:
        tokenizer.eos_token = "<|endoftext|>"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

加上eos,指定为<|endoftext|>,然后加上eos_token,这样即可处理多种长度的文本了。


解决了以上问题,我们现在可以来构建我们所需要的真正的数据集了,前面说了,我们的原始数据是一篇篇文章对吧,但是这显然无法直接进行训练,我们想要通过微调增强模型在色情文学中的续写能力,于是这次我们的策略如下:
• 从原始文本中截出一部分作为 prompt(context),剩下部分作为 “completion”(模型要生成的目标)。
• 模型训练时,以 prompt 为输入,以 completion 为真实目标,用交叉熵损失训练模型预测 completion 中的 token。
注意这个过程中我没有进行人工标注标签,而是 从文本自身拆出输入 vs 输出对,模型自己学“续写”语言的任务。

所以本次的训练数据集,不能完全算是无监督的数据集,也不能算是传统意义上的有监督类的训练,介于二者之间吧。

一下是实际的代码:

 max_length = 512  # 保守的长度设置
    
    def _process_examples(dataset_split):
        processed_data = []
        print(f"开始处理数据...")
        
        for i, example in enumerate(dataset_split):
            if i % 1000 == 0:
                print(f"已处理: {i}/{len(dataset_split)}")
            
            try:
                # 提取文本内容
                text = example.get('text', '') or ""
                
                # 基本验证
                if not text or len(text.strip()) < 50:  # 跳过太短的文本
                    continue
                
                # 清理文本
                text = clean_text(text)
                if not text:
                    continue
                
                # 构建训练格式 - 使用简单的续写格式
                prompt = "请继续这个故事:\n"
                
                # 截取适当长度的文本作为prompt
                if len(text) > 1000:
                    # 取前面一部分作为prompt,后面作为completion
                    split_point = min(300, len(text) // 3)
                    story_prompt = text[:split_point].rsplit('.', 1)[0] + '.'  # 在句子边界截断
                    story_completion = text[len(story_prompt):].strip()
                    
                    if len(story_completion) < 100:  # 确保completion有足够内容
                        continue
                        
                    full_prompt = prompt + story_prompt + "\n\n"
                    full_text = full_prompt + story_completion + tokenizer.eos_token
                else:
                    # 短文本直接作为completion
                    full_text = prompt + text + tokenizer.eos_token
                
                # 编码检查长度
                input_ids = tokenizer.encode(full_text)
                if len(input_ids) > max_length:
                    # 尝试截断
                    truncated_text = tokenizer.decode(input_ids[:max_length-1]) + tokenizer.eos_token
                    input_ids = tokenizer.encode(truncated_text)
                
                # 添加到处理后的数据
                processed_data.append({
                    "text": full_text,
                    "input_ids": input_ids,
                    "token_count": len(input_ids)
                })
                
            except Exception as e:
                print(f"样本 {i} 处理失败: {e}")
                continue
        
        print(f"处理完成! 有效样本数: {len(processed_data)}")
        return processed_data
  • 每次的 prompt前缀都一样,提示模型继续续写。
  • 如果文本很长 (> 1000 字符),就划分为两部分:
    • story_prompt:截取前面一部分(取 split_point,这个点是 min(300, len(text)//3),确保不要太少也不要太多),并尝试在一个句号 “.” 处分割,以保证不要中断在句子中间。
    • story_completion:剩下的后部分作为模型要“生成”的目标。
  • 如果 story_completion 太短(<100字符),则跳过这个样本(认为不够生成意义)。
  • full_prompt 是 prompt 前缀 + story_prompt + 两个换行(分割清晰)。
  • full_text 则是 prompt + prompt 部分 + completion 部分 + tokenizer.eos_token(加上结束符号,以标记序列结束)。
  • 如果文本不长,就把整个 text 作为 completion(prompt 后直接接),同样在末尾加 eos_token。
  • 把 full_text 用 tokenizer.encode 编码成 token id 列表 input_ids。
  • 如果编码后长度超过允许的 max_length,就尝试截断:先切掉到 max_length-1,然后再在末尾加上 eos_token,然后重新 encode。
  • 这样保证 input_ids 的长度 ≤ max_length。
  • 最终把这个样本加入结果中,带上文本、编码和长度。