diff --git a/examples/models/llama/main.cpp b/examples/models/llama/main.cpp index b655a619b26..3de47598426 100644 --- a/examples/models/llama/main.cpp +++ b/examples/models/llama/main.cpp @@ -67,6 +67,11 @@ DEFINE_int32( DEFINE_bool(warmup, false, "Whether to run a warmup run."); +DEFINE_bool( + ignore_eos, + false, + "Whether to ignore EOS token and continue generating until max_new_tokens is reached."); + DEFINE_string( etdump_path, "etdump.in", @@ -165,6 +170,8 @@ int32_t main(int32_t argc, char** argv) { executorch::extension::llm::GenerationConfig config{ .temperature = temperature}; + config.ignore_eos = FLAGS_ignore_eos; + if (FLAGS_max_new_tokens != -1) { config.max_new_tokens = FLAGS_max_new_tokens; } else { diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index 095f82f75bb..33f92a5bd3f 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -238,6 +238,7 @@ void start_runner( }; executorch::extension::llm::GenerationConfig config{ true, + false, -1, false, FLAGS_seq_len, diff --git a/examples/qualcomm/oss_scripts/llama/qnn_multimodal_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_multimodal_runner.cpp index 0b3f2ee4ad1..7cadc0bb0dd 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_multimodal_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_multimodal_runner.cpp @@ -301,6 +301,7 @@ void start_multimodal_runner( // Configure generation executorch::extension::llm::GenerationConfig config{ true, + false, -1, false, FLAGS_seq_len, diff --git a/extension/llm/runner/irunner.h b/extension/llm/runner/irunner.h index ef93f32319c..e50c6cdc074 100644 --- a/extension/llm/runner/irunner.h +++ b/extension/llm/runner/irunner.h @@ -28,6 +28,9 @@ struct GenerationConfig { // Whether to echo the input prompt in the output bool echo = true; + // Whether to ignore EOS token and continue generating until max_new_tokens + bool ignore_eos = false; + // Maximum number of new tokens to generate // If the max_context_len metadata that's serialized in the .pte file exists, // then the number of prompt tokens + max_new_tokens won't exceed diff --git a/extension/llm/runner/multimodal_runner.cpp b/extension/llm/runner/multimodal_runner.cpp index 5c0c1e658a7..96d14e2a855 100644 --- a/extension/llm/runner/multimodal_runner.cpp +++ b/extension/llm/runner/multimodal_runner.cpp @@ -194,6 +194,9 @@ Error MultimodalRunner::generate( "Max new tokens %d is less than or equal to 0", max_new_tokens); + // Set ignore_eos based on config + text_token_generator_->set_ignore_eos(config.ignore_eos); + // Generate tokens using the text token generator std::vector prompt_tokens = {prefill_next_token}; auto generate_result = text_token_generator_->generate( diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index d3dda02cb88..92dbced9560 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -193,6 +193,9 @@ Error TextLLMRunner::generate( // start the main loop prompt_tokens.push_back(cur_token); + // Set ignore_eos based on config + text_token_generator_->set_ignore_eos(config.ignore_eos); + // Generate max_new_tokens - 1 because prefill already generated 1 token. auto generate_result = text_token_generator_->generate( prompt_tokens, diff --git a/extension/llm/runner/text_token_generator.h b/extension/llm/runner/text_token_generator.h index b7fca420bc3..128de05d1d9 100644 --- a/extension/llm/runner/text_token_generator.h +++ b/extension/llm/runner/text_token_generator.h @@ -32,6 +32,10 @@ class ET_EXPERIMENTAL TextTokenGenerator { use_kv_cache_(use_kv_cache), stats_(stats) {} + void set_ignore_eos(bool ignore_eos) { + ignore_eos_ = ignore_eos; + } + virtual ~TextTokenGenerator() = default; /** @@ -125,7 +129,7 @@ class ET_EXPERIMENTAL TextTokenGenerator { } // data-dependent terminating condition: we have n_eos_ number of EOS - if (eos_ids_->find(cur_token) != eos_ids_->end()) { + if (!ignore_eos_ && eos_ids_->find(cur_token) != eos_ids_->end()) { printf("\n"); ET_LOG(Info, "\nReached to the end of generation"); break; @@ -169,6 +173,7 @@ class ET_EXPERIMENTAL TextTokenGenerator { TextDecoderRunner* text_decoder_runner_; std::unique_ptr> eos_ids_; bool use_kv_cache_; + bool ignore_eos_ = false; // state machine bool should_stop_ = false;