偷偷摘套内射激情视频,久久精品99国产国产精,中文字幕无线乱码人妻,中文在线中文a,性爽19p

從零實現(xiàn)一個17M參數(shù)的GPT預(yù)訓練模型

人工智能
今天我們使用開源的的中文數(shù)據(jù)進行模型的預(yù)訓練,下面跟著我的步驟,從零實現(xiàn)你的預(yù)訓練模型。

大家好,我是寫代碼的中年人!

今天我們使用開源的的中文數(shù)據(jù)進行模型的預(yù)訓練,下面跟著我的步驟,從零實現(xiàn)你的預(yù)訓練模型。

本文所有代碼和數(shù)據(jù)資源位置:

https://github.com/ColinAIAPP/MoiraiLM

01、預(yù)訓練模型的概念

預(yù)訓練模型(Pretrained Model)就是一個已經(jīng)在海量數(shù)據(jù)上訓練過的模型,它學會了語言的基本規(guī)律、結(jié)構(gòu)和語義,然后可以拿來做各種下游任務(wù),比如寫作、翻譯、問答、分類、生成代碼等。

那“預(yù)訓練”到底在學什么?以語言模型(LLM)為例:預(yù)訓練階段的任務(wù)通常是預(yù)測下一個詞(token)。

接下來我們就一步一步實現(xiàn)一個17M參數(shù)的預(yù)訓練模型。

02、數(shù)據(jù)準備

構(gòu)建語言模型的第一要義是高質(zhì)量的數(shù)據(jù)源。對于中文任務(wù),選擇維基百科開源中文數(shù)據(jù)集是一個理想起點。這個數(shù)據(jù)集包含數(shù)百萬條中文百科條目,涵蓋歷史、文化、科技等領(lǐng)域,總量約數(shù)GB的純文本數(shù)據(jù)。它開源且免費,可通過維基百科的官方轉(zhuǎn)儲頁面下載最新版本的XML格式文件。

要解壓處理這個文件我們要使用wikiextractor工具進行數(shù)據(jù)解壓。

安裝解壓命令:

pip install wikiextractor

解壓命令:

python -m wikiextractor.WikiExtractor -b 1G -o extracted_wiki_zh zhwiki-20250920-pages-articles-multistream.xml.bz2 --json
zhwiki-20250920-pages-articles-multistream.xml.bz2:為文件名

INFO: Preprocessing 'zhwiki-20250920-pages-articles-multistream.xml.bz2' to collect template definitions: this may take some time.
INFO: Preprocessed 100000 pages
INFO: Preprocessed 200000 pages
INFO: Preprocessed 300000 pages
INFO: Preprocessed 400000 pages
INFO: Preprocessed 500000 pages
INFO: Preprocessed 600000 pages
INFO: Preprocessed 700000 pages
INFO: Preprocessed 800000 pages
INFO: Preprocessed 900000 pages
INFO: Preprocessed 1000000 pages
INFO: Preprocessed 1100000 pages
INFO: Preprocessed 1200000 pages
INFO: Preprocessed 1300000 pages
INFO: Preprocessed 1400000 pages
INFO: Preprocessed 1500000 pages
INFO: Preprocessed 1600000 pages
INFO: Preprocessed 1700000 pages
INFO: Preprocessed 1800000 pages
INFO: Preprocessed 1900000 pages
INFO: Preprocessed 2000000 pages
INFO: Preprocessed 2100000 pages
INFO: Preprocessed 2200000 pages
INFO: Preprocessed 2300000 pages
INFO: Preprocessed 2400000 pages
INFO: Preprocessed 2500000 pages
INFO: Preprocessed 2600000 pages
INFO: Preprocessed 2700000 pages
INFO: Preprocessed 2800000 pages
INFO: Preprocessed 2900000 pages
INFO: Preprocessed 3000000 pages
INFO: Preprocessed 3100000 pages
INFO: Preprocessed 3200000 pages
INFO: Preprocessed 3300000 pages
INFO: Preprocessed 3400000 pages
INFO: Preprocessed 3500000 pages
INFO: Preprocessed 3600000 pages
INFO: Preprocessed 3700000 pages
INFO: Preprocessed 3800000 pages
INFO: Preprocessed 3900000 pages
INFO: Preprocessed 4000000 pages
INFO: Preprocessed 4100000 pages
INFO: Preprocessed 4200000 pages
INFO: Preprocessed 4300000 pages
INFO: Preprocessed 4400000 pages
INFO: Preprocessed 4500000 pages
INFO: Preprocessed 4600000 pages
INFO: Preprocessed 4700000 pages
INFO: Loaded 1036734 templates in 704.2s
INFO: Starting page extraction from zhwiki-20250920-pages-articles-multistream.xml.bz2.
INFO: Using 127 extract processes.
INFO: Extracted 100000 articles (1209.6 art/s)
INFO: Extracted 200000 articles (1947.8 art/s)
INFO: Extracted 300000 articles (2325.1 art/s)
INFO: Extracted 400000 articles (3471.3 art/s)
INFO: Extracted 500000 articles (2551.1 art/s)
INFO: Extracted 600000 articles (2239.4 art/s)
INFO: Extracted 700000 articles (2299.3 art/s)
INFO: Extracted 800000 articles (1525.2 art/s)
INFO: Extracted 900000 articles (3256.1 art/s)
INFO: Extracted 1000000 articles (3485.9 art/s)
INFO: Extracted 1100000 articles (3495.0 art/s)
INFO: Extracted 1200000 articles (3330.4 art/s)
INFO: Extracted 1300000 articles (3555.6 art/s)
INFO: Extracted 1400000 articles (3456.3 art/s)
INFO: Extracted 1500000 articles (2476.1 art/s)
INFO: Extracted 1600000 articles (2268.6 art/s)
INFO: Extracted 1700000 articles (2473.5 art/s)
INFO: Extracted 1800000 articles (2305.9 art/s)
INFO: Extracted 1900000 articles (2263.9 art/s)
INFO: Extracted 2000000 articles (2136.4 art/s)
INFO: Extracted 2100000 articles (2363.0 art/s)
INFO: Extracted 2200000 articles (2601.9 art/s)
INFO: Extracted 2300000 articles (3709.0 art/s)
INFO: Extracted 2400000 articles (2723.9 art/s)
INFO: Extracted 2500000 articles (2487.1 art/s)
INFO: Extracted 2600000 articles (2621.3 art/s)
INFO: Extracted 2700000 articles (2525.4 art/s)
INFO: Extracted 2800000 articles (2666.4 art/s)
INFO: Finished 127-process extraction of 2893023 articles in 1156.5s (2501.5 art/s)

03、清洗數(shù)據(jù)

我們解壓后的數(shù)據(jù)如下圖,下面我們要把數(shù)據(jù)清洗出來。

注:

我們本步驟生成的文件為 data/cleaned_wiki_full.txt

import os
import json
import logging
import argparse
import re
from tqdm import tqdm


# 配置日志記錄
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(levelname)s - %(message)s')


# python scripts/clean_wiki_text.py data/extracted_wiki_zh data/cleaned_wiki_full.txt --min_line_length 20 --min_article_length 300




def clean_text(text: str) -> str:
    """
    對文本進行深度清洗。
    移除維基百科特有的格式標記、參考文獻、HTML標簽、日期和數(shù)字等。
    """
    # 移除維基鏈接 [[link|display]] 或 [[link]]
    text = re.sub(r'\[\[([^\]|]+\|)?([^\]]+)\]\]', r'\2', text)


    # 移除參考文獻標記 [1], [2], [ref], 等
    text = re.sub(r'\[\d+\]|\[ref\]|\[/ref\]|\[citation needed\]', '', text)


    # 移除HTML標簽
    text = re.sub(r'<[^>]+>', '', text)


    # 移除日期格式 (yyyy-mm-dd, yyyy/mm/dd, mm/dd/yyyy 等)
    text = re.sub(r'\d{1,4}[-/]\d{1,2}[-/]\d{1,4}', '', text)


    # 移除年份 (1000-2999)
    text = re.sub(r'\b[12]\d{3}\b', '', text)


    # 移除純數(shù)字(包括小數(shù))
    text = re.sub(r'\b\d+\.?\d*\b', '', text)


    # 移除重復(fù)的空白字符(但保留單個空格)
    text = re.sub(r' +', ' ', text)


    # 移除行首尾空白
    text = text.strip()


    return text




