-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Description
Hello,
When using packing with the bfd strategy, it looks like too much truncation is done when the seq_length is smaller than the average length of the sequences we want to pack.
For example :
from datasets import Dataset
from trl import pack_dataset
examples = {
"input_ids": [[1, 2, 3, 4], [5, 6], [7, 8, 9], [10]],
"attention_mask": [[1, 1, 1, 1], [1, 0], [1, 0, 0], [1]],
}
dataset = Dataset.from_dict(examples)
packed_dataset = pack_dataset(dataset, seq_length=3, strategy="bfd")
print(packed_dataset )results in:
{'input_ids': [[1, 2, 3], [7, 8, 9], [5, 6, 10]],
'attention_mask': [[1, 1, 1], [1, 0, 0], [1, 0, 1]],
'seq_lengths': [[3], [3], [2, 1]]}So the token '4' is missing from the training tokens.
In a extreme case:
examples_2 = {
"input_ids": [[0, 0], [1, 2, 3, 4], [5, 6, 7, 8, 9], [10]],
"attention_mask": [[1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1], [1]],
}
dataset_2 = Dataset.from_dict(examples_2)
print(pack_dataset(dataset_2, seq_length=1, strategy="bfd")[:])results in:
{'input_ids': [[0], [1], [5], [10]],
'attention_mask': [[1], [1], [1], [1]],
'seq_lengths': [[1], [1], [1], [1]]}So here we are basically applying truncation to every sequence instead of having twelve sequences of one token.
If we put ourself in a more usefull setting, when I was finetunning on some very long sequences with a seq_lenfth of 4096, the majority of the tokens was discarded y the bfd packing. On my dataset, the bfd method kept only 0.2% of the total training tokens.
Is the behavior normal ?
I would find it useful to add an option to still have tokens that are deleted in other sequences, even if this is less than ideal. It would be a good compromise between the current versions of bfd and wrapped.