From 0577260545e1897c2674984178b4bf933c37c7d1 Mon Sep 17 00:00:00 2001 From: YuangLiu Date: Mon, 8 Sep 2025 09:37:42 +0800 Subject: [PATCH 1/2] trans script from paddlenlp --- .../create_pretraining_data.py | 436 ++++++++++++++++++ 1 file changed, 436 insertions(+) create mode 100644 examples/pre-training/tools/data_preprocess/create_pretraining_data.py diff --git a/examples/pre-training/tools/data_preprocess/create_pretraining_data.py b/examples/pre-training/tools/data_preprocess/create_pretraining_data.py new file mode 100644 index 000000000..8ad9d4ec7 --- /dev/null +++ b/examples/pre-training/tools/data_preprocess/create_pretraining_data.py @@ -0,0 +1,436 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import io +import json +import multiprocessing +import os +import re +import sys +import time + +import numpy as np +from tqdm import tqdm + +from src.tokenizers.tokenization_eb_v2 import ErnieBotTokenizer + +from paddleformers.data import indexed_dataset +from paddleformers.utils.log import logger + +try: + import nltk + + nltk_available = True +except ImportError: + nltk_available = False + +from datetime import datetime + + +def print_datetime(string): + time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + print("[" + string + "] datetime: {} ".format(time_str)) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name_or_path", type=str, required=True, help="What model to use." + ) + group = parser.add_argument_group(title="data input/output") + group.add_argument( + "--input_path", type=str, required=True, help="Path to input JSON files." + ) + group.add_argument( + "--output_prefix", + type=str, + required=True, + help="Output prefix to store output file.", + ) + group.add_argument( + "--data_format", + type=str, + default="text", + choices=["JSON"], + help="Only support json format for now. One document per line.", + ) + group.add_argument( + "--json_key", + type=str, + default="text", + help="For JSON format. Space separate listed of keys to extract from json", + ) + group.add_argument( + "--split_sentences", action="store_true", help="Split documents into sentences." + ) + + group.add_argument( + "--data_impl", type=str, default="mmap", choices=["lazy", "mmap"] + ) + + group = parser.add_argument_group(title="chinese words") + group.add_argument( + "--chinese", + action="store_true", + help="Is corpus need words segmentation step for chinese words.", + ) + group.add_argument( + "--cn_whole_word_segment", + action="store_true", + help="Is corpus need words segmentation step for chinese words WWM.", + ) + group.add_argument( + "--cn_seg_func", + type=str, + default="jieba", + choices=["lac", "seg", "jieba"], + help="Words segment function for chinese words.", + ) + group.add_argument( + "--cn_splited", + action="store_true", + help="Is chinese corpus is split in to words.", + ) + group.add_argument( + "--cn_split_dimer", + type=str, + default=" ", + help="Split dimer between chinese words.", + ) + + group = parser.add_argument_group(title="common config") + group.add_argument( + "--append_eos", + action="store_true", + help="Append an token to the end of a document.", + ) + group.add_argument( + "--log_interval", + type=int, + default=100, + help="Interval between progress updates", + ) + group.add_argument( + "--workers", type=int, default=1, help="Number of worker processes to launch" + ) + group.add_argument( + "--max_doc_num", + type=int, + default=sys.maxsize, + help="Stop when reach max_doc_num.", + ) + group.add_argument( + "--max_repeated_len", + type=int, + default=100, + help="The maximum length of the repeated characters to keep", + ) + + args = parser.parse_args() + return args + + +def lexical_analysis_fn(): + from LAC import LAC + + lac = LAC(mode="lac") + + def process(line): + words, _ = lac.run(line) + return words + + return process + + +def chinese_segmentation_fn(): + from LAC import LAC + + lac_cws = LAC(mode="seg") + + def process(line): + words = lac_cws.run(line) + return words + + return process + + +def jieba_segmentation_fn(): + import jieba + + def process(line): + words = jieba.cut(line) + return list(words) + + return process + + +def get_whole_word_mask_tokens(tokens, words, max_word_length=6): + """ + Do whole word mask on Chinese word. + First, we do Chinese word segmentation on the sequence of tokens, which are from the WordPiece tokenization. + Then, we add the '##' mark on chinese characters which are in the middle of Chinese words. + And if the tokens are not chinese characters, we just exploit the results of WordPiece tokenization as words. + Such as, + - text line : 通过利用mercer核,将样本从输入空间映射到高维特征空间,使原来没有显现的特征突现出来,取得了很好的图像分割效果。 + - the input tokens (after WordPiece): + ['通', '过', '利', '用', 'me', '##rc', '##er', '核', ',', '将', '样', '本', '从', '输', '入', '空', '间', '映', + '射', '到', '高', '维', '特', '征', '空', '间', ',', '使', '原', '来', '没', '有', '显', '现', '的', '特', '征', + '突', '现', '出', '来', ',', '取', '得', '了', '很', '好', '的', '图', '像', '分', '割', '效', '果', '。'] + - the Chinese words (after Chinese word segmentation like jieba) + ['通过', '利用', 'mercer', '核', ',', '将', '样本', '从', '输入', '空间', '映射', '到', '高维', '特征', + '空间', ',', '使', '原来', '没有', '显现', '的', '特征', '突现', '出来', ',', '取得', '了', '很', '好', + '的', '图像', '分割', '效果', '。'] + - the output whole word mask tokens: + ['通', '##过', '利', '##用', 'me', '##rc', '##er', '核', ',', '将', '样', '##本', '从', '输', '##入', + '空', '##间', '映', '##射', '到', '高', '##维', '特', '##征', '空', '##间', ',', '使', '原', '##来', + '没', '##有', '显', '##现', '的', '特', '##征', '突', '##现', '出', '##来', ',', '取', '##得', '了', + '很', '好', '的', '图', '##像', '分', '##割', '效', '##果', '。'] + + Args: + tokens(list(str)): The sequence of tokens, which are from the WordPiece tokenization. + words(list(str)): The sequence of Chinese words. + max_word_length(int, optional): + The maximum chinese character in Chinese words. It avoids too long Chinese word to be masked. + Defaults as 4. + + Returns: + new_tokens(list(str)): The new token will be done with whole word masking strategy. + + """ + + new_tokens = [] + # opt for long document + words_set = set(words) + i = 0 + while i < len(tokens): + # non-chinese character, then do word piece + if len(re.findall("[\u4E00-\u9FA5]", tokens[i])) == 0: + new_tokens.append(tokens[i]) + i += 1 + continue + + # add "##" mark on the middel tokens of Chinese words + # such as ["通过", "利用"] -> ["通", "##过", "利", "##用"] + has_add = False + for length in range(max_word_length, 0, -1): + if i + length > len(tokens): + continue + if "".join(tokens[i : i + length]) in words_set: + new_tokens.append(tokens[i]) + for lens in range(1, length): + new_tokens.append("##" + tokens[i + lens]) + i += length + has_add = True + break + + if not has_add: + new_tokens.append(tokens[i]) + i += 1 + return new_tokens + + +class IdentitySplitter(object): + def tokenize(self, *text): + return text + + +class NewlineSplitter: + def tokenize(self, text): + return text.split("\n") + + +class Converter(object): + def __init__(self, args): + self.args = args + + def initializer(self): + Converter.tokenizer = ErnieBotTokenizer.from_pretrained( + self.args.model_name_or_path + ) + if self.args.cn_whole_word_segment: + # Extend chinese char vocab for ErnieTokinzer + Converter.tokenizer.extend_chinese_char() + + # Split document to sentence. + if self.args.split_sentences: + if self.args.chinese: + Converter.splitter = NewlineSplitter() + else: + if not nltk_available: + print("NLTK is not available to split sentences.") + exit() + splitter = nltk.load("tokenizers/punkt/english.pickle") + Converter.splitter = splitter + else: + Converter.splitter = IdentitySplitter() + + # Split sentence whole words mask for chinese + if self.args.cn_whole_word_segment: + if self.args.cn_splited: + Converter.segment_func = lambda text: text.split( + self.args.cn_split_dimer + ) + else: + CHINESE_SEG_FUNC = { + "lac": lexical_analysis_fn(), + "seg": chinese_segmentation_fn(), + "jieba": jieba_segmentation_fn(), + } + Converter.segment_func = CHINESE_SEG_FUNC[self.args.cn_seg_func] + Converter.whole_word_mask = get_whole_word_mask_tokens + else: + Converter.segment_func = lambda x: x + Converter.whole_word_mask = lambda x, y: x + + def process(text): + words = Converter.segment_func(text) + # if there are two empty word, the should a split dimer in the pos + if self.args.cn_splited: + pre_dimer = False + for index, w in enumerate(words): + if pre_dimer and len(w) == 0: + words[index] = self.args.cn_split_dimer + pre_dimer = False + elif len(w) == 0: + pre_dimer = True + else: + pre_dimer = False + + tokens = Converter.tokenizer.tokenize("".join(words)) + tokens = Converter.whole_word_mask(tokens, words) + tokens = Converter.tokenizer.convert_tokens_to_ids(tokens) + return tokens + + Converter.process = process + + def remove_repeated_chars(text, max_repeated_len=100): + """ + Removes repeated characters from the given text, where the length of + the repeated characters is greater than or equal to the specified length. + + Args: + text (str): The input text from which to remove repeated characters. + length (int, optional): The minimum length of the repeated characters. Defaults to 15. + + Returns: + str: The modified text with the repeated characters removed. + """ + pattern = r"(.)\1{" + str(max_repeated_len) + ",}" + return re.sub(pattern, r"\1", text) + + def encode(self, json_line): + text = json.loads(json_line)[self.args.json_key] + text = Converter.remove_repeated_chars(text, self.args.max_repeated_len) + doc_ids = [] + for sentence in Converter.splitter.tokenize(text): + sentence_ids = Converter.process(sentence.strip()) + if len(sentence_ids) > 0: + doc_ids.append(sentence_ids) + + if len(doc_ids) > 0 and self.args.append_eos: + if Converter.tokenizer.eos_token_id is None: + logger.warning( + "{}: eos_token_id is not set, ".format(self.args.tokenizer_name) + + "please set other tokenizer " + + "or config eos_token_id or unset append_eos." + ) + else: + doc_ids[-1].append(Converter.tokenizer.eos_token_id) + + return doc_ids, len(text.encode("utf-8")) + + +def main(): + print_datetime("start") + args = get_args() + file_paths = [] + if os.path.isfile(args.input_path): + file_paths.append(args.input_path) + else: + for root, _, fs in os.walk(args.input_path): + for f in fs: + file_paths.append(os.path.join(root, f)) + + convert = Converter(args) + + # Try tokenizer is available + sample_tokenizer = ErnieBotTokenizer.from_pretrained(args.model_name_or_path) + if sample_tokenizer.vocab_size < 2**16 - 1: + save_dtype = np.uint16 + else: + save_dtype = np.int32 + + pool = multiprocessing.Pool(args.workers, initializer=convert.initializer) + + output_ids_files = args.output_prefix + ".bin" + output_idx_files = args.output_prefix + ".idx" + builder = indexed_dataset.make_builder(output_ids_files, args.data_impl, save_dtype) + + file_paths.sort() + + step = 0 + total_bytes_processed = 0 + startup_start = time.time() + for file_path in tqdm(file_paths): + if file_path.endswith(".zst"): + import zstandard + + cctx = zstandard.ZstdDecompressor() + fh = open(file_path, "rb") + text = io.BufferedReader(cctx.stream_reader(fh)) + elif file_path.endswith(".jsonl"): + text = open(file_path, "r", encoding="utf-8") + else: + print("Unexpected data format, skipped %s" % file_path) + continue + + encoded_docs = pool.imap(convert.encode, text, 256) + print("Processing %s" % file_path) + for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1): + step += 1 + total_bytes_processed += bytes_processed + if len(doc) == 0: + continue + + for sentence in doc: + sentence_len = len(sentence) + if sentence_len == 0: + continue + builder.add_item(sentence) + + builder.end_document() + + if step % args.log_interval == 0: + current = time.time() + elapsed = current - startup_start + mbs = total_bytes_processed / elapsed / 1024 / 1024 + print( + f"Processed {step} documents", + f"({step/elapsed:.2f} docs/s, {mbs:.4f} MB/s).", + file=sys.stderr, + ) + if step >= args.max_doc_num: + break + + if step >= args.max_doc_num: + break + + pool.close() + print("Saving tokens to files...") + builder.finalize(output_idx_files) + print_datetime("end") + + +if __name__ == "__main__": + main() From c864ef93dec8b3e6475bfc9710333466c6046b31 Mon Sep 17 00:00:00 2001 From: YuangLiu Date: Mon, 8 Sep 2025 09:43:21 +0800 Subject: [PATCH 2/2] update --- .../pre-training/tools/preprocess/README.md | 255 ++++++++++++++++++ .../create_pretraining_data.py | 5 +- .../tools/preprocess/docs/CLUECorpus2020.md | 12 + .../tools/preprocess/docs/CLUECorpusSmall.md | 76 ++++++ .../tools/preprocess/docs/OpenWebText2.md | 42 +++ .../tools/preprocess/docs/WuDaoCorpusBase.md | 101 +++++++ .../pre-training/tools/preprocess/merge.py | 104 +++++++ .../tools/preprocess/trans_to_json.py | 167 ++++++++++++ .../tools/preprocess/words_segmentation.py | 223 +++++++++++++++ 9 files changed, 983 insertions(+), 2 deletions(-) create mode 100644 examples/pre-training/tools/preprocess/README.md rename examples/pre-training/tools/{data_preprocess => preprocess}/create_pretraining_data.py (99%) create mode 100644 examples/pre-training/tools/preprocess/docs/CLUECorpus2020.md create mode 100644 examples/pre-training/tools/preprocess/docs/CLUECorpusSmall.md create mode 100644 examples/pre-training/tools/preprocess/docs/OpenWebText2.md create mode 100644 examples/pre-training/tools/preprocess/docs/WuDaoCorpusBase.md create mode 100644 examples/pre-training/tools/preprocess/merge.py create mode 100644 examples/pre-training/tools/preprocess/trans_to_json.py create mode 100644 examples/pre-training/tools/preprocess/words_segmentation.py diff --git a/examples/pre-training/tools/preprocess/README.md b/examples/pre-training/tools/preprocess/README.md new file mode 100644 index 000000000..fbeb3422e --- /dev/null +++ b/examples/pre-training/tools/preprocess/README.md @@ -0,0 +1,255 @@ +# PaddleNLP 预训练数据流程 + +本示例致力于打造基于 PaddleNLP 预训练模型的最佳实践。 + + +我们将预训练数据过程划分为以下部分 + +- 原始数据转换,原始文本转换为 jsonl 的 json 字符串格式。 +- 数据 ID 化,断句、分词、tokenize 转化为 token id 格式。 +- 训练 index 文件生成,生成 train、valid、test 的每个样本索引。 +- token 动态 mask(可选),python 层实时 mask 文本。 + +本目录下主要包含一下文件: +``` +├── create_pretraining_data.py +├── merge.py +├── trans_to_json.py +├── words_segmentation.py +└── README.md +``` + +### 环境依赖 + + - tqdm + - numpy + - pybind11 + - fast_dataindex + - lac (可选) + - zstandard (可选) + +安装命令`pip install tqdm numpy pybind11 fast_dataindex lac zstandard`。另,部分功能需要`g++>=4.8`编译支持 + + +## 训练全流程数据 Pipeline + +飞桨是自主研发、功能完备、开源开放的产业级深度学习平台,集深度学习核心训练和推理框架、基础模型库、端到端开发套件和丰富的工具组件于一体 + +| 步骤 | 阶段                      | 数据格式 | 样例 | +|-----------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| 0️⃣初始状态 | - | 原始数据:
**每个 doc 之间用空行间隔开**
- 中文,默认每句换行符,作为句子结束。
- 英文,默认使用 nltk 判断句子结束 | ```飞桨是功能完备、开源开放的产业级深度学习平台。```
```飞桨拥有核心训练和推理框架、基础模型库。```

```PaddleNLP是自然语言处理领域的优秀工具。``` | +| 1️⃣原始数据转换
`trans_to_json.py` | 预处理
输入:0️⃣初始状态
输出:jsonl | jsonl 格式:每个 doc 对应一行 json 字符串 | ```{"text": "飞桨是功能完备、开源开放的产业级深度学习平台。飞桨拥有..."}```
```{"text": "PaddleNLP是自然语言..."}``` | +| ❇️(**可选**)数据中文分词
`words_segmentation.py` | 语料分词:中文 WWM
输入:jsonl
输出:0️⃣初始状态 | 将 jsonl 格式的数据,恢复成分词后的原始格式数据
| ```飞桨 是 功能 完备、开源 开放的 产业级 深度学习 平台。```
```飞桨 拥有 核心 训练和推理 框架、基础 模型库。```

```PaddleNLP 是 自然语言处理领域 的 优秀工具。``` | +| 2️⃣数据 ID 化
`create_pretrain_data.py` | 预处理 | bin 格式:数据 id 化后的 token id
idx 格式:数据句子、文章位置索引 | - | +| 3️⃣训练 index 文件生成 | 训练启动 | npy 格式:
根据训练步数 max_steps 生成
train、valid、test 的每个样本索引文件 | - | +| 4️⃣token 动态 mask(可选) | Dataset 取数据 | 无 | - | + + +注意: +- **❇️(**可选**)数据中文分词** 是中文预训练做 WWM 的可选步骤 + - 当你的数据比较少时,分词耗时较少,不需要分词步骤。直接在`create_pretrain_data.py`步骤中分词即可。 + - 目的是为了提前分词,加快后续数据 ID 转化步骤。 + - 如果这里输入的是 jsonl 格式文件,最好为多文件,`trans_to_json.py` 时候开启`no-merge`选项。 + - 当你的数据集比较大,或者需要尝试多次转换数据的时候,提前分词可以避免`create_pretrain_data.py`时每次都运行一次分词程序。 +- 转换后,需要重新进行步骤 1️⃣`原始数据转换 trans_to_json.py`,最后2️⃣`数据 ID 化`步骤设置`--cn_splited=True`参数。 +- 2️⃣`数据 ID 化`也可以在转化 ID 的同时,一起实现分词。不需要❇️`数据中文分词`步骤。 + + +## 数据教程汇总 + +针对目前开源的数据集,PaddleNLP 提供了详细的数据教程,点击对应数据集的链接,即可开始进行数据制作: + +| 名称 | 文本类型 | 纯文本大小 | 适配模型 | +|--------------------------------------------------|----------|------------|----------| +| [CLUECorpusSmall](./docs/CLUECorpusSmall.md) | 中文 | 14GB | Llama | +| [OpenWebText2](./docs/OpenWebText2.md) | 英文 | 70GB | Llama | +| [WuDaoCorpus2.0 Base](./docs/WuDaoCorpusBase.md) | 中文 | 200GB | Llama | +| [CLUECorpus2020](./docs/CLUECorpus2020.md) | 中文 | 200GB | Llama | + +## 预训练详细准备 + +下面以 ziya-llama-13b-v1预训练为例,简要介绍一下预训练的全流程。 + +### 原始数据 +首先下载样例数据: +``` +mkdir data && cd data +wget https://bj.bcebos.com/paddlenlp/models/transformers/data_tools/baike.txt +cd .. +``` + +### 原始数据转换 jsonl 格式 +使用`trans_to_json.py`转化为 json 串格式,下面是脚本的使用说明 +``` +optional arguments: + -h, --help show this help message and exit + --input_path INPUT_PATH + Path to you raw files. Folder or file path. + 必须设置,可以是文件夹或者单个文件。文件夹中的目录默认最多搜索两层子目录。 + --output_path OUTPUT_PATH + Path to save the output json files. + 必须设置,输出文件的名字。 + --json_key JSON_KEY The content key of json file. + 建议不修改,默认的key是text + --doc_spliter DOC_SPLITER + Spliter between documents. We will strip the line, if you use blank line to split doc, leave it blank. + 根据实际情况修改,默认空行作为文章换行符。 + --min_doc_length MIN_DOC_LENGTH + Minimal char of a document. + 可选。过滤掉长度多短的文章,默认值10 + --workers WORKERS Number of worker processes to launch + 可选。多进程转化文件,适用于 input_path 中包含的文件数据较多的情况。每个文件,分配给不同worker处理 + --log_interval LOG_INTERVAL + Interval between progress updates. + 可选。此处的interval是值处理完文件个数的间隔。 + --no-merge Don't merge the file. + 可选。默认不开启这个选项,默认每个文件转换的jsonl文本,会拼接成到同一个文件。 + --no-shuffle Don't shuffle the file. + 可选。默认不开启这个选项,默认对处理完进行shuffle。 +``` +根据说明,我们使用下面简单命令,可以得到`baike_sample.jsonl`文件。此处,我们对文章所有 doc 进行了 shuffle。 +```shell +python trans_to_json.py --input_path ./data --output_path baike_sample +``` + +```shell +#查看数据 +head -1 baike_sample.jsonl +{"text": "中国效仿西方发展工业的过程,于中华民国国民政府成立后至中日战争开战前夕已顺畅发展,尽管其间受到内外因素的多重干扰。尔后直至中日战争和国共战争的结束, +中国始有较为长期的和平发展时期。\n1980年代以来,邓小平政府宣布改革开放,开始实行社会主义市场经济并推行经济体制改革。中国大陆近年至2010年,GDP超过72000亿美元, +已经成为美国之后的世界第二经济大国,普遍认为中国是世界上发展速度最快的经济体,但是人均国民生产总值仍位于世界中等水平(第89位),并逐渐受到资源限制和贫富差距加 +大的制约。中华人民共和国省份中,广东为GDP最高的第一强省,浙江为人均收入最高的第一富省。中国大陆、香港、澳门、台湾之间的经济联系在全球化的过程中日益紧密。\n"} +``` + +### 数据 ID 化 +本部分,我们使用 `create_pretraining_data.py` 脚本将前面得到的 `baike_sample.jsonl` 进行 tokenize id 化处理。 +``` +optional arguments: + -h, --help show this help message and exit + --model_name_or_path MODEL_NAME_OR_PATH + What model to use. + 必须设置,如:idea-ccnl/ziya-llama-13b-v1, 可以参考已有的模型名称 https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm +data input/output: + --input_path INPUT_PATH + Path to input JSON files. + 必须设置,输入文件jsonl的目录 + --output_prefix OUTPUT_PREFIX + Output prefix to store output file. + 必须设置,输出文件的名称。 + 假设名称为XXX,则会输出 XXX.bin, XXX.idx 两个文件。 + bin文件,数据id化后的token ids; idx文件,数据句子、文章位置索引。 + --data_format {JSON} Only support json format for now. One document per line. + 不需要设置。目前默认处理jsonl数据格式 + --json_key JSON_KEY For JSON format. Space separate listed of keys to extract from json + 文本串json的key值。同前面trans_to_json.py的json_key,默认text为key + --split_sentences Split documents into sentences. + 是否需要将文章划分成句子。一般而言,GPT不需要,BERT/ERNIE模型需要 + --data_impl {mmap,lazy} + Convert the json into mmap/lazy format. + 处理后的数据格式,可选“mmap”或“lazy”,其中“mmap”格式在读入数据时会建立内存映射,“lazy”格式在读入数据时直接从文件读取。 + +chinese words: + --chinese Is corpus need words segmentation step for chinese words. + 若设置了split_sentences,并处理中文则需要设置。 + --cn_whole_word_segment + Is corpus need words segmentation step for chinese words WWM. + 可选。是否需要WWM策略。一般而言,BERT/ERNIE模型需要,GPT不需要。 + --cn_seg_func {lac,seg,jieba} + Words segment function for chinese words. + 默认jieba,jieba速度较快,lac模型更准确,计算量高。 + --cn_splited Is chinese corpus is split into words. + 分词后的文本,可选。设置此选项则,cn_seg_func不起作用。 + 例如分词后文本串 "中国 效仿 西方 发展 工业 的过 程" + --cn_split_dimer CN_SPLIT_DIMER + Split dimer between chinese words. + 配合cn_splited使用,默认空格表示分词间隔。 + +common config: + --append_eos Append an token to the end of a document. + gpt类模型专用,gpt设置此选项,表示doc结束。针对tokenizer中不包含eos_token情况,输出提示warning并且不添加。 + --log_interval LOG_INTERVAL + Interval between progress updates + 打印日志间隔,interval表示处理 文本行数/doc数的 间隔。 + --workers WORKERS Number of worker processes to launch + 处理文本id化的进程个数。 + --max_repeated_len Max length of repeated chars to keep + 最大保留重复的字符个数。 +``` +通过下面脚本转化,我们可以得到处理好的预训练数据,token ids:`baike_sample.bin`, 文章索引信息`baike_sample.idx`. + +* 针对 llama 模型 +```shell +python -u create_pretraining_data.py \ + --model_name_or_path "idea-ccnl/ziya-llama-13b-v1" \ + --input_path "baike_sample.jsonl" \ + --output_prefix "baike_sample" \ + --data_format "JSON" \ + --json_key "text" \ + --data_impl "mmap" \ + --append_eos \ + --log_interval 5 \ + --workers 40 + +``` + +* 针对 ernie 模型 +```shell +python -u create_pretraining_data.py \ + --model_name_or_path "ernie-3.0-base-zh" \ + --input_path "baike_sample.jsonl" \ + --output_prefix "baike_sample" \ + --data_format "JSON" \ + --json_key "text" \ + --split_sentences \ + --data_impl "mmap" \ + --chinese \ + --cn_whole_word_segment \ + --cn_seg_func "jieba" \ + --log_interval 5 \ + --workers 40 +``` +1. 如果您使用已经分好词的语料,可以设置 --cn_splited 为 True,同时指定--cn_split_dimer 如空格。 +2. 使用自定义词表的话,请指定 model_name 为词表所在的文件夹地址。 + +若需要预处理的文件过大,该脚本所耗费的时间可能会很长。此时可以考虑将 jsonl 文件拆分为多个小文件,并行使用 create_pretraining_data.py 进行处理,得到多个.bin & .idx 文件。 +之后使用如下 merge 脚本合并多个小的.bin & .idx 文件。 +``` +python merge.py \ + --input /root/data \ + --output-prefix /root/data/merged \ + --data_impl mmap +``` +使用说明: +``` +arguments: + --input INPUT_PATH + Path to the folder where the files to be merged. + 待合并的文件所在文件夹,文件夹内各个小文件需按merge的顺序排列,如1.bin / 1.idx,2.bin / 2.idx... + --output_prefix OUTPUT_PREFIX + Output prefix to store output file. + 合并后输出文件的名称,假设名称为XXX,则会输出 XXX.bin, XXX.idx 两个文件。 + --data_impl {mmap,lazy} + Convert the json into mmap/lazy format. + merge前后的数据格式,可选“mmap”或“lazy,各个待merge的文件需格式一致。”。 +``` + +### 预训练开始 +得到了处理好的训练数据,就可以开始模型的预训练了。简单将预处理好的数据,拷贝到 data 目录,即可开始预训练。 +```shell +mkdir data +mv ./preprocess/baike_sample* ./data +``` + +* llama 预训练请参考[预训练](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm)。 +* ernie 预训练请参考[预训练](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/slm/model_zoo/ernie-1.0/pretraining_introduction.md)。 + + +代码说明: +- 动态 mask 相关代码实现在`./data_tools/dataset_utils.py` + 用户可以根据自己的需求,灵活修改 mask 方式。具体可以参考`dataset_utils.py`中`create_masked_lm_predictions`函数。 + 可以自定义的选项有 do_whole_word_mask, favor_longer_ngram, do_permutation, geometric_dist 等, + 可以参考[Megatron](https://github.com/NVIDIA/Megatron-LM)使用这些 lm_mask 策略。 + +## 参考内容 + +注: 大部分数据流程,参考自[Megatron](https://github.com/NVIDIA/Megatron-LM),特此表达感谢。 diff --git a/examples/pre-training/tools/data_preprocess/create_pretraining_data.py b/examples/pre-training/tools/preprocess/create_pretraining_data.py similarity index 99% rename from examples/pre-training/tools/data_preprocess/create_pretraining_data.py rename to examples/pre-training/tools/preprocess/create_pretraining_data.py index 8ad9d4ec7..ac47df84a 100644 --- a/examples/pre-training/tools/data_preprocess/create_pretraining_data.py +++ b/examples/pre-training/tools/preprocess/create_pretraining_data.py @@ -23,6 +23,7 @@ import numpy as np from tqdm import tqdm + from src.tokenizers.tokenization_eb_v2 import ErnieBotTokenizer from paddleformers.data import indexed_dataset @@ -100,7 +101,7 @@ def get_args(): group.add_argument( "--cn_splited", action="store_true", - help="Is chinese corpus is split in to words.", + help="Is chinese corpus is splited in to words.", ) group.add_argument( "--cn_split_dimer", @@ -392,7 +393,7 @@ def main(): elif file_path.endswith(".jsonl"): text = open(file_path, "r", encoding="utf-8") else: - print("Unexpected data format, skipped %s" % file_path) + print("Unexpected data format, skiped %s" % file_path) continue encoded_docs = pool.imap(convert.encode, text, 256) diff --git a/examples/pre-training/tools/preprocess/docs/CLUECorpus2020.md b/examples/pre-training/tools/preprocess/docs/CLUECorpus2020.md new file mode 100644 index 000000000..3c6727fab --- /dev/null +++ b/examples/pre-training/tools/preprocess/docs/CLUECorpus2020.md @@ -0,0 +1,12 @@ +## CLUECorpus2020 语料 + +| 名称 | 文本类型 | 纯文本大小 | +|-|-|-| +| CLUECorpus2020| 中文 | 200GB | + +CLUECorpus2020 过对Common Crawl的中文部分进行语料清洗得到。开源部分提供了约200G左右的语料文本,详细介绍见[官网](https://github.com/CLUEbenchmark/CLUECorpus2020#%E6%95%B0%E6%8D%AE%E4%B8%8B%E8%BD%BD),用户可以通过邮件申请下载,方式如下: + +> 数据下载 +> 申请方式: 将使用语料研究目的和用途,计划、研究机构和申请者介绍,发送到邮箱,并承诺不向第三方提供。 +> +> 邮箱: CLUEbenchmark@163.com,标题是:CLUECorpus2020 200G语料库 diff --git a/examples/pre-training/tools/preprocess/docs/CLUECorpusSmall.md b/examples/pre-training/tools/preprocess/docs/CLUECorpusSmall.md new file mode 100644 index 000000000..c2173c3df --- /dev/null +++ b/examples/pre-training/tools/preprocess/docs/CLUECorpusSmall.md @@ -0,0 +1,76 @@ +# CLUECorpusSmall + +| 名称 | 文本类型 | 纯文本大小 | +|-|-|-| +| CLUECorpusSmall| 中文 | 14GB | + +**数据集简介**:可用于语言建模、预训练或生成型任务等,数据量超过14G,近4000个定义良好的 txt 文件、50亿个字。主要部分来自于 nlp_chinese_corpus 项目 +包含如下子语料库(总共14G 语料):新闻语料[news2016zh_corpus.zip](https://bj.bcebos.com/v1/ai-studio-online/6bac09db4e6d4857b6d680d34447457490cb2dbdd8b8462ea1780a407f38e12b?responseContentDisposition=attachment%3B%20filename%3Dnews2016zh_corpus.zip), 社区互动语料[webText2019zh_corpus.zip](https://bj.bcebos.com/v1/ai-studio-online/83da03f7b4974871a52348b41c16c7e3b34a26d5ca644f558df8435be4de51c3?responseContentDisposition=attachment%3B%20filename%3DwebText2019zh_corpus.zip),维基百科语料[wiki2019zh_corpus.zip](https://bj.bcebos.com/v1/ai-studio-online/d7a166408d8b4ffdaf4de9cfca09f6ee1e2340260f26440a92f78134d068b28f?responseContentDisposition=attachment%3B%20filename%3Dwiki2019zh_corpus.zip),评论数据语料[comment2019zh_corpus.zip](https://bj.bcebos.com/v1/ai-studio-online/b66ddd445735408383c42322850ac4bb82faf9cc611447c2affb925443de7a6d?responseContentDisposition=attachment%3B%20filename%3Dcomment2019zh_corpus.zip)。 + +## 数据获取 + +用户可以通过官方 github 网页下载,https://github.com/CLUEbenchmark/CLUECorpus2020 。同时,为方便用户,我们也提供了 aistudio 数据集下载地址。[part1](https://aistudio.baidu.com/aistudio/datasetdetail/60598),[part2](https://aistudio.baidu.com/aistudio/datasetdetail/124357)。使用 aistudio 版本的数据,下载好后,可以核对 md5值: +```shell +> md5sum ./* + 8a8be341ebce39cfe9524fb0b46b08c5 ./comment2019zh_corpus.zip + 4bdc2c941a7adb4a061caf273fea42b8 ./news2016zh_corpus.zip + fc582409f078b10d717caf233cc58ddd ./webText2019zh_corpus.zip + 157dacde91dcbd2e52a60af49f710fa5 ./wiki2019zh_corpus.zip +``` +解压文件 +```shell +unzip comment2019zh_corpus.zip -d clue_corpus_small_14g/comment2019zh_corpus +unzip news2016zh_corpus.zip -d clue_corpus_small_14g/news2016zh_corpus +unzip webText2019zh_corpus.zip -d clue_corpus_small_14g/webText2019zh_corpus +unzip wiki2019zh_corpus.zip -d clue_corpus_small_14g/wiki2019zh_corpus +``` +将 txt 文件转换为 jsonl 格式 +``` +python trans_to_json.py --input_path ./clue_corpus_small_14g --output_path clue_corpus_small_14g.jsonl +``` +现在我们得到了 jsonl 格式的数据集。 + +## 中文预训练数据制作 + +下面是针对训练任务的数据集应用。 + +* llama 为例 +```shell +python -u create_pretraining_data.py \ + --model_name "idea-ccnl/ziya-llama-13b-v1" \ + --input_path "clue_corpus_small_14g.jsonl" \ + --output_prefix "clue_corpus_small_14g" \ + --data_format "JSON" \ + --json_key "text" \ + --data_impl "mmap" \ + --append_eos \ + --log_interval 10000 \ + --workers 48 +``` + +* ernie 为例 +```shell +python -u create_pretraining_data.py \ + --model_name "ernie-3.0-base-zh" \ + --input_path "clue_corpus_small_14g.jsonl" \ + --output_prefix "clue_corpus_small_14g" \ + --data_format "JSON" \ + --json_key "text" \ + --split_sentences \ + --data_impl "mmap" \ + --chinese \ + --cn_whole_word_segment \ + --cn_seg_func "lac" \ + --log_interval 10000 \ + --workers 48 +``` + +- model_name 可以更换为[其他模型](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm)。 +- workers 表示转化的线程数目 + +数据共有文档`15702702`条左右,由于分词比较耗时,大概一小时左右可以完成。在当前目录下产出训练所需数据。 +``` +clue_corpus_small_14g.bin +clue_corpus_small_14g.idx +``` +用户可以使用此数据进行预训练任务。 diff --git a/examples/pre-training/tools/preprocess/docs/OpenWebText2.md b/examples/pre-training/tools/preprocess/docs/OpenWebText2.md new file mode 100644 index 000000000..74df70726 --- /dev/null +++ b/examples/pre-training/tools/preprocess/docs/OpenWebText2.md @@ -0,0 +1,42 @@ +# OpenWebText2 + +| 名称 | 文本类型 | 纯文本大小 | +|-|-|-| +| OpenWebText2 | 英文 | 70GB | + +## 数据获取 + +[OpenWebTextCorpus](https://skylion007.github.io/OpenWebTextCorpus/)是一个开源的英文网页文本数据集,数据来源于 Reddit,经过去重、清洗、提取,最终包含800多万个文档。 +本示例采用 EleutherAI 清洗好的[OpenWebText2数据](https://openwebtext2.readthedocs.io/en/latest/index.html#download-plug-and-play-version) + +下载以后通过以下命令解压: + +```shell +wget https://paddlenlp.bj.bcebos.com/models/transformers/gpt/openwebtext2.jsonl.zst.tar +tar -xvf openwebtext2.json.zst.tar -C /path/to/openwebtext +``` + +## Llama 训练数据制作 + +然后使用`create_pretraining_data.py`脚本进行数据集制作: +``` +python -u create_pretraining_data.py \ + --model_name meta-llama/Llama-2-7b \ + --tokenizer_name LlamaTokenizer \ + --data_format JSON \ + --input_path /path/to/openwebtext/ \ + --append_eos \ + --output_prefix llama_openwebtext \ + --workers 40 \ + --log_interval 10000 \ + --data_impl "mmap" +``` +处理时间约一个小时左右,就可以得到我们需要的`llama_openwebtext.bin`, `llama_openwebtext.idx`数据集文件。 + +将所有预处理得到的文件统一放入一个文件夹中,以备训练使用: + +``` +mkdir data +mv llama_openwebtext.bin ./data +mv llama_openwebtext.idx ./data +``` diff --git a/examples/pre-training/tools/preprocess/docs/WuDaoCorpusBase.md b/examples/pre-training/tools/preprocess/docs/WuDaoCorpusBase.md new file mode 100644 index 000000000..b809d89f0 --- /dev/null +++ b/examples/pre-training/tools/preprocess/docs/WuDaoCorpusBase.md @@ -0,0 +1,101 @@ +# WuDaoCorpus2.0 Base 语料 + + +| 名称 | 文本类型 | 纯文本大小 | +|-|-|-| +| WuDaoCorpus2.0 Base| 中文 | 200GB | + +WuDaoCorpora 是悟道爬取的中文大规模语料。整体数量为3TB,目前开源的部分为 WuDaoCorpus2.0 bases 数据集,大小为200GB。 + +## 数据获取 + +**1. 下载解压** + +用户[此处下载](https://www.scidb.cn/en/detail?dataSetId=c6a3fe684227415a9db8e21bac4a15ab),即可直接下载数据。下载好的压缩数据约 64GB。解压 +``` +unrar x WuDaoCorpus2.0_base_200G.rar +``` +**2. 语料分词** + +由于 WuDao 数据集比较大,分词比较耗时,这里先进行了语料分词: +```shell +python words_segmentation.py \ + --input_path ./WuDaoCorpus2.0_base_200G \ + --workers 40 \ + --data_format wudao \ + --cn_seg_func seg \ + --output_path ./wudao_lac_cut \ +``` + +注:预训练需要实现 SOP( Sentence Order Predict) 任务,在分词的同时,我们使用 简单规则 进行了文本断句。如果语料只有一句话,建议去除 SOP loss,训练时设置 `binary_head=False`。 + +**3. 转换为 jsonl 格式** + +文本转化完成后。我们使用 `../data_tools/trans_to_json.py`重新转换为 jsonl 格式(分词完毕)。 +```shell +python ./trans_to_json.py \ + --input_path ./wudao_lac_cut \ + --output_path wudao_corpus_200g.jsonl \ + --workers 40 +``` +在当前目录下产出数据`wudao_corpus_200g.jsonl`。格式如下: +``` +{"text": "主持人 : 作为 一个 曲线救国 的 路线 我们 没 办法 。\n金鑫 : 考试 和 分数 只是 一个 阶段性 的 评价 手段 , 不是 目的 , 就 像 人 活着 的 目的 不是 为了 吃饭 , 吃饭 是 为了 让 我们 活下去 , 我们 学习 的 目的 不是 为了 考试 , 不是 为了 那个 分数 , 而是 我 掌握 了 知识 , 成为 我 内在 的 能力 , 将来 我 去 创作 创造 工作 , 我能 把 它 做 得 更好 。\n主持人 : 特别感谢 金总 今天 接受 我 的 访谈 , 也 让 我 从 别的 层面 看到 了 一对一 到底 存在 的 道理 是 什么 , 并且 能 发展 那么 好 的 原因 在 哪里 。\n在 节目 后 您 谈谈 您 对 一对一 未来 的 希望 , 包括 您 对 它 未来 的 设想 是 什么 ?\n金鑫 : 一对一 个性化 教育 现在 还是 在 初级阶段 , 如果 是 四个 阶段 的话 , 现在 还是 在 第一阶段 到 第二阶段 迈进 的 , 学大 在 这方面 我们 希望 能 做 得 更 快 更 远 一些 。\n将来 个性化 教育 一定 是 能够 帮助 学生 在 成绩 上 的 提升 , 能够 更好 的 成长 , 进而 成为 对 社会 对 国家 更 有用 的 人才 , 就是 我们 的 成绩 、 成长 、 成才 。\n学大 1 对 1 教育 的 教师 团队 由 各科 优秀教师 、 考试 指导 专家 、 心理 辅导 专家 及 学习 方法 指导 专家 组成 , 同时 配备 专职 班主任 及 学习 监管 师 , 全方位 辅导 顺利 而 有序 的 运作 。\n其中 部分 教师 担任 多年 毕业班 教学 工作 , 多次 参与 中 考试 命题 研究 及 阅卷 工作 , 深谙 中 考试 精髓 , 能够 在 短 的 时间 内 引领 学生 掌握 中 考试 知识 重点 , 快速 提分 。\n■ 对于 成绩 差 的 学生 : 注重 学生 基础知识 , 力求 让 学生 在 基础 中 找 自信 , 在 自信 中 提升 ;\n注重 主观题 的 解题 方法 及 思路 , 以此 来 加强 对 基础知识 的 运用 。\n■ 对于 成绩 需要 拔高 的 学生 : 找出 学生 弱点 , 加强 基础 , 重点 提高 弱势 项目 。\n"} +{"text": "武田信玄 是 天生 的 武将 , 一生 开拓 了 八十五万 石至 九十余万 石之多 的 领地 。\n武田信玄 他 21 岁 时 流放 自己 的 父亲 武田信虎 至骏河 , 避免 父亲 传位 给 弟弟 , 从而 登上 了 第 19 代家督 之位 。\n他 将 信 浓国 ( 现 长野县 ) 纳入 控制 范围 后 , 又 与 当时 的 豪强 今井氏 、 北条 氏 结成 三国 军事同盟 , 与 上 杉谦信 在 川 中岛 前后 展开 了 五次 大战 。\n武田信玄 勇于 进攻 。\n他 连续 攻打 邻国 , 扩大 自己 势力范围 , 可称 遇神 杀神 , 遇佛 杀佛 。\n他 不仅 流放 了 自己 的 父亲 , 连 自己 的 嫡子 武田义信 因 与 他 在 战略 方向 上 相左 , 也 被 他 幽禁 于 佛寺 , 随即 被迫 自杀 。\n武田信玄 虽然 是 战国 武将 中 的 最强者 , 但 他 的 弱点 是 年龄 。\n信玄比 织田信长 年长 13 岁 , 比上 杉谦信 年长 9 岁 。\n当信 玄年 届 五十 之 时 , 信长 和 谦信 犹 在 壮年 。\n上杉谦信 而且 , 武田信玄 虽 驰骋 天下 , 却 未率 军 进过 京都 , 而 织田信长 在 永禄 十一年 ( 1568 年 ) 就 以 拥立 第 15 代 将军 足利义 昭 为名 率兵 上洛 了 。\n所谓 \" 制 京都 者 得 天下 \" , 所以 , 想要 一统天下 , 武田信玄 的 时间 很 紧迫 。\n元龟 三年 ( 1572 年 ) , 武田信玄 与 室 町 幕府 第 15 代 将军 足利义 昭 、 本愿 寺 显如 , 以及 浅井 氏 、 朝仓氏 等 反 织田信长 实力 组成 联盟 , 编织 \" 反信长 包围圈 \" 。\n同年 10 月 3 日 , 武田信玄 率领 大军 , 开始 了 第一次 上洛之行 。\n是 年 , 信玄 52 岁 , 这 也许 是 他 统一天下 的 最后 一次 机会 。\n武田信玄 所 率领 的 是 当时 战国 最强 的 3 万甲州 精兵 。\n打着 \" 风林火山 \" 的 旗帜 , 武田军 第一站 就 到达 了 织田信长 的 同盟 德川家康 所在 的 三河 远江 。\n织田信长 德川家康 的 军队 在 甲州 精兵 之前 显得 不堪一击 , 到 了 10 月 13 日 , 只来 成 、 天 方城 、 一 宫城 、 饭田 城 、 各和城 、 向 笠 城 等 城池 纷纷 被 攻陷 。\n德川家康 见势不妙 , 决定 在 浜松 城中 闭门不出 。\n但是 武田信玄 毫不 松懈 , 又 将 家康 在 远江 地区 的 重要 据点 二俣城 攻破 。\n德川家康 集合 所有 军队 共 1 万 1 千人 , 出城 与 信玄 决一死战 , 但 大败 而 还 , 险些 失 了 性命 。\n这次 战争 被 称为 \" 三方 原战 \" , 德川家康 曾经 承认 这次 战争 是 他 生平 最大 的 失败 。\n"} +``` + +## 中文预训练数据制作 + +下面是针对训练任务的数据集应用。 + +* llama 为例 + +注:若使用 llama 模型,则不需要提前进行分词,请将 WuDaoCorpus2.0_base_200G 中的 json 文件预处理为如下格式的 jsonl 文件: +``` +{"text": "飞桨是功能完备、开源开放的产业级深度学习平台。飞桨拥有..."} +{"text": "PaddleNLP是自然语言..."} +``` + +之后利用如下脚本将对应的 jsonl 文件转化为.bin & .idx 文件。 +```shell +python -u create_pretraining_data.py \ + --model_name "idea-ccnl/ziya-llama-13b-v1" \ + --input_path "wudao_corpus_200g.jsonl" \ + --output_prefix "wudao_corpus_200g" \ + --data_format "JSON" \ + --json_key "text" \ + --data_impl "mmap" \ + --append_eos \ + --log_interval 10000 \ + --workers 48 +``` + +* ernie 为例 +```shell +python -u create_pretraining_data.py \ + --model_name "ernie-3.0-base-zh" \ + --input_path "wudao_corpus_200g.jsonl" \ + --output_prefix "wudao_corpus_200g" \ + --data_format "JSON" \ + --json_key "text" \ + --split_sentences \ + --data_impl "mmap" \ + --chinese \ + --cn_whole_word_segment \ + --cn_seg_func "jieba" \ + --cn_splited \ + --log_interval 10000 \ + --workers 48 +``` + + +- 我们提前进行了分词,所以加上了 `cn_splited`,否则不需要使用此选项。 +- model_name 可以更换为[其他模型](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm)。 +- workers 表示转化的线程数目 + +在当前目录下产出训练所需数据。 +``` +wudao_corpus_200g.bin +wudao_corpus_200g.idx +``` +用户可以使用此数据进行预训练任务。 diff --git a/examples/pre-training/tools/preprocess/merge.py b/examples/pre-training/tools/preprocess/merge.py new file mode 100644 index 000000000..85681a5fc --- /dev/null +++ b/examples/pre-training/tools/preprocess/merge.py @@ -0,0 +1,104 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from datetime import datetime + +from paddleformers.data import indexed_dataset + + +def print_datetime(string): + time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + print("[" + string + "] datetime: {} ".format(time_str)) + + +def main(args): + + prefixes = set() + for basename in os.listdir(args.input): + prefix, ext = os.path.splitext(basename) + + if prefix in prefixes: + continue + + if not os.path.isfile(os.path.join(args.input, basename)): + continue + + ext_pair = ".bin" if ext == ".idx" else ".idx" + assert os.path.isfile( + os.path.join(args.input, prefix) + ext_pair + ), f"ERROR: {ext_pair} file not provided for {os.path.join(args.input, prefix)}" + + prefixes.add(prefix) + + builder = None + + for prefix in sorted(prefixes): + print_datetime(f"start processing file {prefix}") + if builder is None: + dataset = indexed_dataset.make_dataset( + os.path.join(args.input, prefix), args.data_impl + ) + + if isinstance(dataset, indexed_dataset.MMapIndexedDataset): + builder = indexed_dataset.MMapIndexedDatasetBuilder( + args.output_prefix + ".bin", dtype=dataset._index.dtype + ) + else: + builder = indexed_dataset.IndexedDatasetBuilder( + args.output_prefix + ".bin", dtype=dataset.dtype + ) + + del dataset + print_datetime(f"start merge file {prefix}") + builder.merge_file_(os.path.join(args.input, prefix)) + print_datetime(f"end merge file {prefix}") + + print_datetime("start finalize") + builder.finalize(args.output_prefix + ".idx") + print_datetime("end finalize") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + group = parser.add_argument_group(title="input data") + group.add_argument( + "--input", + type=str, + required=True, + help="Path to directory containing all document files to merge", + ) + group.add_argument("--data_impl", type=str, required=True, help="data_impl") + + group = parser.add_argument_group(title="output data") + group.add_argument( + "--output-prefix", + type=str, + required=True, + help="Path to binary output file without suffix", + ) + + args = parser.parse_args() + + assert os.path.isdir( + args.input + ), f"ERROR: {args.input} is not a directory or does not exist" + + assert os.path.isdir( + os.path.dirname(args.output_prefix) + ), f"ERROR: {os.path.dirname(args.output_prefix)} is not a directory or does not exist" + + main(args) diff --git a/examples/pre-training/tools/preprocess/trans_to_json.py b/examples/pre-training/tools/preprocess/trans_to_json.py new file mode 100644 index 000000000..68d95cd5c --- /dev/null +++ b/examples/pre-training/tools/preprocess/trans_to_json.py @@ -0,0 +1,167 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import multiprocessing +import os +import shutil +import sys +import time +from functools import partial + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_path", + type=str, + required=True, + help="Path to you raw files. Folder or file path.", + ) + parser.add_argument( + "--output_path", + type=str, + required=True, + help="Path to save the output json files.", + ) + parser.add_argument( + "--json_key", type=str, default="text", help="The content key of json file." + ) + parser.add_argument( + "--doc_spliter", + type=str, + default="", + help="Spliter between documents. We will strip the line, if you use blank line to split doc, leave it blank.", + ) + parser.add_argument( + "--min_doc_length", type=int, default=10, help="Minimal char of a document." + ) + parser.add_argument( + "--workers", type=int, default=1, help="Number of worker processes to launch" + ) + parser.add_argument( + "--log_interval", type=int, default=1, help="Interval between progress updates." + ) + parser.add_argument("--no-merge", action="store_true", help="Don't merge the file.") + parser.add_argument( + "--no-shuffle", action="store_true", help="Don't shuffle the file." + ) + args = parser.parse_args() + return args + + +def raw_text_to_json(path, doc_spliter="", json_key="text", min_doc_length=10): + path = os.path.abspath(path) + if not os.path.exists(path): + print("No found file %s" % path) + return 0, None + + out_filepath = path + ".jsonl" + fout = open(out_filepath, "w", encoding="utf-8") + len_files = 0 + with open(path, "r") as f: + doc = "" + line = f.readline() + while line: + len_files += len(line) + if line.strip() == doc_spliter: + if len(doc) > min_doc_length: + fout.write(json.dumps({json_key: doc}, ensure_ascii=False) + "\n") + doc = "" + else: + doc += line + line = f.readline() + + if len(doc) > min_doc_length: + fout.write(json.dumps({json_key: doc}, ensure_ascii=False) + "\n") + doc = "" + + return len_files, out_filepath + + +def merge_file(file_paths, output_path): + if not output_path.endswith(".jsonl"): + output_path = output_path + ".jsonl" + print("Merging files into %s" % output_path) + with open(output_path, "wb") as wfd: + for f in file_paths: + if f is not None and os.path.exists(f): + with open(f, "rb") as fd: + shutil.copyfileobj(fd, wfd) + os.remove(f) + print("File save in %s" % output_path) + return output_path + + +def shuffle_file(output_path): + print("Shuffling the jsonl file...") + if os.path.exists(output_path): + os.system("shuf %s -o %s" % (output_path, output_path)) + print("File shuffled!!!") + else: + raise ValueError("File not found: %s" % output_path) + + +def main(): + args = get_args() + startup_start = time.time() + + file_paths = [] + if os.path.isfile(args.input_path): + file_paths.append(args.input_path) + else: + for root, _, fs in os.walk(args.input_path): + for f in fs: + file_paths.append(os.path.join(root, f)) + + pool = multiprocessing.Pool(args.workers) + + startup_end = time.time() + proc_start = time.time() + total_bytes_processed = 0 + print("Time to startup:", startup_end - startup_start) + + trans_json = partial( + raw_text_to_json, + doc_spliter=args.doc_spliter, + json_key=args.json_key, + min_doc_length=args.min_doc_length, + ) + encoded_files = pool.imap(trans_json, file_paths, 1) + + out_paths = [] + for i, (bytes_processed, out_path) in enumerate(encoded_files, start=1): + total_bytes_processed += bytes_processed + out_paths.append(out_path) + + if i % args.log_interval == 0: + current = time.time() + elapsed = current - proc_start + mbs = total_bytes_processed / elapsed / 1024 / 1024 + print( + f"Processed {i} files", + f"({i/elapsed} files/s, {mbs} MB/s).", + file=sys.stderr, + ) + + if not args.no_merge: + output_path = merge_file(out_paths, args.output_path) + if not args.no_shuffle: + shuffle_file(output_path) + + +if __name__ == "__main__": + main() + # profile.run("main()", "testprof") diff --git a/examples/pre-training/tools/preprocess/words_segmentation.py b/examples/pre-training/tools/preprocess/words_segmentation.py new file mode 100644 index 000000000..067239430 --- /dev/null +++ b/examples/pre-training/tools/preprocess/words_segmentation.py @@ -0,0 +1,223 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import multiprocessing +import os +import re +import sys +import time +from functools import partial + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_path", + type=str, + required=True, + help="Path to you raw files. Folder or file path.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./tmp", + help="Path to save the output json files.", + ) + parser.add_argument( + "--data_format", + type=str, + default="jsonl", + choices=["jsonl", "wudao"], + help="Path to you raw files. Folder or file path.", + ) + parser.add_argument( + "--cn_seg_func", + type=str, + default="jieba", + choices=["lac", "seg", "jieba"], + help="Words segment function for chinese words.", + ) + parser.add_argument( + "--workers", type=int, default=1, help="Number of worker processes to launch" + ) + parser.add_argument( + "--log_interval", type=int, default=1, help="Interval between progress updates." + ) + args = parser.parse_args() + return args + + +def lexical_analysis_fn(): + from LAC import LAC + + lac = LAC(mode="lac") + + def process(line): + words, _ = lac.run(line) + return words + + return process + + +def chinese_segmentation_fn(): + from LAC import LAC + + lac_cws = LAC(mode="seg") + + def process(line): + words = lac_cws.run(line) + return words + + return process + + +def jieba_segmentation_fn(): + import jieba + + def process(line): + words = jieba.cut(line) + return list(words) + + return process + + +CHINESE_SEG_FUNC = { + "lac": lexical_analysis_fn(), + "seg": chinese_segmentation_fn(), + "jieba": jieba_segmentation_fn(), +} + + +def read_wudao(path): + print("Loading %s" % path) + with open(path, "r") as f: + try: + contents = json.load(f) + except Exception: + print("Failed to load %s" % path) + raise StopIteration + for js in contents: + yield js["content"] + + +def read_jsonl(path): + print("Loading %s" % path) + with open(path, "r") as f: + line = f.readline() + while line: + contents = json.load(f) + yield contents["text"] + line = f.readline() + + +READFILE_FUNC = { + "jsonl": read_jsonl, + "wudao": read_wudao, +} + +special_chars = ["\n", "。", "?", "?", " ", ";", ";", "!", "!"] +split_chars = ["。", "?", "?", ";", ";", "!", "!"] + + +def text_to_text(path, output_path, read_func, seg_func): + out_name = os.path.join(output_path, path[-20:]) + + print("Write into %s" % out_name) + if os.path.exists(out_name): + print("File exists %s" % out_name) + return 0, None + + seg_func = CHINESE_SEG_FUNC[seg_func] + read_func = READFILE_FUNC[read_func] + + data_len = 0 + count = 0 + with open(out_name, "w") as f: + for text in read_func(path): + # for js in contents: + count += 1 + # text = js["content"] + data_len += len(text.encode("utf-8")) + # make special char only once, + # because of those token will be treat as sentence spliter. + # 此处为断句逻辑 + for char in special_chars: + text = re.sub("[" + char + "]+[ ]*", char, text) + for char in split_chars: + text = text.replace(char, char + "\n") + + # 此处为分词逻辑 + final = "" + for line in text.split("\n"): + if len(line) == 0: + continue + words = seg_func(line) + final += " ".join(words) + "\n" + f.write(final + "\n") + + return data_len, None + + +def main(): + args = get_args() + startup_start = time.time() + + file_paths = [] + if os.path.isfile(args.input_path): + file_paths.append(args.input_path) + else: + for root, _, fs in os.walk(args.input_path): + for f in fs: + file_paths.append(os.path.join(root, f)) + + pool = multiprocessing.Pool(args.workers) + + startup_end = time.time() + proc_start = time.time() + total_bytes_processed = 0 + print("Time to startup:", startup_end - startup_start) + + if not os.path.exists(args.output_path): + os.makedirs(args.output_path) + + trans_func = partial( + text_to_text, + output_path=args.output_path, + seg_func=args.cn_seg_func, + read_func=args.data_format, + ) + + encoded_files = pool.imap(trans_func, file_paths, 1) + + out_paths = [] + for i, (bytes_processed, out_path) in enumerate(encoded_files, start=1): + total_bytes_processed += bytes_processed + out_paths.append(out_path) + + if i % args.log_interval == 0: + current = time.time() + elapsed = current - proc_start + mbs = total_bytes_processed / elapsed / 1024 / 1024 + print( + f"Processed {i} files", + f"({i/elapsed} files/s, {mbs} MB/s).", + file=sys.stderr, + ) + pool.close() + + +if __name__ == "__main__": + main()