def process_extracted_wiki(extracted_dir: str, 
                          output_file: str, 
                          min_line_length: int = 20, 
                          min_article_length: int = 200):
    """
    讀取WikiExtractor輸出的JSON文件,提取、清洗文本并保存到單個文件中。


    參數(shù):
        extracted_dir: WikiExtractor輸出的目錄路徑
        output_file: 最終合并的純文本文件路徑
        min_line_length: 單行文本最小長度,用于過濾噪音(默認: 20)
        min_article_length: 文章最小長度,用于過濾短文章(默認: 200)
    """
    if not os.path.isdir(extracted_dir):
        logging.error(f"輸入的目錄不存在: {extracted_dir}")
        return


    total_articles = 0
    skipped_articles = 0


    # 第一次遍歷:獲取所有需要處理的文件列表
    file_list = []
    for root, dirs, files in os.walk(extracted_dir):
        for file_name in files:
            # 僅處理 WikiExtractor 生成的以 'wiki_' 開頭的文件
            if file_name.startswith('wiki_'):
                file_list.append(os.path.join(root, file_name))


    total_files = len(file_list)
    logging.info(f"找到 {total_files} 個文件等待處理。")


    if total_files == 0:
        logging.warning(f"目錄 {extracted_dir} 中未找到任何 'wiki_' 文件。請檢查路徑。")
        return


    # 第二次遍歷:處理文件并寫入輸出
    with open(output_file, 'w', encoding='utf-8') as f_out:
        # 使用 tqdm 包裝文件列表,顯示處理進度
        for file_path in tqdm(file_list, desc="?? 正在提取維基文本"):
            try:
                with open(file_path, 'r', encoding='utf-8') as f_in:
                    for line_num, line in enumerate(f_in, 1):
                        try:
                            article = json.loads(line)
                            text_content = article.get('text', '').strip()


                            # --- 文本清洗和過濾 ---


                            # 1. 過濾掉過短的文章,它們通常是噪音或重定向頁
                            if len(text_content) < min_article_length:
                                skipped_articles += 1
                                continue


                            # 2. 按行處理文本,過濾短行和額外的空白
                            # 保留行結(jié)構(gòu),而不是將所有行連接成一個長句子
                            cleaned_lines = []
                            for text_line in text_content.split('\n'):
                                text_line = clean_text(text_line)
                                # 只保留足夠長的行
                                if len(text_line) >= min_line_length:
                                    cleaned_lines.append(text_line)


                            # 使用換行符連接各行,保留段落結(jié)構(gòu)
                            final_text = '\n'.join(cleaned_lines)


                            # 最終檢查:確保清洗后的文本仍然足夠長
                            if final_text and len(final_text) >= min_article_length:
                                # 文章之間用兩個換行符分隔
                                f_out.write(final_text + '\n\n')
                                total_articles += 1
                            else:
                                skipped_articles += 1


                        except json.JSONDecodeError:
                            logging.warning(f"無法解析 JSON,文件: {file_path},行號: {line_num}")
                        except Exception as e:
                            logging.error(f"處理文件 {file_path} 第 {line_num} 行時出錯: {e}")


            except Exception as e:
                logging.error(f"打開文件 {file_path} 時出錯: {e}")


    logging.info(f"  所有維基百科文本已成功提取并清洗。")
    logging.info(f"   總文章數(shù): {total_articles}")
    logging.info(f"   跳過文章數(shù): {skipped_articles}")
    logging.info(f"   文件已保存到: {output_file}")




def main():
    parser = argparse.ArgumentParser(
        descriptinotallow="從 WikiExtractor 輸出的 JSON 文件中提取并清洗純文本。",
        formatter_class=argparse.RawTextHelpFormatter
    )


    # 位置參數(shù) 1: 輸入目錄
    parser.add_argument(
        "extracted_directory",
        type=str,
        help="WikiExtractor 輸出的目錄路徑 (e.g., extracted_wiki_zh)"
    )


    # 位置參數(shù) 2: 輸出文件
    parser.add_argument(
        "output_filename",
        type=str,
        help="最終合并的純文本文件路徑 (e.g., cleaned_wiki.txt)"
    )


    # 可選參數(shù): 最小行長
    parser.add_argument(
        "--min_line_length",
        type=int,
        default=20,
        help="文章中單行文本必須達到的最小長度,用于過濾噪音。默認值: 20"
    )


    # 可選參數(shù): 最小文章長度
    parser.add_argument(
        "--min_article_length",
        type=int,
        default=200,
        help="文章最小長度,用于過濾短文章和重定向頁。默認值: 200"
    )


    args = parser.parse_args()


    process_extracted_wiki(
        args.extracted_directory, 
        args.output_filename, 
        args.min_line_length,
        args.min_article_length
    )




if __name__ == "__main__":
    main()
2025-10-01 11:10:58,772 - INFO - 找到 5 個文件等待處理。
正在提取維基文本: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:33<00:00,  6.78s/it]
2025-10-01 11:11:32,681 - INFO - 所有維基百科文本已成功提取??偽恼聰?shù): 628093。文件已保存到 data/cleaned_wiki_full.txt

04、訓練分詞器

我們使用SentencePiece訓練分詞器,本次我們訓練的分詞庫大小為16k,你也可以訓練32k的分詞庫。相關(guān)代碼及過程如下:

注:

我們本步驟生成的文件為

workdir/spm_wiki_16k.model

workdir/spm_wiki_16k.vocab

import sys
import sentencepiece as spm
import argparse
import os
from tqdm import tqdm


# python scripts/train_tokenizer.py data/cleaned_wiki_full.txt workdir/spm_wiki 32000




def get_corpus_size(input_file: str) -> int:
    """計算語料的總行數(shù)和文件大小"""
    try:
        file_size_bytes = os.path.getsize(input_file)
        file_size_mb = file_size_bytes / (1024 * 1024)
        print(f"語料文件大小: {file_size_mb:.2f} MB")


        # 計算行數(shù)和總字符數(shù)
        line_count = 0
        total_chars = 0
        with open(input_file, 'r', encoding='utf-8') as f:
            for line in tqdm(f, desc="統(tǒng)計語料信息"):
                line_count += 1
                total_chars += len(line)


        print(f"語料總行數(shù) (文章數(shù)): {line_count}")
        print(f"總字符數(shù): {total_chars:,}")
        print(f"平均每行字符數(shù): {total_chars / line_count:.1f}")
        return file_size_bytes
    except Exception as e:
        print(f"警告:無法計算文件大小或行數(shù):{e}")
        return 0




def train_spm_model(input_file: str, 
                    model_prefix: str, 
                    vocab_size: int,
                    model_type: str = 'bpe',
                    character_coverage: float = 0.9995):
    """
    訓練一個SentencePiece分詞器模型。


    參數(shù):
        input_file: 訓練語料文件路徑
        model_prefix: 輸出模型文件的前綴
        vocab_size: 詞匯表大小
        model_type: 分詞算法類型 ('bpe', 'unigram', 'char', 'word')
        character_coverage: 字符覆蓋率 (0-1,通常 0.995-0.9995)
    """
    if not os.path.exists(input_file):
        print(f"錯誤:輸入語料文件未找到:{input_file}")
        sys.exit(1)


    # 確保輸出目錄存在
    output_dir = os.path.dirname(model_prefix)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
        print(f"已創(chuàng)建輸出目錄: {output_dir}")


    # 打印語料規(guī)模信息
    print("\n=== 語料分析 ===")
    get_corpus_size(input_file)


    # 構(gòu)建訓練參數(shù)
    # 對于 1.5GB 語料,建議啟用 train_extremely_large_corpus=True 加速
    train_params = {
        'input': input_file,
        'model_prefix': model_prefix,
        'vocab_size': vocab_size,
        'model_type': model_type,
        'character_coverage': character_coverage,
        'num_threads': 32,        # 增加到32(最大化CPU利用)
        'bos_id': 0,
        'eos_id': 1,
        'unk_id': 2,
        'pad_id': -1,
        'normalization_rule_name': 'identity',
        'input_sentence_size': 2000000, # 5000000,         # 增加到500萬句子采樣
        'train_extremely_large_corpus': True,   # 必須啟用
        'shuffle_input_sentence': True,
        'seed_sentencepiece_size': 2000000,     # 添加種子句子大小
        'hard_vocab_limit': False,              # 允許超過目標詞匯量以獲得更好質(zhì)量
    }


    print("\n=== SentencePiece 訓練參數(shù) ===")
    for key, value in train_params.items():
        print(f"  {key}: {value}")
    print("=" * 35)


    print("\n正在訓練 SentencePiece 模型...")
    print("   (請稍候,進度由 SentencePiece 輸出)\n")


    try:
        # 執(zhí)行訓練
        spm.SentencePieceTrainer.train(**train_params)


        print("\n分詞器模型訓練完成!")
        print(f"   模型文件: {model_prefix}.model")
        print(f"   詞匯表文件: {model_prefix}.vocab")


        # 驗證模型是否成功創(chuàng)建
        if os.path.exists(f"{model_prefix}.model") and os.path.exists(f"{model_prefix}.vocab"):
            model_size_kb = os.path.getsize(f"{model_prefix}.model") / 1024
            print(f"\n模型文件大小: {model_size_kb:.2f} KB")


            # 加載模型進行快速測試
            print("\n進行快速測試...")
            sp = spm.SentencePieceProcessor(model_file=f"{model_prefix}.model")


            test_text = "這是一個分詞測試句子。"
            tokens = sp.encode(test_text, out_type=str)
            ids = sp.encode(test_text, out_type=int)


            print(f" 測試文本: {test_text}")
            print(f" 分詞結(jié)果: {tokens}")
            print(f" Token IDs: {ids}")
        else:
            print("\n警告:模型文件生成失敗,請檢查輸入數(shù)據(jù)或參數(shù)")


    except Exception as e:
        print(f"\n訓練過程出錯: {e}")
        sys.exit(1)




