Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions extension/llm/runner/irunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ struct GenerationConfig {
// Whether to echo the input prompt in the output
bool echo = true;

// Whether to ignore EOS token and continue generating until max_new_tokens
bool ignore_eos = false;

// Maximum number of new tokens to generate
// If the max_context_len metadata that's serialized in the .pte file exists,
// then the number of prompt tokens + max_new_tokens won't exceed
Expand Down
3 changes: 3 additions & 0 deletions extension/llm/runner/multimodal_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ Error MultimodalRunner::generate(
"Max new tokens %d is less than or equal to 0",
max_new_tokens);

// Set ignore_eos based on config
text_token_generator_->set_ignore_eos(config.ignore_eos);

// Generate tokens using the text token generator
std::vector<uint64_t> prompt_tokens = {prefill_next_token};
auto generate_result = text_token_generator_->generate(
Expand Down
3 changes: 3 additions & 0 deletions extension/llm/runner/text_llm_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ Error TextLLMRunner::generate(
// start the main loop
prompt_tokens.push_back(cur_token);

// Set ignore_eos based on config
text_token_generator_->set_ignore_eos(config.ignore_eos);

// Generate max_new_tokens - 1 because prefill already generated 1 token.
auto generate_result = text_token_generator_->generate(
prompt_tokens,
Expand Down
7 changes: 6 additions & 1 deletion extension/llm/runner/text_token_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class ET_EXPERIMENTAL TextTokenGenerator {
use_kv_cache_(use_kv_cache),
stats_(stats) {}

void set_ignore_eos(bool ignore_eos) {
ignore_eos_ = ignore_eos;
}

virtual ~TextTokenGenerator() = default;

/**
Expand Down Expand Up @@ -125,7 +129,7 @@ class ET_EXPERIMENTAL TextTokenGenerator {
}

// data-dependent terminating condition: we have n_eos_ number of EOS
if (eos_ids_->find(cur_token) != eos_ids_->end()) {
if (!ignore_eos_ && eos_ids_->find(cur_token) != eos_ids_->end()) {
printf("\n");
ET_LOG(Info, "\nReached to the end of generation");
break;
Expand Down Expand Up @@ -169,6 +173,7 @@ class ET_EXPERIMENTAL TextTokenGenerator {
TextDecoderRunner* text_decoder_runner_;
std::unique_ptr<std::unordered_set<uint64_t>> eos_ids_;
bool use_kv_cache_;
bool ignore_eos_ = false;

// state machine
bool should_stop_ = false;
Expand Down
Loading