def main():
    parser = argparse.ArgumentParser(
        descriptinotallow="使用 SentencePiece 訓練分詞器模型。",
        formatter_class=argparse.RawTextHelpFormatter
    )


    parser.add_argument(
        "input_file",
        type=str,
        help="訓練語料的路徑 (e.g., data/cleaned_wiki_full.txt)"
    )


    parser.add_argument(
        "model_prefix",
        type=str,
        help="訓練模型文件的輸出前綴 (e.g., workdir/spm_wiki)"
    )


    parser.add_argument(
        "vocab_size",
        type=int,
        help="詞匯表大小 (e.g., 32000)"
    )


    parser.add_argument(
        "--model_type",
        type=str,
        default='bpe',
        choices=['bpe', 'unigram', 'char', 'word'],
        help="分詞算法類型 (默認: bpe)"
    )


    parser.add_argument(
        "--character_coverage",
        type=float,
        default=0.9995,
        help="字符覆蓋率,范圍 [0-1]。對于小詞表(8K),建議用0.99或更小"
    )


    args = parser.parse_args()


    print("\n" + "="*50)
    print("SentencePiece 分詞器訓練程序")
    print("="*50)
    print(f"輸入語料: {args.input_file}")
    print(f"輸出模型前綴: {args.model_prefix}")
    print(f"詞匯表大小: {args.vocab_size}")
    print(f"分詞算法: {args.model_type}")
    print(f"字符覆蓋率: {args.character_coverage}")
    print("="*50 + "\n")


    train_spm_model(
        args.input_file, 
        args.model_prefix, 
        args.vocab_size,
        args.model_type,
        args.character_coverage
    )




if __name__ == "__main__":
    main()
開始訓練SentencePiece分詞器...
   輸入語料: data/cleaned_wiki_full.txt
   輸出模型前綴: workdir/spm_wiki_16k
   詞匯表大小: 16000
   語料文件大小: 1697.54 MB
Counting lines: 1256186it [00:05, 230354.42it/s]
   語料總行數(shù) (文章數(shù)): 1256186


--- SentencePiece 訓練參數(shù) ---
--input=data/cleaned_wiki_full.txt
--model_prefix=workdir/spm_wiki_16k
--vocab_size=16000
--model_type=bpe
--character_coverage=0.9995
--num_threads=16
--bos_id=0
--eos_id=1
--unk_id=2
--pad_id=-1 
------------------------------


? 正在啟動訓練... 請注意觀察 SentencePiece 自身的進度輸出。
sentencepiece_trainer.cc(178) LOG(INFO) Running command: --input=data/cleaned_wiki_full.txt --model_prefix=workdir/spm_colinai_16000 --vocab_size=16000 --model_type=bpe --character_coverage=0.9995 --num_threads=16 --bos_id=0 --eos_id=1 --unk_id=2 --pad_id=-1 
sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : 
trainer_spec {
  input: data/cleaned_wiki_full.txt
  input_format: 
  model_prefix: workdir/spm_colinai_16000
  model_type: BPE
  vocab_size: 16000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  seed_sentencepieces_file: 
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 2
  bos_id: 0
  eos_id: 1
  pad_id: -1
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piece: <pad>
  unk_surface:  ? 
  enable_differential_privacy: 0
  differential_privacy_noise_level: 0
  differential_privacy_clipping_threshold: 0
}
normalizer_spec {
  name: nmt_nfkc
  add_dummy_prefix: 1
  remove_extra_whitespaces: 1
  escape_whitespaces: 1
  normalization_rule_tsv: 
}
denormalizer_spec {}
trainer_interface.cc(355) LOG(INFO) SentenceIterator is not specified. Using MultiFileSentenceIterator.
trainer_interface.cc(186) LOG(INFO) Loading corpus: data/cleaned_wiki_full.txt
trainer_interface.cc(382) LOG(WARNING) Found too long line (18615 > 4192).
trainer_interface.cc(384) LOG(WARNING) Too long lines are skipped in the training.
trainer_interface.cc(385) LOG(WARNING) The maximum length can be changed with --max_sentence_length=<size> flag.
trainer_interface.cc(411) LOG(INFO) Loaded all 528882 sentences
trainer_interface.cc(418) LOG(INFO) Skipped 99211 too long sentences.
trainer_interface.cc(427) LOG(INFO) Adding meta_piece: <s>
trainer_interface.cc(427) LOG(INFO) Adding meta_piece: </s>
trainer_interface.cc(427) LOG(INFO) Adding meta_piece: <unk>
trainer_interface.cc(432) LOG(INFO) Normalizing sentences...
trainer_interface.cc(541) LOG(INFO) all chars count=281809036
trainer_interface.cc(552) LOG(INFO) Done: 99.95% characters are covered.
trainer_interface.cc(562) LOG(INFO) Alphabet size=8686
trainer_interface.cc(563) LOG(INFO) Final character coverage=0.9995
trainer_interface.cc(594) LOG(INFO) Done! preprocessed 528882 sentences.
trainer_interface.cc(600) LOG(INFO) Tokenizing input sentences with whitespace: 528882
trainer_interface.cc(611) LOG(INFO) Done! 3885388
.....

05、原始文本轉(zhuǎn)為Token ID 序列

在訓練大型語言模型的準備階段,將海量文本語料轉(zhuǎn)化為模型可處理的數(shù)字格式至關(guān)重要。本次將原始文本語料編碼為整數(shù) Token ID 序列。為了克服單次加載大文件的內(nèi)存限制,腳本采用了分塊讀取機制,支持以自定義大小逐塊處理語料。所有 Token ID 最終被匯總并轉(zhuǎn)化為高效率的 torch.int32 PyTorch 張量,直接存儲為 .pt 文件。這不僅優(yōu)化了數(shù)據(jù)格式,方便后續(xù) PyTorch DataLoader 快速讀取,同時也提供了關(guān)鍵的統(tǒng)計信息和完整性驗證,是構(gòu)建 LLM 數(shù)據(jù)集的穩(wěn)定且高性能的預(yù)處理方案。

import sys
import torch
import sentencepiece as spm
import argparse
from tqdm import tqdm
import os
import numpy as np


# python scripts/preprocess_data.py workdir/spm_wiki.model data/cleaned_wiki_full.txt workdir/wiki_tokens.pt




def preprocess(sp_model_path: str, 
               corpus_path: str, 
               output_path: str,
               chunk_size_mb: int = 50):
    """
    分塊讀取語料,編碼為 Token ID,并保存為 PyTorch 文件。


    參數(shù):
        sp_model_path: SentencePiece 模型文件路徑
        corpus_path: 輸入語料文件路徑
        output_path: 輸出 .pt 文件路徑
        chunk_size_mb: 每次處理的文本大?。∕B),默認 50MB
    """
    # 驗證文件存在
    if not os.path.exists(sp_model_path):
        print(f"錯誤:分詞器模型文件未找到: {sp_model_path}")
        sys.exit(1)


    if not os.path.exists(corpus_path):
        print(f"錯誤:語料文件未找到: {corpus_path}")
        sys.exit(1)


    # 加載分詞器
    try:
        sp = spm.SentencePieceProcessor(model_file=sp_model_path)
        vocab_size = sp.get_piece_size()
        print(f"   分詞器加載成功")
        print(f"   詞匯表大小: {vocab_size}")
        print(f"   特殊 Token: BOS={sp.bos_id()}, EOS={sp.eos_id()}, UNK={sp.unk_id()}, PAD={sp.pad_id()}")
    except Exception as e:
        print(f"加載分詞器失敗: {e}")
        sys.exit(1)


    # 確保輸出目錄存在
    output_dir = os.path.dirname(output_path)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)


    print(f"\n 開始處理語料...")
    print(f"   輸入文件: {corpus_path}")
    print(f"   輸出文件: {output_path}")
    print(f"   塊大小: {chunk_size_mb} MB\n")


    # 計算總大小用于進度條
    total_bytes = os.path.getsize(corpus_path)
    chunk_size_bytes = chunk_size_mb * 1024 * 1024


    token_ids = []
    tokens_processed = 0
    chunks_processed = 0


    try:
        with open(corpus_path, 'r', encoding='utf-8') as f:
            with tqdm(total=total_bytes, unit='B', unit_scale=True, desc="? 編碼語料") as pbar:


                while True:
                    chunk = f.read(chunk_size_bytes)
                    if not chunk:
                        break


                    # 直接編碼(cleaned_wiki_full.txt 已經(jīng)過清洗)
                    ids = sp.encode(chunk, out_type=int)
                    token_ids.extend(ids)


                    # 更新進度條
                    bytes_read = len(chunk.encode('utf-8'))
                    pbar.update(bytes_read)


                    tokens_processed += len(ids)
                    chunks_processed += 1


                    # 定期顯示進度信息
                    if chunks_processed % 10 == 0:
                        pbar.set_postfix({
                            'chunks': chunks_processed,
                            'tokens': f'{tokens_processed:,}'
                        })


        print(f"\n 編碼完成")
        print(f"   處理塊數(shù): {chunks_processed}")
        print(f"   總 Token 數(shù): {tokens_processed:,}")


        # 轉(zhuǎn)換為 PyTorch 張量
        print(f"\n轉(zhuǎn)換為張量并保存...")
        final_tensor = torch.tensor(token_ids, dtype=torch.int32)


        print(f"   張量形狀: {final_tensor.shape}")
        print(f"   張量大小: {final_tensor.numel():,}")
        print(f"   數(shù)據(jù)類型: {final_tensor.dtype}")
        print(f"   占用內(nèi)存: {final_tensor.numel() * 4 / (1024**3):.2f} GB")


        # 驗證 Token ID 范圍
        min_id = final_tensor.min().item()
        max_id = final_tensor.max().item()
        print(f"   Token ID 范圍: [{min_id}, {max_id}]")


        if max_id >= vocab_size or min_id < 0:
            print(f"   警告: 檢測到越界 Token ID!")
            print(f"   詞匯表大小: {vocab_size}")
            print(f"   最大 ID: {max_id}")


        # 保存張量
        torch.save(final_tensor, output_path)
        file_size_mb = os.path.getsize(output_path) / (1024 ** 2)
        print(f"\nToken ID 已保存到 {output_path}")
        print(f"   文件大小: {file_size_mb:.2f} MB")


        # 驗證保存的文件
        print(f"\n驗證保存的文件...")
        loaded_tensor = torch.load(output_path)
        print(f"   加載成功,形狀: {loaded_tensor.shape}")
        print(f"   是否相同: {torch.equal(final_tensor, loaded_tensor)}")


        print(f"\n? 預(yù)處理完成!")


    except Exception as e:
        print(f"\n處理過程中出錯: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)




def main():
    parser = argparse.ArgumentParser(
        descriptinotallow="將清洗后的文本語料轉(zhuǎn)換為 Token ID 二進制文件。",
        formatter_class=argparse.RawTextHelpFormatter
    )


    parser.add_argument(
        "model_path",
        type=str,
        help="SentencePiece 模型文件路徑 (e.g., workdir/spm_wiki.model)"
    )


    parser.add_argument(
        "corpus_path",
        type=str,
        help="輸入語料文件路徑 (e.g., data/cleaned_wiki_full.txt)"
    )


    parser.add_argument(
        "output_path",
        type=str,
        help="輸出 Token ID 文件路徑 (e.g., workdir/wiki_tokens.pt)"
    )


    parser.add_argument(
        "--chunk_size",
        type=int,
        default=50,
        help="每次處理的文本大?。∕B),默認 50MB。更大的塊更快,但占用更多內(nèi)存。"
    )


    args = parser.parse_args()


    print("\n" + "="*60)
    print("數(shù)據(jù)預(yù)處理程序 - 文本到 Token ID")
    print("="*60)
    print(f"SentencePiece 模型: {args.model_path}")
    print(f"輸入語料: {args.corpus_path}")
    print(f"輸出文件: {args.output_path}")
    print(f"塊大小: {args.chunk_size} MB")
    print("="*60 + "\n")


    preprocess(
        args.model_path,
        args.corpus_path,
        args.output_path,
        args.chunk_size
    )




if __name__ == "__main__":
    main()

06、進行模型預(yù)訓練

"""
GPT 高性能訓練腳本
"""


from __future__ import annotations
import sys
import os
import math
import json
from datetime import datetime
from typing import Optional


import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import sentencepiece as spm
from tqdm import tqdm


# ==================== 配置參數(shù) ====================
class Config:
    BLOCK_SIZE = 512 #256
    BATCH_SIZE = 32 #64
    GRAD_ACCUM_STEPS = 4 #1
    MODEL_DIM = 384 #256
    N_LAYERS = 5 #2
    NUM_HEADS = 6 #4
    HEAD_DIM = MODEL_DIM // NUM_HEADS
    FFN_DIM = MODEL_DIM * 4
    VOCAB_SIZE = None


    EPOCHS = 1
    MAX_STEPS = 10000 # 此處根據(jù)自己的硬件和時間定義步數(shù)
    WARMUP_STEPS = 500
    LR = 1e-4
    MIN_LR = 1e-5
    WEIGHT_DECAY = 0.01
    GRAD_CLIP = 1.0
    DROPOUT = 0.1


    CHECKPOINT_EVERY = 5000
    LOG_EVERY = 100


    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    CHECKPOINT_DIR = "./checkpoints"
    LATEST_CHECKPOINT = "latest_checkpoint.pth"


    NUM_WORKERS = 8
    SEED = 42


    # 啟用 bfloat16 (推薦用于現(xiàn)代 GPU)
    DTYPE = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else torch.float16


CFG = Config()


if CFG.DEVICE == 'cuda':
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.cuda.empty_cache()
    # 檢查是否使用了 bfloat16
    if CFG.DTYPE == torch.bfloat16:
        print("使用 bfloat16 混合精度 (推薦)")
    else:
        print("使用 float16 混合精度")


# ==================== 工具函數(shù) ====================
def print_gpu_memory():
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / (1024**3)
        reserved = torch.cuda.memory_reserved() / (1024**3)
        print(f"GPU顯存: {allocated:.2f}GB / {reserved:.2f}GB")


def set_seed(seed: int):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


set_seed(CFG.SEED)


# ==================== 數(shù)據(jù)集 ====================
class TextDataset(Dataset):
    def __init__(self, token_ids: torch.Tensor, block_size: int):
        self.ids = token_ids.long()
        self.block_size = block_size


    def __len__(self):
        return max(0, self.ids.size(0) - self.block_size)


    def __getitem__(self, idx):
        x = self.ids[idx: idx + self.block_size]
        y = self.ids[idx + 1: idx + 1 + self.block_size]
        return x, y


# ==================== RoPE 位置編碼  ====================
class RotaryPositionalEmbedding(nn.Module):
    """RoPE 實現(xiàn)"""
    def __init__(self, head_dim: int, max_seq_len: int = 2048):
        super().__init__()
        self.head_dim = head_dim
        assert head_dim % 2 == 0, "head_dim must be even"


        # 基頻:theta_i = 10000^(-2i/d)
        inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
        self.register_buffer("inv_freq", inv_freq)


        self.max_seq_len = max_seq_len
        self._seq_len_cached = max_seq_len
        self._cos_cached = None
        self._sin_cached = None
        self._update_cos_sin_cache(max_seq_len, device=self.inv_freq.device)


    def _update_cos_sin_cache(self, seq_len: int, device: torch.device):
        if seq_len == self._seq_len_cached and self._cos_cached is not None:
            return


        # m: (seq_len,), theta_i: (head_dim//2,)
        m = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", m, self.inv_freq)  # (seq_len, head_dim//2)


        # 構(gòu)造完整的旋轉(zhuǎn)矩陣(每個復(fù)數(shù)對重復(fù))
        emb = torch.cat([freqs, freqs], dim=-1)  # (seq_len, head_dim)


        cos = emb.cos()[None, None, :, :]  # (1, 1, seq_len, head_dim)
        sin = emb.sin()[None, None, :, :]  # (1, 1, seq_len, head_dim)


        self._cos_cached = cos
        self._sin_cached = sin
        self._seq_len_cached = seq_len


    def forward(self, seq_len: int, device: Optional[torch.device] = None):
        if device is None:
            device = self.inv_freq.device
        self._update_cos_sin_cache(seq_len, device=device)
        return self._cos_cached.to(device), self._sin_cached.to(device)


def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    """應(yīng)用RoPE旋轉(zhuǎn)"""
    # x: (B, H, T, D), cos/sin: (1, 1, T, D)
    # 使用(x, y) -> (x*cos-y*sin, x*sin+y*cos)
    return (x * cos) + (_rotate_half(x) * sin)


def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    """將向量旋轉(zhuǎn)90度"""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# ==================== Flash Attention ====================
class FlashAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, attn_dropout: float = 0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        assert embed_dim % num_heads == 0
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5


        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.attn_dropout = nn.Dropout(attn_dropout)
        self.rope = RotaryPositionalEmbedding(self.head_dim)


    def forward(self, x: torch.Tensor, causal_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        B, T, C = x.shape
        assert T <= self.rope.max_seq_len, f"Seq len {T} exceeds max {self.rope.max_seq_len}"


        qkv = self.qkv(x)
        qkv = qkv.view(B, T, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.unbind(dim=2)
        q = q.permute(0, 2, 1, 3)  # (B, H, T, D)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)


        # 應(yīng)用RoPE
        cos, sin = self.rope(T, device=x.device)
        q = apply_rotary_emb(q, cos, sin)
        k = apply_rotary_emb(k, cos, sin)


        # 注意力計算
        # 注意:這里如果使用 torch.nn.functional.scaled_dot_product_attention 配合 torch.compile 會更快
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        if causal_mask is not None:
            scores = scores.masked_fill(causal_mask == 0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        attn = self.attn_dropout(attn)
        out = torch.matmul(attn, v)
        out = out.permute(0, 2, 1, 3).contiguous().view(B, T, C)
        return self.out_proj(out)


# ==================== 前饋網(wǎng)絡(luò) ====================
class GLU(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim * 2)


    def forward(self, x):
        x, gates = self.linear(x).chunk(2, dim=-1)
        return x * torch.nn.functional.silu(gates)


class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            GLU(dim, hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout),
        )


    def forward(self, x):
        return self.net(x)


# ==================== Transformer Block ====================
class TransformerBlock(nn.Module):
    def __init__(self, dim: int, num_heads: int, ffn_dim: int, dropout: float = 0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = FlashAttention(dim, num_heads, attn_dropout=dropout)
        self.ln2 = nn.LayerNorm(dim)
        self.ff = FeedForward(dim, ffn_dim, dropout)


    def forward(self, x, causal_mask=None):
        x = x + self.attn(self.ln1(x), causal_mask)
        x = x + self.ff(self.ln2(x))
        return x


# ==================== GPT 模型(已移除 pos_emb) ====================
class GPTModel(nn.Module):
    def __init__(self, vocab_size: int, block_size: int, dim: int = CFG.MODEL_DIM,
                 num_layers: int = CFG.N_LAYERS, num_heads: int = CFG.NUM_HEADS,
                 ffn_dim: int = CFG.FFN_DIM, dropout: float = CFG.DROPOUT,
                 tie_weights: bool = True):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, dim)
        # self.pos_emb = nn.Embedding(block_size, dim) # 移除:與 RoPE 沖突
        self.dropout = nn.Dropout(dropout)


        self.blocks = nn.ModuleList([
            TransformerBlock(dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)
        ])


        self.ln_final = nn.LayerNorm(dim)
        self.lm_head = nn.Linear(dim, vocab_size, bias=False)


        if tie_weights:
            self.lm_head.weight = self.token_emb.weight


        self.block_size = block_size
        self.apply(self._init_weights)


        n_params = sum(p.numel() for p in self.parameters())
        print(f"模型參數(shù): {n_params/1e6:.2f}M")


    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)


    def forward(self, idx):
        B, T = idx.shape
        assert T <= self.block_size, f"Seq len {T} exceeds block_size {self.block_size}"


        token_emb = self.token_emb(idx)


       
        x = self.dropout(token_emb) # token embedding


        causal_mask = torch.tril(torch.ones(T, T, device=idx.device, dtype=torch.bool))[None, None, :, :]
        for block in self.blocks:
            x = block(x, causal_mask)
        x = self.ln_final(x)
        logits = self.lm_head(x)
        return logits


# ==================== 檢查點管理 ====================
def save_checkpoint(model, optimizer, scaler, lr_scheduler, step: int, loss: float, config_dict: dict):
    os.makedirs(CFG.CHECKPOINT_DIR, exist_ok=True)
    checkpoint_path = os.path.join(CFG.CHECKPOINT_DIR, CFG.LATEST_CHECKPOINT)
    state = {
        'step': step,
        'loss': loss,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'config': config_dict,
        'torch_rng_state': torch.get_rng_state(),
        'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
    }


    if scaler is not None and hasattr(scaler, "state_dict"):
        state['scaler_state_dict'] = scaler.state_dict()


    if lr_scheduler is not None:
        state['lr_scheduler_state_dict'] = {
            'current_step': lr_scheduler.current_step,
            'warmup_steps': lr_scheduler.warmup_steps,
            'total_steps': lr_scheduler.total_steps,
            'base_lr': lr_scheduler.base_lr,
            'min_lr': lr_scheduler.min_lr,
        }


    torch.save(state, checkpoint_path)


    try:
        with open(os.path.join(CFG.CHECKPOINT_DIR, "config.json"), "w", encoding="utf-8") as f:
            json.dump(config_dict, f, indent=2)
    except Exception:
        pass


    print(f" 檢查點已保存: {checkpoint_path} (step {step}, loss {loss:.4f})")


def load_checkpoint(checkpoint_path: str, model, optimizer, scaler, lr_scheduler):
    if not os.path.exists(checkpoint_path):
        return None


    checkpoint = torch.load(checkpoint_path, map_locatinotallow=CFG.DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])


    if checkpoint.get('scaler_state_dict') is not None and scaler is not None:
        try:
            scaler.load_state_dict(checkpoint['scaler_state_dict'])
        except Exception as e:
            print(f"無法恢復(fù)scaler: {e}")


    if checkpoint.get('lr_scheduler_state_dict') is not None and lr_scheduler is not None:
        try:
            sched_state = checkpoint['lr_scheduler_state_dict']
            lr_scheduler.current_step = sched_state['current_step']
            lr_scheduler.warmup_steps = sched_state['warmup_steps']
            lr_scheduler.total_steps = sched_state['total_steps']
            lr_scheduler.base_lr = sched_state['base_lr']
            lr_scheduler.min_lr = sched_state['min_lr']
        except Exception as e:
            print(f"無法恢復(fù)lr_scheduler: {e}")


    torch.set_rng_state(checkpoint['torch_rng_state'])
    if torch.cuda.is_available() and checkpoint.get('cuda_rng_state') is not None:
        torch.cuda.set_rng_state(checkpoint['cuda_rng_state'])


    print(f"檢查點已加載: {checkpoint_path}")
    print(f"    Step: {checkpoint['step']}, Loss: {checkpoint['loss']:.4f}")
    return checkpoint['step']


# ==================== 學習率調(diào)度器 ====================
class WarmupCosineScheduler:
    def __init__(self, optimizer, warmup_steps: int, total_steps: int, base_lr: float, min_lr: float):
        self.optimizer = optimizer
        self.warmup_steps = max(0, int(warmup_steps))
        self.total_steps = max(1, int(total_steps))
        self.base_lr = base_lr
        self.min_lr = min_lr
        self.current_step = 0


    def get_lr(self, step: int = None) -> float:
        """計算給定step的學習率(不修改optimizer)"""
        if step is None:
            step = self.current_step


        if step < self.warmup_steps and self.warmup_steps > 0:
            return self.base_lr * (step / float(self.warmup_steps))
        else:
            denom = max(1, (self.total_steps - self.warmup_steps))
            progress = (step - self.warmup_steps) / denom
            progress = min(1.0, max(0.0, progress))
            return self.min_lr + (self.base_lr - self.min_lr) * 0.5 * (1.0 + math.cos(math.pi * progress))


    def step(self):
        """執(zhí)行一次步長更新"""
        lr = self.get_lr(self.current_step)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        self.current_step += 1
        return lr


# ==================== 訓練循環(huán) ====================
def train(model: nn.Module, train_loader: DataLoader, epochs: int = CFG.EPOCHS, resume: bool = False):
    # 檢測fused優(yōu)化器支持
    fused = False
    try:
        fused = torch.cuda.is_available() and ("fused" in torch.optim.AdamW.__init__.__code__.co_varnames)
    except Exception:
        fused = False


    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=CFG.LR,
        betas=(0.9, 0.95),
        weight_decay=CFG.WEIGHT_DECAY,
        fused=fused
    )


    # 使用配置中的 DTYPE
    scaler = torch.cuda.amp.GradScaler(enabled=(CFG.DEVICE == "cuda") and (CFG.DTYPE == torch.float16))
    loss_fn = nn.CrossEntropyLoss()


    total_steps = CFG.MAX_STEPS if CFG.MAX_STEPS else len(train_loader) * epochs
    lr_scheduler = WarmupCosineScheduler(optimizer, CFG.WARMUP_STEPS, total_steps, CFG.LR, CFG.MIN_LR)


    model.train()
    start_step = 0
    best_loss = float('inf')


    checkpoint_path = os.path.join(CFG.CHECKPOINT_DIR, CFG.LATEST_CHECKPOINT)
    if resume and os.path.exists(checkpoint_path):
        loaded_step = load_checkpoint(checkpoint_path, model, optimizer, scaler, lr_scheduler)
        if loaded_step is not None:
            start_step = loaded_step


    global_step = start_step
    grad_accum_counter = 0
    accumulated_loss = 0.0


    print("\n" + "="*60)
    print("開始訓練...")
    print("="*60)
    print_gpu_memory()
    print()


    # 自動選擇是否需要 scaler.scale()
    use_scaler = (CFG.DEVICE == "cuda") and (CFG.DTYPE == torch.float16)


    for epoch in range(epochs):
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", initial=global_step % len(train_loader) if epoch == 0 else 0)
        num_batches = 0
        last_lr = None


        for batch_idx, (xb, yb) in enumerate(pbar):
            # 跳過已訓練的批次 (如果從中間恢復(fù))
            if global_step > start_step and batch_idx < (start_step % len(train_loader)):
                 continue


            xb = xb.to(CFG.DEVICE, non_blocking=True)
            yb = yb.to(CFG.DEVICE, non_blocking=True)


            with torch.cuda.amp.autocast(enabled=(CFG.DEVICE == "cuda"), dtype=CFG.DTYPE):
                logits = model(xb)
                loss = loss_fn(logits.view(-1, logits.size(-1)), yb.view(-1))
                loss_item = loss.item()
                loss = loss / CFG.GRAD_ACCUM_STEPS


            if use_scaler:
                scaler.scale(loss).backward()
            else:
                loss.backward()


            grad_accum_counter += 1
            accumulated_loss += loss_item
            num_batches += 1
            # 這里的 global_step 計數(shù)是基于數(shù)據(jù)批次的,而不是優(yōu)化器步數(shù),用于日志和檢查點
            # 真正的優(yōu)化器步數(shù)會在下面更新


            # 梯度累積:達到閾值時執(zhí)行優(yōu)化步驟
            if grad_accum_counter >= CFG.GRAD_ACCUM_STEPS:


                # 優(yōu)化器步進 (這是真正的 global_step 增長點)
                lr_scheduler.step() # 先更新 LR
                global_step += 1 # 只有進行了一次優(yōu)化器步進,才算一個 global_step


                if use_scaler:
                    scaler.unscale_(optimizer)


                # 梯度裁剪 (在 unscale 后或非 AMP 模式下)
                torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.GRAD_CLIP)


                if use_scaler:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()


                optimizer.zero_grad()
                grad_accum_counter = 0
                last_lr = lr_scheduler.get_lr(global_step) # 獲取當前步的LR


            # 日志輸出
            if global_step % CFG.LOG_EVERY == 0 or (global_step == 1):
                # accumulated_loss 是累積的原始損失, num_batches 是累積的批次數(shù)
                avg_loss = accumulated_loss / num_batches if num_batches > 0 else 0.0
                pbar.set_postfix({
                    'step': global_step,
                    'loss': f'{avg_loss:.4f}',
                    'lr': f'{last_lr:.2e}' if last_lr is not None else 'N/A'
                })
                # 重置累積值以便計算下一個 LOG_EVERY 間隔的平均損失
                accumulated_loss = 0.0
                num_batches = 0




            # 保存檢查點
            if global_step > start_step and global_step % CFG.CHECKPOINT_EVERY == 0:
                # 使用上一個日志點計算的 avg_loss
                current_avg_loss = accumulated_loss / num_batches if num_batches > 0 else loss_item


                config_dict = {
                    'vocab_size': CFG.VOCAB_SIZE,
                    'block_size': CFG.BLOCK_SIZE,
                    'model_dim': CFG.MODEL_DIM,
                    'n_layers': CFG.N_LAYERS,
                    'num_heads': CFG.NUM_HEADS,
                    'created_at': datetime.now().isoformat()
                }
                save_checkpoint(model, optimizer, scaler, lr_scheduler, global_step, current_avg_loss, config_dict)
                torch.cuda.empty_cache()


            if CFG.MAX_STEPS and global_step >= CFG.MAX_STEPS:
                break


        # 處理 epoch 結(jié)束時剩余的梯度 (如果 grad_accum_counter > 0)
        if grad_accum_counter > 0:
            if use_scaler:
                scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.GRAD_CLIP)


            if use_scaler:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()


            optimizer.zero_grad()
            lr_scheduler.step()
            global_step += 1
            grad_accum_counter = 0




        # 此時 pbar.total_loss 已累積
        if num_batches > 0:
             final_avg_loss = accumulated_loss / num_batches
        else:
             final_avg_loss = float('inf')




        if final_avg_loss < best_loss:
            best_loss = final_avg_loss
            best_path = os.path.join(CFG.CHECKPOINT_DIR, "best_model.pth")
            torch.save(model.state_dict(), best_path)
            print(f"最佳模型已保存 (loss: {best_loss:.4f})")


        print(f"\n[Epoch {epoch+1}] Avg Loss: {final_avg_loss:.4f}")


        if CFG.MAX_STEPS and global_step >= CFG.MAX_STEPS:
            break


    print("\n訓練完成!")


# ==================== 主函數(shù) ====================
def main():
    if len(sys.argv) < 4:
        print("用法: python train_20251012_v1.py workdir/spm_wiki_16k.model workdir/wiki_tokens_16k.pt models/gpt_wiki.pth [--resume]")
        sys.exit(1)


    sp_model_path, token_file_path, out_path = sys.argv[1:4]
    resume = "--resume" in sys.argv


    if not os.path.exists(token_file_path):
        print(f" Token文件不存在: {token_file_path}")
        sys.exit(1)


    # 檢查 CFG.DTYPE 是否為 bfloat16 但環(huán)境不支持
    if CFG.DTYPE == torch.bfloat16 and not torch.cuda.is_bf16_supported():
        print("警告: bfloat16 不受當前 CUDA 設(shè)備支持,自動回退到 float16。")
        CFG.DTYPE = torch.float16


    sp = spm.SentencePieceProcessor(model_file=sp_model_path)
    CFG.VOCAB_SIZE = sp.get_piece_size()


    print("="*60)
    print("GPT 語言模型訓練")
    print("="*60)
    print(f"分詞器: {sp_model_path}")
    print(f"Token文件: {token_file_path}")
    print(f"輸出模型: {out_path}")
    print(f"設(shè)備: {CFG.DEVICE}")
    print(f"\n模型配置:")
    print(f"    - VOCAB_SIZE: {CFG.VOCAB_SIZE}")
    print(f"    - BLOCK_SIZE: {CFG.BLOCK_SIZE}")
    print(f"    - MODEL_DIM: {CFG.MODEL_DIM}")
    print(f"    - N_LAYERS: {CFG.N_LAYERS}")
    print(f"    - NUM_HEADS: {CFG.NUM_HEADS}")
    print(f"\n訓練配置:")
    print(f"    - BATCH_SIZE: {CFG.BATCH_SIZE}")
    print(f"    - GRAD_ACCUM_STEPS: {CFG.GRAD_ACCUM_STEPS}")
    print(f"    - 有效BATCH_SIZE: {CFG.BATCH_SIZE * CFG.GRAD_ACCUM_STEPS}")
    print(f"    - LR: {CFG.LR}, WARMUP_STEPS: {CFG.WARMUP_STEPS}")
    print("="*60)


    print(f"\n加載Token文件: {token_file_path}")
    ids = torch.load(token_file_path)
    print(f"已加載 {ids.numel():,} tokens ({ids.numel() * ids.element_size() / (1024**3):.2f} GB)")


    dataset = TextDataset(ids, CFG.BLOCK_SIZE)
    del ids
    torch.cuda.empty_cache()


    # 改進:啟用 shuffle=True 進行預(yù)訓練
    num_workers = CFG.NUM_WORKERS
    try:
        train_loader = DataLoader(
            dataset,
            batch_size=CFG.BATCH_SIZE,
            shuffle=True, # 啟用 Shuffle
            pin_memory=(CFG.DEVICE == "cuda"),
            num_workers=num_workers,
            persistent_workers=True if num_workers > 0 else False
        )
    except Exception as e:
        print(f"DataLoader錯誤: {e}, 改用num_workers=0")
        train_loader = DataLoader(
            dataset,
            batch_size=CFG.BATCH_SIZE,
            shuffle=True,
            pin_memory=(CFG.DEVICE == "cuda"),
            num_workers=0
        )


    model = GPTModel(
        CFG.VOCAB_SIZE,
        CFG.BLOCK_SIZE,
        dim=CFG.MODEL_DIM,
        num_layers=CFG.N_LAYERS,
        num_heads=CFG.NUM_HEADS,
        ffn_dim=CFG.FFN_DIM,
        dropout=CFG.DROPOUT
    ).to(CFG.DEVICE)


    # 嘗試編譯(容錯)
    try:
        model = torch.compile(model, mode='reduce-overhead')
        print("已啟用 torch.compile() 加速")
    except Exception as e:
        print(f"跳過 torch.compile(): {e}")


    train(model, train_loader, epochs=CFG.EPOCHS, resume=resume)


    torch.save(model.state_dict(), out_path)
    print(f"\n最終模型已保存到 {out_path}")
    print_gpu_memory()


if __name__ == "__main__":
    main()

07、進行模型推理測試

import torch
from torch import nn
import sentencepiece as spm
from typing import Optional


# ==================== 配置參數(shù) (必須與訓練時一致) ====================
# 使用與訓練腳本中完全相同的配置
class Config:
    BLOCK_SIZE = 512
    # 模型尺寸參數(shù) (必須與訓練時一致)
    MODEL_DIM = 384
    N_LAYERS = 5
    NUM_HEADS = 6
    HEAD_DIM = MODEL_DIM // NUM_HEADS
    FFN_DIM = MODEL_DIM * 4 
    VOCAB_SIZE = None


    # 推理設(shè)置
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    # 推理通常使用 float32 獲得最佳兼容性和精度
    DTYPE = torch.float32 


CFG = Config()


# ==================== RoPE 位置編碼 (與訓練腳本保持一致) ====================
class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, head_dim: int, max_seq_len: int = 2048):
        super().__init__()
        self.head_dim = head_dim
        assert head_dim % 2 == 0, "head_dim must be even"
        inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
        self.register_buffer("inv_freq", inv_freq)
        self.max_seq_len = max_seq_len
        self._seq_len_cached = max_seq_len
        self._cos_cached = None
        self._sin_cached = None
        self._update_cos_sin_cache(max_seq_len, device=self.inv_freq.device)


    def _update_cos_sin_cache(self, seq_len: int, device: torch.device):
        if seq_len == self._seq_len_cached and self._cos_cached is not None:
            return
        m = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", m, self.inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        cos = emb.cos()[None, None, :, :]
        sin = emb.sin()[None, None, :, :]
        self._cos_cached = cos
        self._sin_cached = sin
        self._seq_len_cached = seq_len


    def forward(self, seq_len: int, device: Optional[torch.device] = None):
        if device is None:
            device = self.inv_freq.device
        self._update_cos_sin_cache(seq_len, device=device)
        return self._cos_cached.to(device), self._sin_cached.to(device)


def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    return (x * cos) + (_rotate_half(x) * sin)


def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# ==================== Attention, FFN, Block, Model (與訓練腳本保持一致) ====================
class FlashAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, attn_dropout: float = 0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        assert embed_dim % num_heads == 0
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        # 推理時通常不使用 Dropout,但模型結(jié)構(gòu)需要保持一致
        self.attn_dropout = nn.Dropout(attn_dropout) 
        self.rope = RotaryPositionalEmbedding(self.head_dim)


    def forward(self, x: torch.Tensor, causal_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        B, T, C = x.shape
        qkv = self.qkv(x)
        qkv = qkv.view(B, T, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.unbind(dim=2)
        q = q.permute(0, 2, 1, 3)  # (B, H, T, D)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)


        cos, sin = self.rope(T, device=x.device)
        q = apply_rotary_emb(q, cos, sin)
        k = apply_rotary_emb(k, cos, sin)


        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        # 注意:在推理時,通常使用 KV-Cache,這里簡化為完整計算
        if T > 1: # 僅在序列長度大于 1 時應(yīng)用 mask
            causal_mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool))[None, None, :, :]
            scores = scores.masked_fill(causal_mask == 0, float('-inf'))


        attn = torch.softmax(scores, dim=-1)
        # 推理時禁用 dropout
        # attn = self.attn_dropout(attn) 
        out = torch.matmul(attn, v)
        out = out.permute(0, 2, 1, 3).contiguous().view(B, T, C)
        return self.out_proj(out)


class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):
        super().__init__()
        # 必須保持與訓練腳本中完全相同的 nn.Sequential 結(jié)構(gòu)
        self.net = nn.Sequential(
            GLU(dim, hidden_dim),
            nn.Dropout(dropout),   # net.1: Dropout (必須保留,占位)
            nn.Linear(hidden_dim, dim), # net.2: Linear (與訓練時一致)
            nn.Dropout(dropout),   # net.3: Dropout (必須保留,占位)
        )


    def forward(self, x):
        # 在推理時, model.eval() 會自動禁用所有 nn.Dropout 層,但結(jié)構(gòu)不變
        return self.net(x)


# 確保 GLU 的定義如下(與訓練時一致):
class GLU(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        # GLU 內(nèi)部只有一個 nn.Linear
        self.linear = nn.Linear(in_dim, out_dim * 2)


    def forward(self, x):
        x, gates = self.linear(x).chunk(2, dim=-1)
        return x * torch.nn.functional.silu(gates)


class TransformerBlock(nn.Module):
    def __init__(self, dim: int, num_heads: int, ffn_dim: int, dropout: float = 0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = FlashAttention(dim, num_heads, attn_dropout=dropout)
        self.ln2 = nn.LayerNorm(dim)
        self.ff = FeedForward(dim, ffn_dim, dropout)


    def forward(self, x, causal_mask=None):
        x = x + self.attn(self.ln1(x), causal_mask)
        x = x + self.ff(self.ln2(x))
        return x


class GPTModel(nn.Module):
    def __init__(self, vocab_size: int, block_size: int, dim: int = CFG.MODEL_DIM,
                 num_layers: int = CFG.N_LAYERS, num_heads: int = CFG.NUM_HEADS,
                 ffn_dim: int = CFG.FFN_DIM, dropout: float = 0.0, # 推理時 dropout=0
                 tie_weights: bool = True):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, dim)
        self.dropout = nn.Dropout(dropout)


        self.blocks = nn.ModuleList([
            TransformerBlock(dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)
        ])


        self.ln_final = nn.LayerNorm(dim)
        self.lm_head = nn.Linear(dim, vocab_size, bias=False)


        if tie_weights:
            self.lm_head.weight = self.token_emb.weight


        self.block_size = block_size


    def forward(self, idx):
        B, T = idx.shape
        token_emb = self.token_emb(idx)
        x = token_emb # 推理時不使用 dropout


        causal_mask = None # Attention 模塊內(nèi)部處理 Causal Mask
        for block in self.blocks:
            x = block(x, causal_mask)
        x = self.ln_final(x)
        logits = self.lm_head(x)
        return logits




# ==================== 推理和生成函數(shù) ====================


@torch.no_grad()
def generate_text(model: GPTModel, sp: spm.SentencePieceProcessor, 
                  prompt: str, max_new_tokens: int, temperature: float = 0.8, 
                  top_k: int = 50):


    model.eval()
    device = CFG.DEVICE


    # 1. 編碼輸入
    input_ids = sp.encode_as_ids(prompt)
    if not input_ids:
        return "無法編碼輸入。"


    # 將輸入轉(zhuǎn)換為模型期望的格式 (B, T)
    x = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)


    # 2. 循環(huán)生成
    for _ in range(max_new_tokens):
        # 裁剪輸入以適應(yīng)模型的 BLOCK_SIZE
        # 在實際部署中,這里應(yīng)該使用 KV Cache,但此處簡化為完整前向傳播
        idx_cond = x if x.size(1) <= CFG.BLOCK_SIZE else x[:, -CFG.BLOCK_SIZE:]


        # 獲取 logits
        logits = model(idx_cond)


        # 只取最后一個時間步的 logits
        logits = logits[:, -1, :] 


        # 應(yīng)用溫度縮放
        logits = logits / temperature


        # 3. Top-K 采樣
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = float('-inf')


        # 計算概率并采樣
        probs = torch.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)


        # 4. 停止條件
        # 檢查是否生成了 EOS token (假設(shè) </s> 是 ID 3, 請根據(jù)您的分詞器調(diào)整)
        # 默認使用 SentencePiece 的 <eos> ID
        if idx_next.item() == sp.eos_id():
            break


        # 將新生成的 token 添加到序列中
        x = torch.cat((x, idx_next), dim=1)


        # 檢查是否達到最大序列長度 (防止溢出)
        if x.size(1) >= CFG.BLOCK_SIZE + max_new_tokens:
            break


    # 5. 解碼輸出
    output_ids = x[0].tolist()
    # 查找輸入 prompt 的長度,只解碼新生成的 token
    start_index = len(input_ids)


    return sp.decode_ids(output_ids[start_index:])




# ==================== 主執(zhí)行函數(shù) ====================


def main_infer(sp_model_path: str, model_weights_path: str):
    print("="*50)
    print(f"GPT 模型推理模式")
    print(f"設(shè)備: {CFG.DEVICE}, DTYPE: {CFG.DTYPE}")
    print("="*50)


    # 1. 加載分詞器
    try:
        sp = spm.SentencePieceProcessor(model_file=sp_model_path)
        CFG.VOCAB_SIZE = sp.get_piece_size()
        print(f"加載分詞器成功,VOCAB_SIZE: {CFG.VOCAB_SIZE}")
    except Exception as e:
        print(f"無法加載分詞器模型 {sp_model_path}: {e}")
        return


    # 2. 實例化模型
    model = GPTModel(
        vocab_size=CFG.VOCAB_SIZE,
        block_size=CFG.BLOCK_SIZE,
        dim=CFG.MODEL_DIM,
        num_layers=CFG.N_LAYERS,
        num_heads=CFG.NUM_HEADS,
        ffn_dim=CFG.FFN_DIM,
        dropout=0.0 # 推理時設(shè)置 dropout 為 0
    ).to(CFG.DEVICE).to(CFG.DTYPE)


    # 3. 加載權(quán)重
    try:
        # 檢查是否是 torch.compile 后的狀態(tài)字典
        weights = torch.load(model_weights_path, map_locatinotallow=CFG.DEVICE)


        # 如果權(quán)重是 DDP 或 torch.compile 包裝后的,需要解包
        if any(k.startswith('_orig_mod.') for k in weights.keys()):
            weights = {k.replace('_orig_mod.', ''): v for k, v in weights.items()}


        model.load_state_dict(weights, strict=True)
        print(f"成功加載模型權(quán)重: {model_weights_path}")
    except Exception as e:
        print(f"無法加載或匹配模型權(quán)重: {e}")
        # 如果加載失敗,打印預(yù)期鍵和實際鍵,方便調(diào)試
        # print("\n--- 預(yù)期模型鍵 (部分) ---")
        # print(list(model.state_dict().keys())[:5])
        # print("\n--- 載入權(quán)重鍵 (部分) ---")
        # print(list(weights.keys())[:5])
        return


    # 4. 進入交互循環(huán)
    print("\n--- 進入交互模式 ---")
    print(f"輸入 'exit' 或 'quit' 退出。")
    print(f"輸入 'config' 查看當前生成參數(shù)。")
    print("----------------------")


    max_tokens = 100
    temperature = 0.8
    top_k = 50


    while True:
        try:
            prompt = input(">>> 輸入提示詞: ")


            if prompt.lower() in ['exit', 'quit']:
                break


            if prompt.lower() == 'config':
                print(f"  Max Tokens: {max_tokens}, Temp: {temperature}, Top K: {top_k}")
                new_max = input("  設(shè)置 Max Tokens (回車跳過): ")
                new_temp = input("  設(shè)置 Temperature (回車跳過): ")
                new_k = input("  設(shè)置 Top K (回車跳過): ")


                if new_max: max_tokens = int(new_max)
                if new_temp: temperature = float(new_temp)
                if new_k: top_k = int(new_k)
                continue


            if not prompt.strip():
                continue


            print("生成中...")


            # 執(zhí)行生成
            output = generate_text(model, sp, prompt, max_tokens, temperature, top_k)


            print(f"--- 模型回復(fù) ---\n{output.strip()}")
            print("----------------")


        except KeyboardInterrupt:
            print("\n退出生成...")
            break
        except Exception as e:
            print(f"發(fā)生錯誤: {e}")




if __name__ == "__main__":
    import sys


    if len(sys.argv) != 3:
        print("用法: python infer.py <spm模型路徑> <模型權(quán)重文件路徑>")
        # 示例用法 (請根據(jù)您的實際文件路徑修改):
        # python infer.py tokenizer.model final_model.pth
        sys.exit(1)


    sp_model_path = sys.argv[1]
    model_weights_path = sys.argv[2]


    main_infer(sp_model_path, model_weights_path)

我們看到模型大概可以預(yù)測我們輸入的下一個詞,因我們訓練的參數(shù)和步數(shù)很低,模型輸出的亂七八糟!

本次總結(jié)

本次我們做了數(shù)據(jù)準備、數(shù)據(jù)清洗、分詞器訓練、模型訓練、推理等,請根據(jù)步驟進行執(zhí)行代碼,你便可以得到一個17M參數(shù)的小模型。后面我們再加大參數(shù)進行訓練,再進行監(jiān)督微調(diào)。

責任編輯:龐桂玉 來源: 寫代碼的中年人
相關(guān)推薦

2025-10-24 10:34:55

2020-09-24 11:46:03

Promise

2021-03-23 15:21:00

人工智能機器學習技術(shù)

2020-03-17 10:45:11

GitHub代碼開發(fā)者

2021-08-17 11:08:08

參數(shù)M6模型

2019-04-24 15:06:37

Http服務(wù)器協(xié)議

2024-11-04 00:24:56

2021-01-25 13:45:14

模型人工智能深度學習

2024-12-23 12:52:29

2021-06-30 07:19:36

網(wǎng)絡(luò)安全

2021-08-04 05:49:40

數(shù)據(jù)庫數(shù)時序數(shù)據(jù)庫技術(shù)

2021-10-28 09:19:29

模型人工智能Facebook

2014-09-25 09:51:29

Android App個人博客

2022-11-01 14:50:00

數(shù)據(jù)計算

2016-09-14 17:48:44

2023-04-06 08:01:30

RustMutex

2019-07-21 19:45:23

GitHub代碼開發(fā)者

2017-06-06 10:14:55

KerasTensorFlow深度學習

2021-09-26 10:47:12

預(yù)訓練模型GPT

2024-05-10 10:01:26

自動駕駛模型
點贊
收藏

51CTO技術(shù)棧公眾號