diff --git a/Makefile b/Makefile index 097536fd0a6..459f2f50c11 100644 --- a/Makefile +++ b/Makefile @@ -127,7 +127,7 @@ help: @echo " llava-cpu - Build Llava runner with CPU backend" @echo " gemma3-cuda - Build Gemma3 runner with CUDA backend" @echo " gemma3-cpu - Build Gemma3 runner with CPU backend" - @echo " gemma4_31b-cuda - Build Gemma 4 31B runner with CUDA backend" + @echo " gemma4_31b-cuda - Build Gemma 4 31B runner + OpenAI serving worker with CUDA backend" @echo " gemma4_31b-mlx - Build Gemma 4 31B runner with MLX backend" @echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner + OpenAI serving worker (CUDA)" @echo " qwen3_5_moe-metal - Build Qwen3.5 MoE runner with Metal backend" @@ -444,11 +444,13 @@ qwen3_5_moe-cuda: gemma4_31b-cuda: @echo "==> Building and installing ExecuTorch with CUDA..." cmake --workflow --preset llm-release-cuda - @echo "==> Building Gemma 4 31B runner with CUDA..." + @echo "==> Building Gemma 4 31B runner + serving worker with CUDA..." cd examples/models/gemma4_31b && cmake --workflow --preset gemma4-31b-cuda @echo "" @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner" + @echo " Serving worker: cmake-out/examples/models/gemma4_31b/gemma4_31b_worker" + @echo " Launch: see examples/models/gemma4_31b/README.md (Serving)" gemma4_31b-mlx: @echo "==> Building and installing ExecuTorch with MLX..." diff --git a/examples/models/gemma4_31b/CMakeLists.txt b/examples/models/gemma4_31b/CMakeLists.txt index 52419eb95bc..02100eb2f03 100644 --- a/examples/models/gemma4_31b/CMakeLists.txt +++ b/examples/models/gemma4_31b/CMakeLists.txt @@ -15,6 +15,9 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) set(_common_include_directories ${EXECUTORCH_ROOT}/..) +set(_json_include + ${EXECUTORCH_ROOT}/extension/llm/tokenizers/third-party/json/single_include +) # gflags set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags) @@ -58,9 +61,13 @@ endif() # Tokenizer (HuggingFace tokenizer.json) list(APPEND link_libraries tokenizers::tokenizers) -add_executable(gemma4_31b_runner main.cpp) +if(EXECUTORCH_BUILD_CUDA) + add_executable(gemma4_31b_runner main.cpp gemma4_31b_engine.cpp) +else() + add_executable(gemma4_31b_runner main.cpp) +endif() target_include_directories( - gemma4_31b_runner PUBLIC ${_common_include_directories} + gemma4_31b_runner PUBLIC ${_common_include_directories} ${_json_include} ) target_link_libraries(gemma4_31b_runner PUBLIC ${link_libraries}) @@ -71,6 +78,21 @@ if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") endif() endif() +if(EXECUTORCH_BUILD_CUDA) + add_executable(gemma4_31b_worker gemma4_31b_worker.cpp gemma4_31b_engine.cpp) + target_include_directories( + gemma4_31b_worker PUBLIC ${_common_include_directories} ${_json_include} + ) + target_link_libraries(gemma4_31b_worker PUBLIC ${link_libraries}) + + if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") + target_link_options_gc_sections(gemma4_31b_worker) + if(NOT APPLE AND NOT MSVC) + target_link_options(gemma4_31b_worker PRIVATE "LINKER:-s") + endif() + endif() +endif() + if(TARGET mlxdelegate) executorch_target_copy_mlx_metallib(gemma4_31b_runner) endif() diff --git a/examples/models/gemma4_31b/CMakePresets.json b/examples/models/gemma4_31b/CMakePresets.json index 23a7d42e035..5d6019f1911 100644 --- a/examples/models/gemma4_31b/CMakePresets.json +++ b/examples/models/gemma4_31b/CMakePresets.json @@ -13,7 +13,7 @@ }, { "name": "gemma4-31b-cuda", - "displayName": "Gemma 4 31B runner (CUDA)", + "displayName": "Gemma 4 31B runner + serving worker (CUDA)", "inherits": ["gemma4-31b-base"], "cacheVariables": { "EXECUTORCH_BUILD_CUDA": "ON" @@ -39,9 +39,9 @@ "buildPresets": [ { "name": "gemma4-31b-cuda", - "displayName": "Build Gemma 4 31B runner (CUDA)", + "displayName": "Build Gemma 4 31B runner + serving worker (CUDA)", "configurePreset": "gemma4-31b-cuda", - "targets": ["gemma4_31b_runner"] + "targets": ["gemma4_31b_runner", "gemma4_31b_worker"] }, { "name": "gemma4-31b-mlx", @@ -53,7 +53,7 @@ "workflowPresets": [ { "name": "gemma4-31b-cuda", - "displayName": "Configure and build Gemma 4 31B runner (CUDA)", + "displayName": "Configure and build Gemma 4 31B runner + serving worker (CUDA)", "steps": [ { "type": "configure", diff --git a/examples/models/gemma4_31b/README.md b/examples/models/gemma4_31b/README.md index ae3bcb24c19..384fd671a76 100644 --- a/examples/models/gemma4_31b/README.md +++ b/examples/models/gemma4_31b/README.md @@ -139,11 +139,12 @@ model produces sensible text. ## Build the runner ```bash -make gemma4_31b-cuda # Linux — CUDA backend -make gemma4_31b-mlx # macOS — MLX backend (Apple Silicon) +make gemma4_31b-cuda # Linux — CUDA runner + serving worker +make gemma4_31b-mlx # macOS — MLX runner (serving later) ``` -The binary lands at `cmake-out/examples/models/gemma4_31b/gemma4_31b_runner`. +The CUDA build also produces +`cmake-out/examples/models/gemma4_31b/gemma4_31b_worker`. ## Run the .pte @@ -162,3 +163,29 @@ Pass `--raw_prompt` to skip template wrapping for pre-formatted input. For benchmarking, add `--cuda_graph` to capture the decode method in a CUDA graph (decode is fully static — `T=1`). + +## Serving + +The CUDA OpenAI-compatible server is a Python control plane plus a C++ model worker. +The worker owns the ExecuTorch model and speaks the shared JSONL protocol used by +the generic LLM server. + +```bash +LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH \ +python -m executorch.examples.models.gemma4_31b.serve \ + --model-path ./gemma4_31b_exports/model.pte \ + --data-path ./gemma4_31b_exports/aoti_cuda_blob.ptd \ + --tokenizer-path ./gemma4_31b_int4/tokenizer.json \ + --hf-tokenizer ./gemma4_31b_int4 \ + --model-id gemma4-31b \ + --max-sessions 1 +``` + +The launcher defaults to Gemma's `<|tool_call>call:...` parser. Use +`--tool-parser hermes`, `--tool-parser qwen`, or `--tool-parser none` if the +model/template you are testing emits a different tool-call format. + +Named sessions and warm resume require worker capacity above one. CUDA exports +with `get_mutable_buffer_metadata` can use per-session mutable rebinding and +advertise `--max-sessions > 1`; older exports fail closed to a single scratch +session. MLX serving is intentionally left for a later change. diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index d84e2c03a7f..0b4baa88fb1 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -24,6 +24,7 @@ """ import argparse +import json import os import torch @@ -135,6 +136,11 @@ def _pack_for_backend(model: nn.Module, path: str, backend: str) -> None: # Export + lower +def _mutable_buffer_metadata(model: nn.Module) -> str: + mutable = [name for name, _ in model.named_buffers() if ".kv_cache." in name] + return json.dumps({"version": 1, "mutable_buffers": mutable}) + + def export_and_lower( model: Gemma4_31B, config: Gemma4_31BConfig, @@ -181,6 +187,7 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) - import executorch.backends.cuda.quantize_op_dispatch # noqa: F401 materialize_runtime_buffers(model, dtype=torch.bfloat16) + mutable_buffer_metadata = _mutable_buffer_metadata(model) # Int4Tensor weights are used directly — no format conversion. # F.linear dispatches to executorch_cuda::int4_plain_mm (CUDA shim). @@ -248,6 +255,8 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) - "get_vocab_size": config.vocab_size, "get_n_layers": config.num_hidden_layers, "get_max_prefill_chunk": max_prefill, + "get_min_prefill_chunk": 5, + "get_mutable_buffer_metadata": mutable_buffer_metadata, "use_kv_cache": True, "use_sdpa_with_kv_cache": False, "enable_dynamic_shape": True, diff --git a/examples/models/gemma4_31b/gemma4_31b_engine.cpp b/examples/models/gemma4_31b/gemma4_31b_engine.cpp new file mode 100644 index 00000000000..f7ed42718d0 --- /dev/null +++ b/examples/models/gemma4_31b/gemma4_31b_engine.cpp @@ -0,0 +1,598 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#ifdef EXECUTORCH_BUILD_CUDA +#include +#include +#include +#endif + +namespace executorch::extension::llm { + +using ::executorch::extension::from_blob; +using ::executorch::extension::Module; +using ::executorch::extension::TensorPtr; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; +using ::executorch::runtime::Result; +using SizesType = executorch::aten::SizesType; + +namespace { + +Result read_sampled_token( + const executorch::aten::Tensor& output, + float temperature) { +#ifdef EXECUTORCH_BUILD_CUDA + (void)temperature; + const void* ptr = output.const_data_ptr(); + cudaPointerAttributes attrs{}; + const bool on_device = cudaPointerGetAttributes(&attrs, ptr) == cudaSuccess && + attrs.type == cudaMemoryTypeDevice; + float val = 0.0f; + if (on_device) { + if (cudaMemcpy(&val, ptr, sizeof(float), cudaMemcpyDeviceToHost) != + cudaSuccess) { + ET_LOG(Error, "read_sampled_token: cudaMemcpy D2H failed"); + return Error::Internal; + } + } else { + std::memcpy(&val, ptr, sizeof(float)); + } + return static_cast(llrintf(val)); +#else + (void)output; + (void)temperature; + return Error::NotSupported; +#endif +} + +Result> build_gemma_module( + const Gemma4_31BConfig& config) { + std::vector data_files; + if (!config.data_path.empty()) { + data_files.push_back(config.data_path); + } + auto module = std::make_unique( + config.model_path, + data_files, + Module::LoadMode::MmapUseMlockIgnoreErrors, + /*event_tracer=*/nullptr, + /*memory_allocator=*/nullptr, + /*temp_allocator=*/nullptr, + /*share_memory_arenas=*/true); + +#ifdef EXECUTORCH_BUILD_CUDA + if (config.enable_cuda_graph) { + executorch::runtime::BackendOptions<2> cuda_opts; + ET_CHECK_OK_OR_RETURN_ERROR( + cuda_opts.set_option("enable_cuda_graph_for_method", "decode")); + ET_CHECK_OK_OR_RETURN_ERROR( + executorch::runtime::set_option("CudaBackend", cuda_opts.view())); + } + { + executorch::runtime::BackendOptions<1> backend_options; + ET_CHECK_OK_OR_RETURN_ERROR( + backend_options.set_option("weight_sharing_across_methods", true)); + ET_CHECK_OK_OR_RETURN_ERROR( + executorch::runtime::set_option("CudaBackend", backend_options.view())); + } + ET_CHECK_OK_OR_RETURN_ERROR(module->load_method("prefill")); + ET_CHECK_OK_OR_RETURN_ERROR(module->load_method("decode")); +#else + (void)module; + ET_LOG(Error, "Gemma4_31BEngine is implemented for CUDA only"); + return Error::NotSupported; +#endif + return module; +} + +void add_token_piece( + ::tokenizers::Tokenizer* tokenizer, + std::unordered_set& ids, + const char* piece) { + if (auto id = tokenizer->piece_to_id(piece); id.ok()) { + ids.insert(*id); + } +} + +#ifdef EXECUTORCH_BUILD_CUDA +Error register_mutable_fqns(Module* module, int mutable_ctx) { + auto res = module->execute("get_mutable_buffer_metadata"); + if (res.error() != Error::Ok) { + ET_LOG( + Info, "Gemma4_31BEngine: no mutable-buffer metadata; capacity stays 1"); + return res.error(); + } + const auto& outs = res.get(); + if (outs.empty() || !outs[0].isString()) { + ET_LOG(Error, "get_mutable_buffer_metadata did not return a string"); + return Error::InvalidProgram; + } + std::string json_str(outs[0].toString()); + auto j = nlohmann::json::parse(json_str, nullptr, /*allow_exceptions=*/false); + if (j.is_discarded() || !j.is_object() || j.value("version", 0) != 1 || + !j.contains("mutable_buffers") || !j["mutable_buffers"].is_array()) { + ET_LOG(Error, "get_mutable_buffer_metadata has invalid schema"); + return Error::InvalidProgram; + } + std::vector fqns; + for (const auto& f : j["mutable_buffers"]) { + if (!f.is_string() || f.get().empty()) { + ET_LOG(Error, "mutable_buffers entries must be non-empty strings"); + return Error::InvalidProgram; + } + fqns.push_back(f.get()); + } + if (fqns.empty()) { + ET_LOG(Error, "mutable_buffers must be non-empty for multi-session"); + return Error::InvalidProgram; + } + ::executorch::backends::cuda::mutable_state_register_fqns(mutable_ctx, fqns); + return Error::Ok; +} +#endif + +class Gemma4_31BSession : public LLMSession { + public: + Gemma4_31BSession( + Module* module, + std::mutex* exec_mutex, + int mutable_ctx, + int session_token, + std::atomic* live_sessions, + ::tokenizers::Tokenizer* tokenizer, + std::unordered_map metadata, + std::unordered_set eos_ids, + int64_t max_prefill_chunk, + int64_t min_prefill_chunk) + : module_(module), + exec_mutex_(exec_mutex), + mutable_ctx_(mutable_ctx), + session_token_(session_token), + live_sessions_(live_sessions), + tokenizer_(tokenizer), + metadata_(std::move(metadata)), + eos_ids_(std::move(eos_ids)), + max_prefill_chunk_(max_prefill_chunk), + min_prefill_chunk_(min_prefill_chunk) { + decode_tokens_ = from_blob( + decode_token_data_, {1, 1}, executorch::aten::ScalarType::Long); + decode_pos_ = + from_blob(decode_pos_data_, {1}, executorch::aten::ScalarType::Long); +#ifdef EXECUTORCH_BUILD_CUDA + temp_tensor_ = + from_blob(&temp_val_, {1}, executorch::aten::ScalarType::Float); +#endif + } + + ~Gemma4_31BSession() override { +#ifdef EXECUTORCH_BUILD_CUDA + if (session_token_ != ::executorch::backends::cuda::kNoMutableSession) { + ::executorch::backends::cuda::mutable_state_destroy_session( + mutable_ctx_, session_token_); + } +#endif + if (live_sessions_ != nullptr) { + live_sessions_->fetch_sub(1); + } + } + + Error prefill_tokens( + std::vector tokens, + const SamplingConfig* initial_sampling) override { + if (tokens.empty()) { + return Error::InvalidArgument; + } + float first_token_temp = temperature_; + if (initial_sampling != nullptr) { + if (initial_sampling->top_p != 1.0f || initial_sampling->top_k != 0 || + initial_sampling->seed != 0) { + ET_LOG( + Error, + "Gemma4_31BSession: only temperature is supported; top_p/top_k/seed " + "are not implemented"); + return Error::NotSupported; + } + first_token_temp = initial_sampling->temperature; + } + const int64_t T = static_cast(tokens.size()); + const auto ctx_it = metadata_.find(kMaxContextLen); + if (ctx_it != metadata_.end() && pos_ + T >= ctx_it->second) { + ET_LOG(Error, "prefill_tokens would leave no room to generate"); + return Error::InvalidArgument; + } + + stop_.store(false, std::memory_order_relaxed); + int64_t offset = 0; + while (offset < T) { + int64_t chunk = T - offset; + if (max_prefill_chunk_ > 0) { + chunk = std::min(chunk, max_prefill_chunk_); + } +#ifdef EXECUTORCH_BUILD_CUDA + if (chunk > 1 && chunk < min_prefill_chunk_) { + chunk = 1; + } +#endif + auto sampled = + run_prefill_chunk(tokens.data() + offset, chunk, first_token_temp); + ET_CHECK_OK_OR_RETURN_ERROR(sampled.error()); + pending_ = sampled.get(); + pos_ += chunk; + offset += chunk; + } + prev_decode_token_ = tokens.back(); + return Error::Ok; + } + + Result decode_one(const SamplingConfig& sampling) override { + if (sampling.top_p != 1.0f || sampling.top_k != 0 || sampling.seed != 0) { + ET_LOG( + Error, + "Gemma4_31BSession: only temperature is supported; top_p/top_k/seed " + "are not implemented"); + return Error::NotSupported; + } + ET_CHECK_OR_RETURN_ERROR( + pending_.has_value(), + InvalidState, + "decode_one requires a pending token; call prefill_tokens() first"); + temperature_ = sampling.temperature; + + const uint64_t token = pending_.value(); + const bool is_eos = eos_ids_.find(token) != eos_ids_.end(); + const uint64_t prev = prev_decode_token_.value_or(token); + auto dec = tokenizer_->decode(prev, token); + if (!dec.ok()) { + ET_LOG( + Error, + "Tokenizers error code %d", + static_cast(dec.error())); + return Error::InvalidArgument; + } + std::string text_piece = std::move(*dec); + + if (is_eos || stop_.load(std::memory_order_relaxed)) { + pending_.reset(); + return DecodeResult{ + token, std::move(text_piece), is_eos, /*is_terminal=*/true}; + } + + const auto ctx_it = metadata_.find(kMaxContextLen); + if (ctx_it != metadata_.end()) { + ET_CHECK_OR_RETURN_ERROR( + pos_ < ctx_it->second, + InvalidArgument, + "decode_one would exceed context capacity"); + } + + decode_token_data_[0] = static_cast(token); + decode_pos_data_[0] = pos_; + std::vector inputs; + inputs.push_back(EValue(decode_tokens_)); + inputs.push_back(EValue(decode_pos_)); +#ifdef EXECUTORCH_BUILD_CUDA + set_temp(temperature_); + inputs.push_back(EValue(temp_tensor_)); + const char* method = "decode"; +#else + (void)inputs; + return Error::NotSupported; +#endif + auto sampled = + run_locked(method, inputs, temperature_, /*sync_after=*/false); + ET_CHECK_OK_OR_RETURN_ERROR(sampled.error()); + pending_ = sampled.get(); + prev_decode_token_ = token; + pos_ += 1; + return DecodeResult{ + token, std::move(text_piece), /*is_eos=*/false, /*is_terminal=*/false}; + } + + Error seek(int64_t pos) override { + (void)pos; + return Error::NotSupported; + } + + int64_t position() const override { + return pos_; + } + + Error reset() override { + pos_ = 0; + pending_.reset(); + prev_decode_token_.reset(); + stop_.store(false, std::memory_order_relaxed); + return Error::Ok; + } + + void stop() override { + stop_.store(true, std::memory_order_relaxed); + } + + private: +#ifdef EXECUTORCH_BUILD_CUDA + void set_temp(float t) { + temp_val_ = (t <= 0.0f) ? 1e-6f : t; + } +#endif + + Result + run_prefill_chunk(const uint64_t* tokens, int64_t T, float temperature) { + std::vector token_data(tokens, tokens + T); + std::vector pos_data(T); + for (int64_t i = 0; i < T; ++i) { + pos_data[i] = pos_ + i; + } + auto tokens_tensor = from_blob( + token_data.data(), + {1, static_cast(T)}, + executorch::aten::ScalarType::Long); + auto pos_tensor = from_blob( + pos_data.data(), + {static_cast(T)}, + executorch::aten::ScalarType::Long); + std::vector inputs; + inputs.push_back(EValue(tokens_tensor)); + inputs.push_back(EValue(pos_tensor)); +#ifdef EXECUTORCH_BUILD_CUDA + set_temp(temperature); + inputs.push_back(EValue(temp_tensor_)); + const char* method = (T >= min_prefill_chunk_) ? "prefill" : "decode"; +#else + (void)inputs; + (void)temperature; + return Error::NotSupported; +#endif + return run_locked(method, inputs, temperature, /*sync_after=*/true); + } + + Result run_locked( + const char* method, + std::vector& inputs, + float temperature, + bool sync_after) { + std::lock_guard guard(*exec_mutex_); +#ifdef EXECUTORCH_BUILD_CUDA + if (mutable_ctx_ != 0) { + ::executorch::backends::cuda::mutable_state_set_active( + mutable_ctx_, session_token_); + } +#endif + auto res = module_->execute(method, inputs); +#ifdef EXECUTORCH_BUILD_CUDA + if (mutable_ctx_ != 0) { + ::executorch::backends::cuda::mutable_state_set_active( + mutable_ctx_, ::executorch::backends::cuda::kNoMutableSession); + } +#endif + ET_CHECK_OK_OR_RETURN_ERROR(res.error()); + auto sampled = read_sampled_token(res.get()[0].toTensor(), temperature); + ET_CHECK_OK_OR_RETURN_ERROR(sampled.error()); +#ifdef EXECUTORCH_BUILD_CUDA + if (sync_after && cudaDeviceSynchronize() != cudaSuccess) { + ET_LOG(Error, "run_locked: cudaDeviceSynchronize failed"); + return Error::Internal; + } +#else + (void)sync_after; +#endif + return sampled.get(); + } + + Module* module_; + std::mutex* exec_mutex_; + int mutable_ctx_; + int session_token_; + std::atomic* live_sessions_; + ::tokenizers::Tokenizer* tokenizer_; + std::unordered_map metadata_; + std::unordered_set eos_ids_; + int64_t max_prefill_chunk_; + int64_t min_prefill_chunk_; + + int64_t pos_ = 0; + std::optional pending_; + std::optional prev_decode_token_; + float temperature_ = -1.0f; + std::atomic stop_{false}; + + int64_t decode_token_data_[1] = {0}; + int64_t decode_pos_data_[1] = {0}; + TensorPtr decode_tokens_; + TensorPtr decode_pos_; +#ifdef EXECUTORCH_BUILD_CUDA + float temp_val_ = 1e-6f; + TensorPtr temp_tensor_; +#endif +}; + +} // namespace + +Result> Gemma4_31BEngine::create( + const Gemma4_31BConfig& config) { + if (config.model_path.empty() || config.tokenizer_path.empty()) { + ET_LOG( + Error, "Gemma4_31BEngine: model_path and tokenizer_path are required"); + return Error::InvalidArgument; + } + + auto tokenizer = std::make_unique<::tokenizers::HFTokenizer>(); + if (tokenizer->load(config.tokenizer_path) != ::tokenizers::Error::Ok) { + ET_LOG(Error, "Gemma4_31BEngine: failed to load tokenizer"); + return Error::InvalidArgument; + } + + std::vector data_files; + if (!config.data_path.empty()) { + data_files.push_back(config.data_path); + } + auto meta_module = std::make_unique( + config.model_path, data_files, Module::LoadMode::File); + auto metadata_result = get_llm_metadata(tokenizer.get(), meta_module.get()); + if (metadata_result.error() != Error::Ok) { + ET_LOG(Error, "Gemma4_31BEngine: failed to read metadata"); + return metadata_result.error(); + } + + auto eos_ids = get_eos_ids(tokenizer.get(), meta_module.get()); + eos_ids.insert(static_cast(config.eos_id)); + add_token_piece(tokenizer.get(), eos_ids, ""); + + const auto& metadata = metadata_result.get(); + int64_t max_prefill_chunk = 1; + auto max_ctx_it = metadata.find(kMaxContextLen); + if (max_ctx_it != metadata.end() && max_ctx_it->second > 1) { + max_prefill_chunk = max_ctx_it->second - 1; + } + if (auto get_result = meta_module->get("get_max_prefill_chunk"); + get_result.ok()) { + max_prefill_chunk = get_result->toScalar().to(); + } + int64_t min_prefill_chunk = 1; +#ifdef EXECUTORCH_BUILD_CUDA + min_prefill_chunk = 5; + if (auto get_result = meta_module->get("get_min_prefill_chunk"); + get_result.ok()) { + min_prefill_chunk = get_result->toScalar().to(); + } +#endif + + bool registered_mutable = false; + int mutable_ctx = 0; +#ifdef EXECUTORCH_BUILD_CUDA + if (!config.enable_cuda_graph) { + mutable_ctx = ::executorch::backends::cuda::mutable_state_create_context(); + if (register_mutable_fqns(meta_module.get(), mutable_ctx) == Error::Ok) { + registered_mutable = true; + ::executorch::backends::cuda::mutable_state_begin_load(mutable_ctx); + } else { + ::executorch::backends::cuda::mutable_state_destroy_context(mutable_ctx); + mutable_ctx = 0; + } + } +#endif + + auto module_res = build_gemma_module(config); +#ifdef EXECUTORCH_BUILD_CUDA + if (registered_mutable) { + ::executorch::backends::cuda::mutable_state_end_load(); + } +#endif + if (module_res.error() != Error::Ok) { +#ifdef EXECUTORCH_BUILD_CUDA + if (mutable_ctx != 0) { + ::executorch::backends::cuda::mutable_state_destroy_context(mutable_ctx); + } +#endif + return module_res.error(); + } + + bool rebind_available = false; +#ifdef EXECUTORCH_BUILD_CUDA + if (mutable_ctx != 0) { + rebind_available = + ::executorch::backends::cuda::mutable_state_available(mutable_ctx); + if (rebind_available && + ::executorch::backends::cuda::mutable_state_validate_coverage( + mutable_ctx) != Error::Ok) { + ET_LOG( + Error, + "Gemma4_31BEngine: mutable-buffer coverage check failed; disabling " + "multi-session"); + rebind_available = false; + } + } +#endif + + return std::unique_ptr(new Gemma4_31BEngine( + config, + std::move(tokenizer), + metadata, + std::move(eos_ids), + std::move(module_res.get()), + max_prefill_chunk, + min_prefill_chunk, + rebind_available, + mutable_ctx)); +} + +Gemma4_31BEngine::~Gemma4_31BEngine() { +#ifdef EXECUTORCH_BUILD_CUDA + if (mutable_ctx_ != 0) { + ::executorch::backends::cuda::mutable_state_destroy_context(mutable_ctx_); + } +#endif +} + +Result> Gemma4_31BEngine::create_session() { + const int cap = + serving_capacity().max_physical_sessions_without_weight_duplication; + { + std::lock_guard g(exec_mutex_); + if (live_sessions_.load() >= cap) { + return Error::InvalidState; + } + live_sessions_.fetch_add(1); + } + + int token = -1; +#ifdef EXECUTORCH_BUILD_CUDA + if (rebind_available_) { + auto t = ::executorch::backends::cuda::mutable_state_create_session( + mutable_ctx_); + if (t.error() != Error::Ok) { + live_sessions_.fetch_sub(1); + return t.error(); + } + token = t.get(); + } +#endif + return std::unique_ptr(new Gemma4_31BSession( + shared_module_.get(), + &exec_mutex_, + mutable_ctx_, + token, + &live_sessions_, + tokenizer_.get(), + metadata_, + eos_ids_, + max_prefill_chunk_, + min_prefill_chunk_)); +} + +LLMServingCapacity Gemma4_31BEngine::serving_capacity() const { + LLMServingCapacity cap; +#ifdef EXECUTORCH_BUILD_CUDA + if (rebind_available_) { + cap.max_physical_sessions_without_weight_duplication = + config_.max_sessions > 1 ? config_.max_sessions : 1; + cap.estimated_bytes_per_session = + ::executorch::backends::cuda::mutable_state_bytes_per_session( + mutable_ctx_); + } +#endif + return cap; +} + +} // namespace executorch::extension::llm diff --git a/examples/models/gemma4_31b/gemma4_31b_engine.h b/examples/models/gemma4_31b/gemma4_31b_engine.h new file mode 100644 index 00000000000..92eaf1b02da --- /dev/null +++ b/examples/models/gemma4_31b/gemma4_31b_engine.h @@ -0,0 +1,91 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace executorch::extension::llm { + +struct Gemma4_31BConfig { + std::string model_path; + std::string data_path; + std::string tokenizer_path; + int32_t max_sessions = 1; + int64_t eos_id = 1; + bool enable_cuda_graph = false; +}; + +class ET_EXPERIMENTAL Gemma4_31BEngine : public LLMEngine { + public: + static ::executorch::runtime::Result> + create(const Gemma4_31BConfig& config); + + ~Gemma4_31BEngine() override; + + ::executorch::runtime::Result> create_session() + override; + + LLMServingCapacity serving_capacity() const override; + + const std::unordered_map& metadata() const override { + return metadata_; + } + + ::tokenizers::Tokenizer* tokenizer() const { + return tokenizer_.get(); + } + + Gemma4_31BEngine(const Gemma4_31BEngine&) = delete; + Gemma4_31BEngine& operator=(const Gemma4_31BEngine&) = delete; + + private: + Gemma4_31BEngine( + Gemma4_31BConfig config, + std::unique_ptr<::tokenizers::Tokenizer> tokenizer, + std::unordered_map metadata, + std::unordered_set eos_ids, + std::unique_ptr shared_module, + int64_t max_prefill_chunk, + int64_t min_prefill_chunk, + bool rebind_available, + int mutable_ctx) + : config_(std::move(config)), + tokenizer_(std::move(tokenizer)), + metadata_(std::move(metadata)), + eos_ids_(std::move(eos_ids)), + shared_module_(std::move(shared_module)), + max_prefill_chunk_(max_prefill_chunk), + min_prefill_chunk_(min_prefill_chunk), + rebind_available_(rebind_available), + mutable_ctx_(mutable_ctx) {} + + Gemma4_31BConfig config_; + std::unique_ptr<::tokenizers::Tokenizer> tokenizer_; + std::unordered_map metadata_; + std::unordered_set eos_ids_; + std::unique_ptr shared_module_; + std::mutex exec_mutex_; + int64_t max_prefill_chunk_ = 0; + int64_t min_prefill_chunk_ = 1; + bool rebind_available_ = false; + int mutable_ctx_ = 0; + std::atomic live_sessions_{0}; +}; + +} // namespace executorch::extension::llm diff --git a/examples/models/gemma4_31b/gemma4_31b_worker.cpp b/examples/models/gemma4_31b/gemma4_31b_worker.cpp new file mode 100644 index 00000000000..4d979003d9c --- /dev/null +++ b/examples/models/gemma4_31b/gemma4_31b_worker.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +DEFINE_string(model_path, "", "Model .pte file path."); +DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path."); +DEFINE_string(data_path, "", "Data file (.ptd) for delegated weights."); +DEFINE_int32( + max_sessions, + 1, + "Max physical sessions to host on one weight allocation. CUDA may raise " + "this when per-session mutable rebinding is available."); +DEFINE_bool( + warm_resume, + true, + "Warm append-only resume for named sessions when the engine supports them."); +DEFINE_int32(bos_id, 2, "BOS token id to prepend to server-rendered prompts."); +DEFINE_int32(eos_id, 1, "EOS token id (Gemma convention: 1)."); + +namespace { +namespace llm = ::executorch::extension::llm; +using ::executorch::runtime::Error; +} // namespace + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + if (FLAGS_model_path.empty() || FLAGS_tokenizer_path.empty()) { + ET_LOG( + Error, "gemma4_31b_worker: --model_path and --tokenizer_path required"); + return 1; + } + + llm::Gemma4_31BConfig config; + config.model_path = FLAGS_model_path; + config.data_path = FLAGS_data_path; + config.tokenizer_path = FLAGS_tokenizer_path; + config.max_sessions = FLAGS_max_sessions; + config.eos_id = FLAGS_eos_id; + + auto engine_result = llm::Gemma4_31BEngine::create(config); + if (engine_result.error() != Error::Ok) { + ET_LOG(Error, "gemma4_31b_worker: failed to create engine"); + return 1; + } + auto engine = std::move(engine_result.get()); + + return llm::run_worker_stdio_loop( + *engine, + *engine->tokenizer(), + engine->metadata(), + FLAGS_warm_resume, + {static_cast(FLAGS_bos_id)}); +} diff --git a/examples/models/gemma4_31b/main.cpp b/examples/models/gemma4_31b/main.cpp index 6cf65cc8246..7a3dfcf89ba 100644 --- a/examples/models/gemma4_31b/main.cpp +++ b/examples/models/gemma4_31b/main.cpp @@ -6,6 +6,174 @@ * LICENSE file in the root directory of this source tree. */ +#ifdef EXECUTORCH_BUILD_CUDA + +// Thin CUDA CLI over Gemma4_31BEngine / LLMSession. The non-CUDA legacy runner +// remains below for the existing MLX target; serving is CUDA-only for now. + +#include + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include + +DEFINE_string(model_path, "", "Model .pte file path."); +DEFINE_string(data_path, "", "Data file (.ptd) for CUDA backend."); +DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path."); +DEFINE_string(prompt, "Hello", "Prompt text."); +DEFINE_string( + prompt_file, + "", + "Path to file containing prompt text (overrides --prompt)."); +DEFINE_double(temperature, 0.8, "Sampling temperature (0 = near-greedy)."); +DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate."); +DEFINE_int32(bos_id, 2, "BOS token id to prepend (Gemma convention: 2)."); +DEFINE_int32(eos_id, 1, "EOS token id (Gemma convention: 1)."); +DEFINE_bool( + raw_prompt, + false, + "Skip chat-template wrapping (use if the prompt is already formatted)."); +DEFINE_bool( + cuda_graph, + false, + "Enable CUDA graph capture for the decode method. CUDA only."); + +namespace llm = ::executorch::extension::llm; +using ::executorch::runtime::Error; + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + if (FLAGS_model_path.empty()) { + ET_LOG(Error, "Must specify --model_path"); + return 1; + } + if (FLAGS_tokenizer_path.empty()) { + ET_LOG(Error, "Must specify --tokenizer_path"); + return 1; + } + + llm::Stats stats; + size_t gpu_free_bytes = 0, gpu_total_bytes = 0; + if (cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes) == cudaSuccess) { + stats.gpu_total_bytes = gpu_total_bytes; + stats.gpu_free_before_load_bytes = gpu_free_bytes; + } + + stats.model_load_start_ms = llm::time_in_ms(); + + llm::Gemma4_31BConfig config; + config.model_path = FLAGS_model_path; + config.data_path = FLAGS_data_path; + config.tokenizer_path = FLAGS_tokenizer_path; + config.eos_id = FLAGS_eos_id; + config.enable_cuda_graph = FLAGS_cuda_graph; + + printf("Loading methods...\n"); + auto engine_result = llm::Gemma4_31BEngine::create(config); + if (engine_result.error() != Error::Ok) { + ET_LOG(Error, "Failed to create Gemma 4 31B engine"); + return 1; + } + auto engine = std::move(engine_result.get()); + + auto session_result = engine->create_session(); + if (session_result.error() != Error::Ok) { + ET_LOG(Error, "Failed to create session"); + return 1; + } + auto session = std::move(session_result.get()); + + stats.model_load_end_ms = llm::time_in_ms(); + if (cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes) == cudaSuccess) { + stats.gpu_free_after_load_bytes = gpu_free_bytes; + } + + std::string prompt_text = FLAGS_prompt; + if (!FLAGS_prompt_file.empty()) { + std::ifstream f(FLAGS_prompt_file); + if (!f.is_open()) { + ET_LOG( + Error, "Failed to open prompt file: %s", FLAGS_prompt_file.c_str()); + return 1; + } + prompt_text = std::string( + (std::istreambuf_iterator(f)), std::istreambuf_iterator()); + } + + if (!FLAGS_raw_prompt) { + prompt_text = "<|turn>user\n" + prompt_text + + "\n<|turn>model\n<|channel>thought\n"; + } + + auto encode_result = engine->tokenizer()->encode(prompt_text); + if (!encode_result.ok()) { + ET_LOG(Error, "Failed to encode prompt"); + return 1; + } + auto prompt_tokens = std::move(*encode_result); + prompt_tokens.insert( + prompt_tokens.begin(), static_cast(FLAGS_bos_id)); + const int64_t num_prompt_tokens = static_cast(prompt_tokens.size()); + printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens); + stats.num_prompt_tokens = num_prompt_tokens; + + llm::SamplingConfig sampling; + sampling.temperature = static_cast(FLAGS_temperature); + stats.inference_start_ms = llm::time_in_ms(); + if (session->prefill_tokens(prompt_tokens, &sampling) != Error::Ok) { + ET_LOG(Error, "Prefill failed"); + return 1; + } + stats.prompt_eval_end_ms = llm::time_in_ms(); + stats.first_token_ms = stats.prompt_eval_end_ms; + + int64_t num_generated = 0; + for (int32_t step = 0; step < FLAGS_max_new_tokens; ++step) { + auto step_result = session->decode_one(sampling); + if (step_result.error() != Error::Ok) { + ET_LOG(Error, "Decode step %d failed", step); + return 1; + } + const auto& d = step_result.get(); + if (d.is_terminal) { + break; + } + if (step == 0) { + stats.first_token_ms = llm::time_in_ms(); + } + ++num_generated; + if (!d.text_piece.empty()) { + fwrite(d.text_piece.data(), 1, d.text_piece.size(), stdout); + fflush(stdout); + } + } + printf("\n"); + + stats.inference_end_ms = llm::time_in_ms(); + stats.num_generated_tokens = num_generated; + if (cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes) == cudaSuccess) { + stats.gpu_free_after_generate_bytes = gpu_free_bytes; + stats.gpu_peak_usage_mb = + (stats.gpu_total_bytes - gpu_free_bytes) / 1024.0 / 1024.0; + } + + llm::print_report(stats); + return 0; +} + +#else + // Gemma 4 31B-IT runner for ExecuTorch. Supports two backends: // CUDA — exports ``prefill`` (T>=2, dynamic) + ``decode`` (T=1, static) // methods sharing KV-cache buffers; on-device Gumbel-max sampling @@ -416,3 +584,5 @@ int main(int argc, char** argv) { llm::print_report(stats); return 0; } + +#endif // EXECUTORCH_BUILD_CUDA diff --git a/examples/models/gemma4_31b/serve.py b/examples/models/gemma4_31b/serve.py new file mode 100644 index 00000000000..1e6f38606ee --- /dev/null +++ b/examples/models/gemma4_31b/serve.py @@ -0,0 +1,189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""OpenAI-compatible HTTP server for Gemma 4 31B on CUDA.""" + +import argparse +import logging +import os +import re +from pathlib import Path + +from executorch.extension.llm.server.python.chat_template import ChatTemplate +from executorch.extension.llm.server.python.serving_chat import ServingChat +from executorch.extension.llm.server.python.session_runtime import SessionRuntime +from executorch.extension.llm.server.python.tool_parsers import ( + GemmaToolCallDetector, + HermesDetector, + QwenFunctionCallDetector, +) +from executorch.extension.llm.server.python.worker_client import spawn_worker + +logger = logging.getLogger(__name__) + +_GEMMA_CHANNEL_SPECIALS = {"<|channel>", "", "<|think|>"} +_GEMMA_CHANNEL_BLOCK = re.compile(r"<\|channel>.*?", re.DOTALL) + + +def _strip_gemma_channels(text: str) -> str: + text = _GEMMA_CHANNEL_BLOCK.sub("", text) + open_idx = text.find("<|channel>") + if open_idx != -1: + text = text[:open_idx] + return text.replace("", "").replace("<|think|>", "").strip() + + +def _default_worker_bin() -> str: + repo_root = Path(__file__).resolve().parents[3] + return str( + repo_root + / "cmake-out" + / "examples" + / "models" + / "gemma4_31b" + / "gemma4_31b_worker" + ) + + +def _spawn(args): + env = dict(os.environ) + conda = os.environ.get("CONDA_PREFIX") + if conda: + env["LD_LIBRARY_PATH"] = f"{conda}/lib:" + env.get("LD_LIBRARY_PATH", "") + worker_bin = args.worker_bin or _default_worker_bin() + cmd = [ + worker_bin, + "--model_path", + args.model_path, + "--tokenizer_path", + args.tokenizer_path, + "--max_sessions", + str(args.max_sessions), + f"--warm_resume={'true' if args.warm_resume else 'false'}", + "--bos_id", + str(args.bos_id), + "--eos_id", + str(args.eos_id), + ] + if args.data_path: + cmd += ["--data_path", args.data_path] + logger.info("Starting Gemma4 31B worker subprocess...") + return spawn_worker(cmd, env=env) + + +def _tool_detector(name: str): + if name == "gemma": + return GemmaToolCallDetector + if name == "hermes": + return HermesDetector + if name == "qwen": + return QwenFunctionCallDetector + if name == "none": + return None + raise ValueError(f"unknown tool parser: {name}") + + +def build_app_from_args(args): + template = ChatTemplate( + args.hf_tokenizer, + assistant_header="<|turn>model\n", + # Gemma's HF template starts with bos_token text. Strip that text before + # C++ tokenization; the worker prepends the numeric BOS id. + strip_rendered_bos=True, + ) + worker = _spawn(args) + runtime = SessionRuntime(worker) + serving = ServingChat( + runtime, + template, + args.model_id, + max_context=args.max_context, + tool_detector_cls=_tool_detector(args.tool_parser), + prompt_token_offset=1, + content_filter=_strip_gemma_channels, + content_filter_specials=_GEMMA_CHANNEL_SPECIALS, + ) + + from executorch.extension.llm.server.python.server import build_app + + app = build_app(serving, args.model_id) + + @app.on_event("shutdown") + def _stop_worker(): + runtime.close_worker() + + return app, args.model_id + + +def main() -> None: + p = argparse.ArgumentParser( + description="OpenAI-compatible CUDA LLM server for Gemma 4 31B" + ) + p.add_argument("--model-path", required=True, help="Path to the .pte model") + p.add_argument("--data-path", default=None, help="Path to the .ptd delegate blob") + p.add_argument("--tokenizer-path", required=True, help="Path to the tokenizer.json") + p.add_argument( + "--hf-tokenizer", + required=True, + help="HF tokenizer id/dir for the model's chat template", + ) + p.add_argument("--model-id", default="gemma4-31b") + p.add_argument("--host", default="127.0.0.1") + p.add_argument("--port", type=int, default=8000) + p.add_argument("--max-context", type=int, default=None) + p.add_argument( + "--num-runners", + type=int, + default=1, + help="Worker processes. 1 only; more would duplicate the weights.", + ) + p.add_argument( + "--max-sessions", + type=int, + default=1, + help="Isolated sessions the CUDA worker may host when the export has " + "mutable-buffer metadata.", + ) + p.add_argument( + "--warm-resume", + action=argparse.BooleanOptionalAction, + default=True, + help="Warm append-only resume for named sessions when available.", + ) + p.add_argument( + "--tool-parser", + choices=("gemma", "hermes", "qwen", "none"), + default="gemma", + help="Tool-call format parser to apply to model output.", + ) + p.add_argument( + "--bos-id", + type=int, + default=2, + help="BOS token id to prepend in the worker. The launcher strips the " + "HF template's literal before C++ tokenization.", + ) + p.add_argument("--eos-id", type=int, default=1) + p.add_argument( + "--worker-bin", + default=None, + help="Path to the gemma4_31b_worker binary.", + ) + args = p.parse_args() + logging.basicConfig(level=logging.INFO) + + if args.num_runners != 1: + p.error("Only 1 worker process is supported; more would duplicate weights.") + + app, _ = build_app_from_args(args) + + import uvicorn + + uvicorn.run(app, host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/examples/models/gemma4_31b/test_ondevice_serving.py b/examples/models/gemma4_31b/test_ondevice_serving.py new file mode 100644 index 00000000000..f71fb14fa8c --- /dev/null +++ b/examples/models/gemma4_31b/test_ondevice_serving.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import urllib.request + +import pytest + +from executorch.extension.llm.server.python.chat_template import ChatTemplate +from executorch.extension.llm.server.python.protocol import ChatMessage + +_SERVER = os.environ.get("GEMMA_SERVER_URL") +_HF_DIR = os.environ.get( + "GEMMA_HF_DIR", "/home/mnachin/local/scripts/models/gemma-4-31B-it-HQQ-INT4" +) + + +pytestmark = pytest.mark.skipif( + not _SERVER or not os.path.isdir(_HF_DIR), + reason="set GEMMA_SERVER_URL and GEMMA_HF_DIR to run Gemma on-device tests", +) + + +def _post(path: str, payload: dict) -> dict: + data = json.dumps(payload).encode("utf-8") + req = urllib.request.Request( + _SERVER.rstrip("/") + path, + data=data, + headers={"Content-Type": "application/json"}, + method="POST", + ) + with urllib.request.urlopen(req, timeout=120) as resp: + return json.loads(resp.read().decode("utf-8")) + + +def test_prompt_tokens_match_real_template_with_numeric_bos_prefix(): + _ = pytest.importorskip("transformers") + from transformers import AutoTokenizer + + template = ChatTemplate( + _HF_DIR, + assistant_header="<|turn>model\n", + strip_rendered_bos=True, + ) + tok = AutoTokenizer.from_pretrained(_HF_DIR) + messages = [ChatMessage(role="user", content="Say ok.")] + rendered = template.render(messages) + expected_ids = [tok.bos_token_id] + tok.encode(rendered, add_special_tokens=False) + + body = _post( + "/v1/chat/completions", + { + "model": "gemma4_31b", + "messages": [{"role": "user", "content": "Say ok."}], + "max_tokens": 1, + "temperature": 0, + "session_id": "gemma-bos-regression", + }, + ) + assert body["usage"]["prompt_tokens"] == len(expected_ids) diff --git a/examples/models/gemma4_31b/test_serve.py b/examples/models/gemma4_31b/test_serve.py new file mode 100644 index 00000000000..aa1e6c8d318 --- /dev/null +++ b/examples/models/gemma4_31b/test_serve.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pathlib +from types import SimpleNamespace + +import pytest + +from executorch.examples.models.gemma4_31b import serve + +_HERE = pathlib.Path(serve.__file__).resolve().parent +_REPO_ROOT = _HERE.parents[2] + + +def test_generic_server_does_not_reference_gemma4_31b(): + server_dir = _REPO_ROOT / "extension/llm/server" + offenders = [p for p in server_dir.rglob("*.py") if "gemma4_31b" in p.read_text()] + assert offenders == [] + + +def test_control_plane_runs_no_model_code(): + serve_src = (_HERE / "serve.py").read_text() + assert "Gemma4_31BEngine" not in serve_src + worker_src = (_HERE / "gemma4_31b_worker.cpp").read_text() + assert "Gemma4_31BEngine" in worker_src + + +def test_spawn_builds_worker_command(monkeypatch): + captured = {} + + def fake_spawn(cmd, env=None): + captured["cmd"] = cmd + return object() + + monkeypatch.setattr(serve, "spawn_worker", fake_spawn) + serve._spawn( + SimpleNamespace( + worker_bin="/bin/gemma_worker", + model_path="m.pte", + tokenizer_path="t.json", + data_path="d.ptd", + max_sessions=4, + warm_resume=True, + bos_id=2, + eos_id=1, + ) + ) + assert captured["cmd"] == [ + "/bin/gemma_worker", + "--model_path", + "m.pte", + "--tokenizer_path", + "t.json", + "--max_sessions", + "4", + "--warm_resume=true", + "--bos_id", + "2", + "--eos_id", + "1", + "--data_path", + "d.ptd", + ] + + +def test_spawn_defaults_worker_bin_and_omits_empty_data_path(monkeypatch): + captured = {} + monkeypatch.setattr( + serve, "spawn_worker", lambda cmd, env=None: captured.update(cmd=cmd) + ) + serve._spawn( + SimpleNamespace( + worker_bin=None, + model_path="m.pte", + tokenizer_path="t.json", + data_path=None, + max_sessions=1, + warm_resume=False, + bos_id=2, + eos_id=1, + ) + ) + cmd = captured["cmd"] + assert cmd[0].endswith("gemma4_31b_worker") + assert "--data_path" not in cmd + assert "--warm_resume=false" in cmd + + +def test_strip_gemma_channels_returns_visible_answer(): + text = "<|channel>thought\nscratch work\nThe answer." + assert serve._strip_gemma_channels(text) == "The answer." + + +def test_strip_gemma_channels_cuts_unclosed_channel(): + assert serve._strip_gemma_channels("Lead <|channel>thought") == "Lead" + + +def test_strip_gemma_channels_removes_stray_close(): + assert serve._strip_gemma_channels("Visible") == "Visible" + + +def test_rejects_multiple_runners(monkeypatch): + import sys + + monkeypatch.setattr( + sys, + "argv", + [ + "serve.py", + "--model-path", + "m.pte", + "--tokenizer-path", + "t.json", + "--hf-tokenizer", + "hf", + "--num-runners", + "2", + ], + ) + with pytest.raises(SystemExit): + serve.main() diff --git a/extension/llm/server/cpp/test_worker_loop.cpp b/extension/llm/server/cpp/test_worker_loop.cpp index a14d6518681..7c3a5d95ff9 100644 --- a/extension/llm/server/cpp/test_worker_loop.cpp +++ b/extension/llm/server/cpp/test_worker_loop.cpp @@ -58,6 +58,7 @@ class FakeSession : public LLMSession { int prefill_calls = 0; std::vector prefill_sizes; // size of each prefill_tokens() call + std::vector> prefill_batches; int fail_prefill_on = -1; // 0-based call index to fail (-1 = never) int decode_calls = 0; int fail_decode_on = -1; @@ -68,6 +69,7 @@ class FakeSession : public LLMSession { std::vector tokens, const SamplingConfig* /*initial_sampling*/ = nullptr) override { prefill_sizes.push_back(tokens.size()); + prefill_batches.push_back(tokens); if (prefill_calls++ == fail_prefill_on) { return ETError::Internal; // failed AFTER (notionally) mutating state } @@ -173,13 +175,14 @@ Emitted run( WorkerSessionState& st, bool warm, const nlohmann::json& req, - const std::unordered_map& md = {}) { + const std::unordered_map& md = {}, + const std::vector& prefix = {}) { static FakeTokenizer tok; std::ostringstream cap; std::streambuf* old = std::cout.rdbuf(cap.rdbuf()); Emitted em; try { - worker_handle_request(st, warm, tok, md, req); + worker_handle_request(st, warm, tok, md, req, prefix); } catch (const std::exception&) { em.threw = true; } @@ -325,6 +328,23 @@ void test_generated_token_ids_excludes_terminal() { st.resident_token_ids.size() == (size_t)st.session->position()); } +void test_prompt_prefix_ids_prepend_text_prompt_once() { + auto st = makeState(); + fake(st).steps = {{0, "", true, true}}; + auto em = + run(st, + /*warm=*/true, + {{"max_new_tokens", 1}, {"prompt", "ab"}}, + {}, + {2}); + check("prefix: prompt_tokens includes prefix", em.done["prompt_tokens"] == 3); + check( + "prefix: prefilled ids == [2,'a','b']", + fake(st).prefill_batches == + std::vector>{ + {2, static_cast('a'), static_cast('b')}}); +} + void test_stop_string_marks_dirty_and_omits_ids() { auto st = makeState(); fake(st).steps = { @@ -430,6 +450,7 @@ int main() { test_equal_prompt_no_empty_prefill(); test_anonymous_never_warm(); test_generated_token_ids_excludes_terminal(); + test_prompt_prefix_ids_prepend_text_prompt_once(); test_stop_string_marks_dirty_and_omits_ids(); test_prefill_failure_marks_dirty(); test_decode_failure_marks_dirty(); diff --git a/extension/llm/server/cpp/worker_loop.h b/extension/llm/server/cpp/worker_loop.h index 3cf4541a4e2..47be1320029 100644 --- a/extension/llm/server/cpp/worker_loop.h +++ b/extension/llm/server/cpp/worker_loop.h @@ -105,7 +105,8 @@ inline void worker_handle_request( bool warm, ::tokenizers::Tokenizer& tokenizer, const std::unordered_map& metadata, - const nlohmann::json& req) { + const nlohmann::json& req, + const std::vector& prompt_prefix_ids = {}) { LLMSession& session = *st.session; int64_t max_new = req.value("max_new_tokens", static_cast(-1)); const float temperature = req.value("temperature", 0.0f); @@ -129,7 +130,7 @@ inline void worker_handle_request( throw std::runtime_error( "exactly one of prompt / prompt_segments is required"); } - std::vector ids; + std::vector ids = prompt_prefix_ids; auto encode_text = [&](const std::string& text) { auto enc = tokenizer.encode(text, /*bos=*/0, /*eos=*/0); if (!enc.ok()) { @@ -397,7 +398,8 @@ inline int run_worker_stdio_loop( LLMEngine& engine, ::tokenizers::Tokenizer& tokenizer, const std::unordered_map& metadata, - bool enable_warm_resume = true) { + bool enable_warm_resume = true, + const std::vector& prompt_prefix_ids = {}) { WorkerSessions sessions(engine); worker_emit( {{"ready", true}, @@ -465,7 +467,8 @@ inline int run_worker_stdio_loop( } warm = enable_warm_resume; } - worker_handle_request(*st, warm, tokenizer, metadata, req); + worker_handle_request( + *st, warm, tokenizer, metadata, req, prompt_prefix_ids); } catch (const std::exception& e) { // report and keep serving worker_emit({{"error", std::string(e.what())}}); } diff --git a/extension/llm/server/python/chat_template.py b/extension/llm/server/python/chat_template.py index 1235f6fcf2c..92d869e6898 100644 --- a/extension/llm/server/python/chat_template.py +++ b/extension/llm/server/python/chat_template.py @@ -37,9 +37,16 @@ "<|end|>", "<|end_of_text|>", "", + "", "", ) +_TOOL_RESPONSE_GENERATION_PROMPT_MARKERS = ( + "", + "\n", + "\n", +) + def _content_text(content) -> str: """Best-effort text for the ChatML fallback: a str as-is, or the concatenated @@ -87,10 +94,26 @@ def __init__( hf_tokenizer_path: Optional[str] = None, default_template_kwargs: Optional[dict[str, Any]] = None, allow_fallback: bool = False, + assistant_header: str = "<|im_start|>assistant\n", + strip_rendered_prefix: str = "", + strip_rendered_bos: bool = False, + append_generation_prompt_after_tool_response: bool = False, + tool_response_generation_prompt_markers: Optional[tuple[str, ...]] = None, ): + if strip_rendered_prefix and strip_rendered_bos: + raise ValueError("use either strip_rendered_prefix or strip_rendered_bos") # Server-level defaults (e.g. {"enable_thinking": False}); per-request # chat_template_kwargs override these. self._defaults = default_template_kwargs or {} + self._assistant_header = assistant_header + self._strip_rendered_prefix = strip_rendered_prefix + self._append_generation_prompt_after_tool_response = ( + append_generation_prompt_after_tool_response + ) + self._tool_response_generation_prompt_markers = ( + tool_response_generation_prompt_markers + or _TOOL_RESPONSE_GENERATION_PROMPT_MARKERS + ) # Cache of the (deterministic) generation scaffold per resolved mode, so # warm-resume bookkeeping doesn't re-render a probe prompt every request. self._preamble_cache: dict[tuple, str] = {} @@ -110,11 +133,20 @@ def __init__( "No chat_template at %s; using approximate ChatML.", hf_tokenizer_path, ) + if strip_rendered_bos: + bos = getattr(self._hf, "bos_token", None) + if not isinstance(bos, str) or not bos: + raise ValueError( + "strip_rendered_bos requires a tokenizer with bos_token" + ) + self._strip_rendered_prefix = bos elif not allow_fallback: raise ValueError( "A chat template is required: pass --hf-tokenizer for the model's own " "template, or opt into approximate ChatML with --allow-chatml-fallback." ) + elif strip_rendered_bos: + raise ValueError("strip_rendered_bos requires --hf-tokenizer") else: logger.warning( "No --hf-tokenizer; using approximate ChatML (no thinking control)." @@ -130,15 +162,56 @@ def render( if self._hf is not None: dumped = [m.model_dump(exclude_none=True) for m in messages] _decode_tool_call_arguments(dumped) - return self._hf.apply_chat_template( + rendered = self._hf.apply_chat_template( dumped, tools=tools, add_generation_prompt=True, tokenize=False, **kwargs, ) + if self._strip_rendered_prefix and rendered.startswith( + self._strip_rendered_prefix + ): + rendered = rendered[len(self._strip_rendered_prefix) :] + elif self._strip_rendered_prefix: + raise ValueError( + "rendered prompt did not start with configured strip prefix " + f"{self._strip_rendered_prefix!r}" + ) + return self._maybe_append_tool_response_generation_prompt( + rendered, messages, tools, kwargs + ) return self._fallback(messages) + def _maybe_append_tool_response_generation_prompt( + self, + rendered: str, + messages: list[ChatMessage], + tools: Optional[list[dict[str, Any]]], + template_kwargs: dict[str, Any], + ) -> str: + if ( + not self._append_generation_prompt_after_tool_response + or not messages + or messages[-1].role != "tool" + ): + return rendered + has_prompt_marker = any( + rendered.endswith(marker) + for marker in self._tool_response_generation_prompt_markers + ) + if not has_prompt_marker and not rendered.endswith(self._assistant_header): + return rendered + + prompt = self._assistant_header + self.generation_preamble( + template_kwargs=template_kwargs, tools=tools + ) + if not prompt or rendered.endswith(prompt): + return rendered + if rendered.endswith(self._assistant_header): + return rendered + prompt[len(self._assistant_header) :] + return rendered + prompt + def generation_preamble( self, template_kwargs: Optional[dict[str, Any]] = None, @@ -177,12 +250,15 @@ def generation_preamble( tools=tools, template_kwargs=template_kwargs, ) - marker = "<|im_start|>assistant\n" + marker = self._assistant_header idx = rendered.rfind(marker) preamble = rendered[idx + len(marker) :] if idx != -1 else "" self._preamble_cache[key] = preamble return preamble + def assistant_header(self) -> str: + return self._assistant_header + def chat_template_str(self) -> Optional[str]: """Raw chat-template string (for tool-format auto-detection), if available.""" return ( diff --git a/extension/llm/server/python/openai_transcript.py b/extension/llm/server/python/openai_transcript.py index d6e614822ec..78d0569bb33 100644 --- a/extension/llm/server/python/openai_transcript.py +++ b/extension/llm/server/python/openai_transcript.py @@ -39,6 +39,9 @@ # last user; the open form is the think-mode generation preamble). Anything else # in that region is unrecognized -> the splice falls back to plain text. _THINK_SCAFFOLD_RE = re.compile(r"\A(?:\n\n\n\n|\n)?\Z") +_GEMMA_TOOL_CALL_START = "<|tool_call>" +_GEMMA_TOOL_CALL_END = "" +_GEMMA_QUOTE = '<|"|>' def _normalize_tool_args(args): @@ -57,9 +60,43 @@ def _normalize_tool_args(args): return args +def _find_gemma_tool_call_span(rendered: str, search_pos: int): + start = rendered.find(_GEMMA_TOOL_CALL_START, search_pos) + if start == -1: + return None + i = start + len(_GEMMA_TOOL_CALL_START) + in_string = False + depth = 0 + saw_object = False + while i < len(rendered): + if rendered.startswith(_GEMMA_QUOTE, i): + in_string = not in_string + i += len(_GEMMA_QUOTE) + continue + if not in_string: + if rendered.startswith(_GEMMA_TOOL_CALL_END, i): + if saw_object and depth == 0: + return start, i + len(_GEMMA_TOOL_CALL_END) + ch = rendered[i] + if ch in "{[": + depth += 1 + saw_object = saw_object or ch == "{" + elif ch in "}]": + if depth == 0: + return None + depth -= 1 + i += 1 + return None + + class OpenAITranscriptState: def __init__(self, template: ChatTemplate): self._template = template + self._assist_hdr = ( + template.assistant_header() + if hasattr(template, "assistant_header") + else _ASSIST_HDR + ) # session_id -> [{"fp": str, "ids": list[int] | None}, ...] (one per # assistant turn we produced, in order). Cleared on reset/close. self._turns: dict[str, list[dict]] = {} @@ -84,8 +121,7 @@ def _assistant_fingerprint(content, tool_calls) -> str: blob = json.dumps([content or "", norm], sort_keys=True, ensure_ascii=False) return hashlib.sha1(blob.encode("utf-8")).hexdigest() - @staticmethod - def _normalize_scaffold(text_chunk: str, preamble: str) -> Optional[str]: + def _normalize_scaffold(self, text_chunk: str, preamble: str) -> Optional[str]: """Force the scaffold region (between the last assistant header in `text_chunk` and its end) to equal `preamble`, so the worker re-tokenizes the exact resident scaffold. The region is empty (history stripped it -> @@ -97,10 +133,10 @@ def _normalize_scaffold(text_chunk: str, preamble: str) -> Optional[str]: # assistant header (the fix stays a true no-op for non-think models). if not preamble: return text_chunk - h = text_chunk.rfind(_ASSIST_HDR) + h = text_chunk.rfind(self._assist_hdr) if h == -1: return None - base = h + len(_ASSIST_HDR) + base = h + len(self._assist_hdr) region = text_chunk[base:] if region == preamble: return text_chunk @@ -108,9 +144,8 @@ def _normalize_scaffold(text_chunk: str, preamble: str) -> Optional[str]: return None return text_chunk[:base] + preamble - @staticmethod def _split_on_sentinels( - rendered: str, sub: dict[str, dict] + self, rendered: str, sub: dict[str, dict] ) -> Optional[list[dict]]: """Split `rendered` on the sentinels into alternating {"text"} chunks and {"ids"} runs (each sentinel -> sub[sentinel] = {"ids", "preamble"}). The @@ -121,7 +156,7 @@ def _split_on_sentinels( segments: list[dict] = [] pos = 0 for mobj in pattern.finditer(rendered): - norm = OpenAITranscriptState._normalize_scaffold( + norm = self._normalize_scaffold( rendered[pos : mobj.start()], sub[mobj.group()]["preamble"] ) if norm is None: @@ -134,6 +169,50 @@ def _split_on_sentinels( segments.append({"text": rendered[pos:]}) return segments + def _split_gemma_tool_call_spans( + self, + rendered: str, + messages: list[ChatMessage], + splice: dict[int, dict], + ) -> Optional[list[dict]]: + segments: list[dict] = [] + pos = 0 + + for i, msg in enumerate(messages): + tool_calls = msg.tool_calls if msg.role == "assistant" else None + if not tool_calls: + continue + + start = None + end = pos + for _ in tool_calls: + span = _find_gemma_tool_call_span(rendered, end) + if span is None: + return None + span_start, span_end = span + if start is None: + start = span_start + end = span_end + if start is None: + return None + + if i in splice: + norm = self._normalize_scaffold( + rendered[pos:start], splice[i]["preamble"] + ) + if norm is None: + return None + if norm: + segments.append({"text": norm}) + segments.append({"ids": splice[i]["ids"]}) + else: + segments.append({"text": rendered[pos:end]}) + pos = end + + if pos < len(rendered): + segments.append({"text": rendered[pos:]}) + return segments + def build_prompt_input( self, *, @@ -182,6 +261,20 @@ def build_prompt_input( del stored[diverged_at:] if not splice: return PromptInput(text=rendered_prompt) + tool_splice = { + pos: data + for pos, data in splice.items() + if messages[pos].tool_calls and not (messages[pos].content or "") + } + if tool_splice and _GEMMA_TOOL_CALL_START in rendered_prompt: + segments = self._split_gemma_tool_call_spans( + rendered_prompt, messages, tool_splice + ) + return ( + PromptInput(segments=segments) + if segments is not None + else PromptInput(text=rendered_prompt) + ) token = uuid.uuid4().hex sentinel_at = {pos: f"<>" for j, pos in enumerate(splice)} sub = {sentinel_at[pos]: splice[pos] for pos in splice} diff --git a/extension/llm/server/python/serving_chat.py b/extension/llm/server/python/serving_chat.py index 1b85f8fba3d..3b1716aaee9 100644 --- a/extension/llm/server/python/serving_chat.py +++ b/extension/llm/server/python/serving_chat.py @@ -12,7 +12,7 @@ import json import logging import math -from typing import AsyncIterator, Optional +from typing import AsyncIterator, Callable, Optional from .chat_template import ChatTemplate from .errors import ( @@ -63,11 +63,16 @@ def __init__( model_id: str, max_context: Optional[int] = None, tool_detector_cls: Optional[type[HermesDetector]] = None, + prompt_token_offset: int = 0, + content_filter: Optional[Callable[[str], str]] = None, + content_filter_specials: Optional[set[str]] = None, ): self._runtime = runtime self._template = template self._model_id = model_id self._max_context = max_context + self._prompt_token_offset = prompt_token_offset + self._content_filter = content_filter # Detector CLASS; a fresh instance is created per request so streaming # state is never shared across concurrent requests. self._tool_detector_cls = tool_detector_cls @@ -82,7 +87,10 @@ def __init__( # never reaches the client, AND it backs _strip_specials for final # cleanup of already-parsed visible content. self._stops = template.turn_stop_sequences() - self._content_specials = template.special_tokens() + handled = content_filter_specials or set() + self._content_specials = [ + t for t in template.special_tokens() if t not in handled + ] # OpenAI/chat-template token-ID warm-resume state. Adapter-side, # not runtime; kept in lockstep with the worker's session state by # clearing both on reset/close. @@ -108,6 +116,11 @@ def _strip_specials(self, text: str) -> str: cut = _earliest_stop(text, self._content_specials) return text[:cut] if cut is not None else text + def _visible_content(self, text: str) -> str: + if self._content_filter is not None: + text = self._content_filter(text) + return self._strip_specials(text) + @staticmethod def _to_openai_tool_call(item: ToolCallItem) -> ToolCall: return ToolCall( @@ -165,10 +178,10 @@ def _extract_tools(self, req: ChatCompletionRequest, text: str): text, self._tool_schemas(req) ) if parsed.calls: - content = self._strip_specials(parsed.normal_text) or None + content = self._visible_content(parsed.normal_text) or None return [self._to_openai_tool_call(c) for c in parsed.calls], content text = parsed.normal_text - return None, self._strip_specials(text) + return None, self._visible_content(text) async def _clean( self, stream: AsyncIterator[str], stops: list[str], on_stop=None @@ -347,8 +360,9 @@ def _count_prompt_tokens(self, prompt: PromptInput) -> Optional[int]: tokenized length of {text} chunks. None when no tokenizer is available to count text (the worker still enforces the real context limit).""" if prompt.text is not None: - return self._template.count_tokens(prompt.text) - total = 0 + count = self._template.count_tokens(prompt.text) + return None if count is None else count + self._prompt_token_offset + total = self._prompt_token_offset for seg in prompt.segments: if "ids" in seg: total += len(seg["ids"]) @@ -478,6 +492,36 @@ async def _complete( ), ) + async def _stream_plain_content( + self, + req: ChatCompletionRequest, + prompt: PromptInput, + options: GenerationOptions, + stats: GenStats, + stops: list[str], + stop_hit: list[bool], + ) -> AsyncIterator[str]: + if self._content_filter is not None: + raw, stop_hit[0] = await self._collect_until_stop( + self._runtime.generate_stream(req.session_id, prompt, options, stats), + stops, + ) + content = self._visible_content(self._apply_stop(raw, stops)) + if content: + yield content + return + + def on_stop(): + stop_hit[0] = True + self._runtime.stop() + + async for token in self._clean( + self._runtime.generate_stream(req.session_id, prompt, options, stats), + stops, + on_stop=on_stop, + ): + yield token + async def _stream( self, req: ChatCompletionRequest, @@ -527,19 +571,9 @@ def chunk(delta: DeltaMessage, finish=None) -> str: req, self._truncate_raw(raw, req) ) else: - # Plain chat: stream tokens live (best UX), cutting at special - # tokens or request stop sequences and halting early on a hit. - def on_stop(): - stop_hit[0] = True - self._runtime.stop() - streamed: list[str] = [] - async for token in self._clean( - self._runtime.generate_stream( - req.session_id, prompt, options, stats - ), - stops, - on_stop=on_stop, + async for token in self._stream_plain_content( + req, prompt, options, stats, stops, stop_hit ): streamed.append(token) yield chunk(DeltaMessage(content=token)) diff --git a/extension/llm/server/python/tests/test_gemma_tool_parser.py b/extension/llm/server/python/tests/test_gemma_tool_parser.py new file mode 100644 index 00000000000..6dfa766d999 --- /dev/null +++ b/extension/llm/server/python/tests/test_gemma_tool_parser.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for GemmaToolCallDetector.""" + +import json + +from executorch.extension.llm.server.python.tool_parsers import GemmaToolCallDetector + +_TOOLS = { + "get_weather": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "units": {"type": "string"}, + }, + }, + "add": { + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + }, + "set_alarm": { + "type": "object", + "properties": { + "enabled": {"type": "boolean"}, + "labels": {"type": "array", "items": {"type": "string"}}, + "meta": { + "type": "object", + "properties": {"priority": {"type": "integer"}}, + }, + }, + }, +} + + +def _parse(text, tools=_TOOLS): + return GemmaToolCallDetector().detect_and_parse(text, tools) + + +def test_basic_call(): + text = ( + '<|tool_call>call:get_weather{city:<|"|>Paris<|"|>,' + 'units:<|"|>celsius<|"|>}' + ) + r = _parse(text) + assert len(r.calls) == 1 + assert r.calls[0].name == "get_weather" + assert json.loads(r.calls[0].arguments) == { + "city": "Paris", + "units": "celsius", + } + assert r.normal_text == "" + + +def test_multiple_calls_and_indices(): + text = ( + "<|tool_call>call:add{a:1,b:2}" + "<|tool_call>call:add{a:3,b:4}" + ) + r = _parse(text) + assert [c.tool_index for c in r.calls] == [0, 1] + assert [json.loads(c.arguments) for c in r.calls] == [ + {"a": 1, "b": 2}, + {"a": 3, "b": 4}, + ] + + +def test_nested_values(): + text = ( + '<|tool_call>call:set_alarm{enabled:true,labels:[<|"|>wake<|"|>,' + '<|"|>work<|"|>],meta:{<|"|>priority<|"|>:5}}' + ) + r = _parse(text) + assert json.loads(r.calls[0].arguments) == { + "enabled": True, + "labels": ["wake", "work"], + "meta": {"priority": 5}, + } + + +def test_leading_text_preserved(): + r = _parse("Checking.<|tool_call>call:add{a:1,b:2}") + assert r.normal_text == "Checking." + assert len(r.calls) == 1 + + +def test_no_tool_call_is_plain_text(): + text = "The weather is nice." + r = _parse(text) + assert r.calls == [] + assert r.normal_text == text + + +def test_undefined_tool_degrades_to_full_text(): + text = "<|tool_call>call:delete_everything{x:1}" + r = _parse(text) + assert r.calls == [] + assert r.normal_text == text + + +def test_truncated_call_degrades_without_leaking_markup(): + r = _parse('Sure <|tool_call>call:get_weather{city:<|"|>Par') + assert r.calls == [] + assert r.normal_text == "Sure" + + +def test_malformed_call_degrades_without_leaking_markup(): + r = _parse('Lead <|tool_call>call:get_weather{city:<|"|>Paris') + assert r.calls == [] + assert r.normal_text == "Lead" + + +def test_string_typed_bare_value_preserves_raw(): + tools = {"f": {"type": "object", "properties": {"code": {"type": "string"}}}} + r = _parse("<|tool_call>call:f{code:007}", tools) + assert json.loads(r.calls[0].arguments) == {"code": "007"} + + +def test_integer_typed_bare_value_is_int(): + tools = {"f": {"type": "object", "properties": {"n": {"type": "integer"}}}} + r = _parse("<|tool_call>call:f{n:007}", tools) + assert json.loads(r.calls[0].arguments) == {"n": 7} + + +def test_untyped_bare_vs_quoted_distinction(): + tools = {"f": {"type": "object", "properties": {}}} + r = _parse('<|tool_call>call:f{a:5,b:<|"|>5<|"|>}', tools) + assert json.loads(r.calls[0].arguments) == {"a": 5, "b": "5"} diff --git a/extension/llm/server/python/tests/test_streaming_stops.py b/extension/llm/server/python/tests/test_streaming_stops.py index ca926d92421..285b21c5c45 100644 --- a/extension/llm/server/python/tests/test_streaming_stops.py +++ b/extension/llm/server/python/tests/test_streaming_stops.py @@ -25,6 +25,10 @@ from fastapi.testclient import TestClient FIM = "<|fim_pad|>" # a broad content special that is NOT a turn terminator +CHANNEL_OPEN = "<|channel>" +CHANNEL_CLOSE = "" +THINK = "<|think|>" +CHANNEL_SPECIALS = {CHANNEL_OPEN, CHANNEL_CLOSE, THINK} WEATHER_TOOLS = [ {"type": "function", "function": {"name": "get_weather", "parameters": {}}} ] @@ -35,7 +39,7 @@ class _SpecialTok: eos=<|im_end|> (a terminator) plus <|fim_pad|> (broad-only).""" eos_token = "<|im_end|>" - all_special_tokens = ["<|im_end|>", FIM] + all_special_tokens = ["<|im_end|>", FIM, CHANNEL_OPEN, CHANNEL_CLOSE, THINK] def encode(self, text, add_special_tokens=False): return [0] * 5 @@ -100,12 +104,23 @@ def generate(self, prompt, config, token_callback=None, stats_callback=None): stats_callback(stats) -def _serving(tokens, honor_stops=False, gen_ids=None): +def _serving( + tokens, + honor_stops=False, + gen_ids=None, + content_filter=None, + content_filter_specials=None, +): runner = _Runner(tokens, gen_ids=gen_ids, honor_stops=honor_stops) template = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) template._hf = _SpecialTok() serving = ServingChat( - SessionRuntime(runner), template, "test-model", tool_detector_cls=HermesDetector + SessionRuntime(runner), + template, + "test-model", + tool_detector_cls=HermesDetector, + content_filter=content_filter, + content_filter_specials=content_filter_specials, ) return serving, runner @@ -165,6 +180,29 @@ def test_plain_chat_streaming_does_not_leak_broad_special(): assert finish == "stop" +def test_plain_chat_streaming_applies_content_filter(): + def strip_channel(text): + return text.replace(f"{CHANNEL_OPEN}secret{CHANNEL_CLOSE}", "").strip() + + serving, runner = _serving( + ["Visible ", CHANNEL_OPEN, "secret", CHANNEL_CLOSE], + content_filter=strip_channel, + content_filter_specials=CHANNEL_SPECIALS, + ) + r = _client(serving).post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "stream": True, + }, + ) + content, _ = _sse_content(r.text) + assert content == "Visible" + assert CHANNEL_OPEN not in content and CHANNEL_CLOSE not in content + assert CHANNEL_OPEN not in (runner.captured_config.stop or []) + + def test_plain_chat_nonstreaming_matches_streaming_visible(): serving, _ = _serving(["Hi", FIM, "leak"]) r = _client(serving).post( diff --git a/extension/llm/server/python/tests/test_template.py b/extension/llm/server/python/tests/test_template.py index c8a53f1dec8..7e1f75e5dd2 100644 --- a/extension/llm/server/python/tests/test_template.py +++ b/extension/llm/server/python/tests/test_template.py @@ -6,6 +6,8 @@ """Contract tests for chat-template kwargs (e.g. enable_thinking) passthrough.""" +import pytest + from executorch.extension.llm.server.python.chat_template import ChatTemplate from executorch.extension.llm.server.python.protocol import ( ChatMessage, @@ -19,6 +21,8 @@ def __init__(self): self.seen_kwargs = None self.seen_messages = None self.encode_add_special = None + self.chat_template = "fake" + self.bos_token = "" def apply_chat_template( self, messages, tools, add_generation_prompt, tokenize, **kwargs @@ -34,6 +38,26 @@ def encode(self, text, add_special_tokens=True): return list(range(len(text))) # 1 id per char, deterministic +class _GemmaToolResponseHF(_FakeHF): + def __init__( + self, + tool_prompt="<|turn>model\n<|tool_response>response:x", + preamble="<|channel>thought\n", + ): + super().__init__() + self.tool_prompt = tool_prompt + self.preamble = preamble + + def apply_chat_template( + self, messages, tools, add_generation_prompt, tokenize, **kwargs + ): + self.seen_kwargs = kwargs + self.seen_messages = messages + if messages and messages[-1]["role"] == "tool": + return self.tool_prompt + return "<|turn>user\n\n<|turn>model\n" + self.preamble + + def _template_with_fake(defaults=None): t = ChatTemplate( hf_tokenizer_path=None, allow_fallback=True, default_template_kwargs=defaults @@ -43,6 +67,22 @@ def _template_with_fake(defaults=None): return t, fake +def _template_with_gemma_tool_response_fake( + tool_prompt="<|turn>model\n<|tool_response>response:x", + preamble="<|channel>thought\n", + append=False, +): + t = ChatTemplate( + hf_tokenizer_path=None, + allow_fallback=True, + assistant_header="<|turn>model\n", + append_generation_prompt_after_tool_response=append, + ) + fake = _GemmaToolResponseHF(tool_prompt=tool_prompt, preamble=preamble) + t._hf = fake + return t, fake + + def test_count_tokens_excludes_special_tokens(): # The rendered prompt already carries control tokens, so count_tokens must # encode with add_special_tokens=False (matching the session/prefix-cache @@ -75,6 +115,42 @@ def test_no_kwargs_when_none(): assert fake.seen_kwargs == {} +def test_strip_rendered_prefix_before_worker_tokenization(): + t = ChatTemplate( + hf_tokenizer_path=None, allow_fallback=True, strip_rendered_prefix="PRO" + ) + fake = _FakeHF() + t._hf = fake + out = t.render([ChatMessage(role="user", content="hi")]) + assert out == "MPT" + assert fake.seen_messages[0]["content"] == "hi" + + +def test_strip_rendered_prefix_fails_if_missing(): + t = ChatTemplate( + hf_tokenizer_path=None, allow_fallback=True, strip_rendered_prefix="MISS" + ) + t._hf = _FakeHF() + with pytest.raises(ValueError, match="strip prefix"): + t.render([ChatMessage(role="user", content="hi")]) + + +def test_strip_rendered_bos_uses_tokenizer_bos(monkeypatch): + transformers = pytest.importorskip("transformers") + fake = _FakeHF() + + def apply_chat_template(messages, tools, add_generation_prompt, tokenize, **kwargs): + fake.seen_messages = messages + return "PROMPT" + + fake.apply_chat_template = apply_chat_template + monkeypatch.setattr( + transformers.AutoTokenizer, "from_pretrained", lambda _path: fake + ) + t = ChatTemplate("unused", strip_rendered_bos=True) + assert t.render([ChatMessage(role="user", content="hi")]) == "PROMPT" + + def test_fallback_ignores_kwargs_without_hf(): # No HF tokenizer → ChatML fallback, must not raise on kwargs. t = ChatTemplate( @@ -86,6 +162,61 @@ def test_fallback_ignores_kwargs_without_hf(): assert "<|im_start|>user" in out and out.endswith("<|im_start|>assistant\n") +def test_tool_response_generation_prompt_disabled_by_default(): + t, _ = _template_with_gemma_tool_response_fake(append=False) + out = t.render([ChatMessage(role="tool", tool_call_id="c1", content="ok")]) + assert out.endswith("") + assert "<|channel>thought" not in out + + +def test_tool_response_generation_prompt_appended_for_gemma(): + t, _ = _template_with_gemma_tool_response_fake(append=True) + out = t.render([ChatMessage(role="tool", tool_call_id="c1", content="ok")]) + assert out.endswith("<|turn>model\n<|channel>thought\n") + + +def test_tool_response_generation_prompt_handles_turn_end_case(): + t, _ = _template_with_gemma_tool_response_fake( + tool_prompt="<|turn>model\nLet me check.\n", + preamble="", + append=True, + ) + out = t.render([ChatMessage(role="tool", tool_call_id="c1", content="ok")]) + assert out.endswith("\n<|turn>model\n") + + +def test_tool_response_generation_prompt_not_double_inserted(): + t, _ = _template_with_gemma_tool_response_fake( + tool_prompt=( + "<|turn>model\n<|tool_response>response:x" + "<|turn>model\n<|channel>thought\n" + ), + append=True, + ) + out = t.render([ChatMessage(role="tool", tool_call_id="c1", content="ok")]) + assert out.count("<|turn>model\n") == 2 + assert out.count("<|channel>thought\n") == 1 + + +def test_tool_response_generation_prompt_completes_header_only_prompt(): + t, _ = _template_with_gemma_tool_response_fake( + tool_prompt="<|turn>model\n<|tool_response>response:x<|turn>model\n", + append=True, + ) + out = t.render([ChatMessage(role="tool", tool_call_id="c1", content="ok")]) + assert out.endswith("<|turn>model\n<|channel>thought\n") + assert out.count("<|turn>model\n") == 2 + + +def test_tool_response_generation_prompt_requires_final_tool_message(): + t, _ = _template_with_gemma_tool_response_fake( + tool_prompt="<|turn>model\n<|tool_response>response:x", + append=True, + ) + out = t.render([ChatMessage(role="user", content="ok")]) + assert out == "<|turn>user\n\n<|turn>model\n<|channel>thought\n" + + # (5) Chat-template behaviors: multi-turn ordering, system message, roles. def test_multi_turn_order_preserved(): t = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) diff --git a/extension/llm/server/python/tests/test_warm_resume_scaffold.py b/extension/llm/server/python/tests/test_warm_resume_scaffold.py index 5f01c553379..501791dd6e4 100644 --- a/extension/llm/server/python/tests/test_warm_resume_scaffold.py +++ b/extension/llm/server/python/tests/test_warm_resume_scaffold.py @@ -32,6 +32,8 @@ HDR = "<|im_start|>assistant\n" NOTHINK = "\n\n\n\n" # no-think generation preamble / preserved block THINK = "\n" # think-mode generation preamble +GEMMA_HDR = "<|turn>model\n" +GEMMA_PREAMBLE = "<|channel>thought\n" def _msgs(*pairs): @@ -100,6 +102,20 @@ def render(self, messages, tools=None, template_kwargs=None): return "".join(out) +class _FakeGemma: + def assistant_header(self): + return GEMMA_HDR + + def render(self, messages, tools=None, template_kwargs=None): + out = [""] + for m in messages: + role = "model" if m.role == "assistant" else m.role + content = m.content if isinstance(m.content, str) else "" + out.append(f"<|turn>{role}\n{content}\n") + out.append(GEMMA_HDR + GEMMA_PREAMBLE) + return "".join(out) + + def _ids_index(segs, ids): for i, s in enumerate(segs): if s.get("ids") == ids: @@ -245,6 +261,29 @@ def test_non_qwen_header_no_scaffold_still_splices(): assert any(s.get("ids") == [9, 9] for s in pi.segments) # ids actually spliced +def test_custom_assistant_header_inserts_scaffold(): + st = OpenAITranscriptState(_FakeGemma()) + st.record_assistant_turn( + session_id="s", + content="a1", + tool_calls=None, + generated_token_ids=[4, 5], + prior_turns=0, + preamble=GEMMA_PREAMBLE, + ) + msgs = _msgs(("user", "u1"), ("assistant", "a1"), ("user", "u2")) + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=st._template.render(msgs), + tools=None, + template_kwargs=None, + ) + assert pi.segments is not None + before = _text_before_ids(pi.segments, [4, 5]) + assert before.rsplit(GEMMA_HDR, 1)[-1] == GEMMA_PREAMBLE + + def test_stop_trimmed_turn_falls_back_to_text(): st = OpenAITranscriptState(_FakeQwen()) st.record_assistant_turn( @@ -414,6 +453,49 @@ def test_tool_turn_splices_despite_reserialized_args(): assert any(s.get("ids") == [1, 2, 3] for s in pi.segments) +def test_gemma_tool_span_ignores_close_marker_inside_string(): + st = OpenAITranscriptState(_FakeGemma()) + call = ToolCall( + id="call_1", + function=FunctionCall( + name="bash", arguments='{"command":"printf ok"}' + ), + ) + st.record_assistant_turn( + session_id="s", + content="", + tool_calls=[call], + generated_token_ids=[101, 102, 103], + prior_turns=0, + preamble="", + ) + msgs = [ + ChatMessage(role="user", content="u1"), + ChatMessage(role="assistant", content="", tool_calls=[call]), + ChatMessage(role="tool", tool_call_id="call_1", content="done"), + ] + rendered = ( + "<|turn>user\nu1\n" + "<|turn>model\n" + '<|tool_call>call:bash{command:<|"|>printf ok<|"|>}' + "<|tool_response>done" + ) + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=rendered, + tools=None, + template_kwargs=None, + ) + assert pi.segments is not None + assert any(s.get("ids") == [101, 102, 103] for s in pi.segments) + suffix = "".join( + s.get("text", "") + for s in pi.segments[_ids_index(pi.segments, [101, 102, 103]) + 1 :] + ) + assert suffix == "<|tool_response>done" + + # --- 5b. Token-level fidelity against the real tokenizer (gated/skipped) ----- _MODEL = os.environ.get( @@ -444,6 +526,128 @@ def _assemble(segs, enc): return out +_GEMMA_MODEL = os.environ.get( + "GEMMA_HF_DIR", "/home/mnachin/local/scripts/models/gemma-4-31B-it-HQQ-INT4" +) +_HAVE_GEMMA = os.path.isdir(_GEMMA_MODEL) +_skip_gemma = pytest.mark.skipif( + not _HAVE_GEMMA, reason=f"real Gemma tokenizer dir not present: {_GEMMA_MODEL}" +) + + +def _real_gemma_template_and_enc(strip_bos=True): + pytest.importorskip("transformers") + from transformers import AutoTokenizer + + tmpl = ChatTemplate( + hf_tokenizer_path=_GEMMA_MODEL, + assistant_header=GEMMA_HDR, + strip_rendered_bos=strip_bos, + ) + tok = AutoTokenizer.from_pretrained(_GEMMA_MODEL) + return tmpl, tok, (lambda s: tok.encode(s, add_special_tokens=False)) + + +@_skip_gemma +def test_gemma_real_template_bos_strip_matches_full_render_ids(): + stripped, tok, enc = _real_gemma_template_and_enc(strip_bos=True) + full, _, _ = _real_gemma_template_and_enc(strip_bos=False) + msgs = _msgs(("user", "What is the capital of France?")) + full_render = full.render(msgs) + stripped_render = stripped.render(msgs) + + assert full_render.startswith(tok.bos_token) + assert not stripped_render.startswith(tok.bos_token) + assert [tok.bos_token_id] + enc(stripped_render) == enc(full_render) + + +@_skip_gemma +def test_gemma_real_template_warm_resume_prefix_with_bos_prefix(): + tmpl, tok, enc = _real_gemma_template_and_enc(strip_bos=True) + st = OpenAITranscriptState(tmpl) + content = "The capital is Paris." + gen_ids = enc(content) + first_prompt = tmpl.render(_msgs(("user", "u1"))) + resident = [tok.bos_token_id] + enc(first_prompt) + gen_ids + st.record_assistant_turn( + session_id="s", + content=content, + tool_calls=None, + generated_token_ids=gen_ids, + prior_turns=0, + preamble=tmpl.generation_preamble(), + ) + msgs = _msgs(("user", "u1"), ("assistant", content), ("user", "u2")) + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=tmpl.render(msgs), + tools=None, + template_kwargs=None, + ) + assert pi.segments is not None + assembled = [tok.bos_token_id] + _assemble(pi.segments, enc) + assert assembled[: len(resident)] == resident + + +@_skip_gemma +def test_gemma_real_template_post_tool_splices_call_and_keeps_tool_response(): + tmpl, tok, enc = _real_gemma_template_and_enc(strip_bos=True) + st = OpenAITranscriptState(tmpl) + tools = [ + { + "type": "function", + "function": { + "name": "bash", + "description": "Run bash", + "parameters": { + "type": "object", + "properties": {"command": {"type": "string"}}, + "required": ["command"], + }, + }, + } + ] + call = ToolCall( + id="call_1", + function=FunctionCall(name="bash", arguments='{"command":"echo hello42"}'), + ) + raw_call = '<|tool_call>call:bash{command:<|"|>echo hello42<|"|>}' + gen_ids = enc(raw_call) + first_prompt = tmpl.render(_msgs(("user", "Run echo hello42")), tools=tools) + resident = [tok.bos_token_id] + enc(first_prompt) + gen_ids + st.record_assistant_turn( + session_id="s", + content="", + tool_calls=[call], + generated_token_ids=gen_ids, + prior_turns=0, + preamble=tmpl.generation_preamble(tools=tools), + ) + msgs = [ + ChatMessage(role="user", content="Run echo hello42"), + ChatMessage(role="assistant", content="", tool_calls=[call]), + ChatMessage(role="tool", tool_call_id="call_1", content="hello42"), + ] + rendered = tmpl.render(msgs, tools=tools) + assert rendered.endswith("") + + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=rendered, + tools=tools, + template_kwargs=None, + ) + assert pi.segments is not None + assembled = [tok.bos_token_id] + _assemble(pi.segments, enc) + assert assembled[: len(resident)] == resident + + suffix_text = "".join(seg.get("text", "") for seg in pi.segments) + assert "<|tool_response>" in suffix_text + assert "hello42" in suffix_text + + @_skip @pytest.mark.parametrize("thinking", [False, True]) def test_token_level_exact_prefix_ordinary(thinking): diff --git a/extension/llm/server/python/tool_parsers/__init__.py b/extension/llm/server/python/tool_parsers/__init__.py index c890dec3888..7d72030b40f 100644 --- a/extension/llm/server/python/tool_parsers/__init__.py +++ b/extension/llm/server/python/tool_parsers/__init__.py @@ -4,9 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Tool-call parsing. Two formats, pick the one matching your model: +"""Tool-call parsing. Pick the parser matching your model: - HermesDetector: JSON inside (Qwen2.5/3, Hermes). +- GemmaToolCallDetector: Gemma <|tool_call>call:NAME{...}. - QwenFunctionCallDetector: Qwen XML (Qwen3.5-MoE / Qwen3-Coder). @@ -14,12 +15,14 @@ OpenAI tool_calls; parse failures degrade to visible text. """ +from .gemma import GemmaToolCallDetector from .hermes import HermesDetector from .qwen import QwenFunctionCallDetector from .types import ParseResult, ToolCallItem __all__ = [ "HermesDetector", + "GemmaToolCallDetector", "QwenFunctionCallDetector", "ParseResult", "ToolCallItem", diff --git a/extension/llm/server/python/tool_parsers/gemma.py b/extension/llm/server/python/tool_parsers/gemma.py new file mode 100644 index 00000000000..1f36b95f852 --- /dev/null +++ b/extension/llm/server/python/tool_parsers/gemma.py @@ -0,0 +1,258 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Gemma tool calls: <|tool_call>call:NAME{key:value}.""" + +import json +import logging +import math +import re +from typing import Any, Optional + +from .types import ParseResult, ToolCallItem + +logger = logging.getLogger(__name__) + +_BOT = "<|tool_call>" +_EOT = "" +_QUOTE = '<|"|>' +_INT_RE = re.compile(r"[+-]?[0-9]+$") +_NUM_RE = re.compile(r"[+-]?(?:[0-9]+\.?[0-9]*|\.[0-9]+)(?:[eE][+-]?[0-9]+)?$") + + +class _UndefinedToolCall(Exception): + pass + + +class _Bare(str): + __slots__ = () + + +class _Parser: + def __init__(self, text: str): + self.text = text + self.i = 0 + + def _skip_ws(self) -> None: + while self.i < len(self.text) and self.text[self.i].isspace(): + self.i += 1 + + def _expect(self, token: str) -> None: + self._skip_ws() + if not self.text.startswith(token, self.i): + raise ValueError(f"expected {token!r}") + self.i += len(token) + + def parse_call(self) -> tuple[str, dict[str, Any]]: + self._expect("call:") + start = self.i + while self.i < len(self.text) and self.text[self.i] not in "{ \t\r\n": + self.i += 1 + name = self.text[start : self.i].strip() + if not name: + raise ValueError("missing tool name") + args = self._parse_object() + self._skip_ws() + if self.i != len(self.text): + raise ValueError("trailing garbage") + return name, args + + def _parse_object(self) -> dict[str, Any]: + self._expect("{") + out: dict[str, Any] = {} + self._skip_ws() + if self.i < len(self.text) and self.text[self.i] == "}": + self.i += 1 + return out + while True: + key = self._parse_key() + self._expect(":") + out[key] = self._parse_value() + self._skip_ws() + if self.i >= len(self.text): + raise ValueError("unclosed object") + if self.text[self.i] == "}": + self.i += 1 + return out + self._expect(",") + + def _parse_key(self) -> str: + self._skip_ws() + if self.text.startswith(_QUOTE, self.i): + return self._parse_string() + start = self.i + while self.i < len(self.text) and self.text[self.i] not in ": \t\r\n": + self.i += 1 + key = self.text[start : self.i].strip() + if not key: + raise ValueError("missing key") + return key + + def _parse_value(self) -> Any: + self._skip_ws() + if self.text.startswith(_QUOTE, self.i): + return self._parse_string() + if self.i < len(self.text) and self.text[self.i] == "{": + return self._parse_object() + if self.i < len(self.text) and self.text[self.i] == "[": + return self._parse_array() + return self._parse_bare() + + def _parse_string(self) -> str: + self._expect(_QUOTE) + end = self.text.find(_QUOTE, self.i) + if end == -1: + raise ValueError("unclosed string") + value = self.text[self.i : end] + self.i = end + len(_QUOTE) + return value + + def _parse_array(self) -> list[Any]: + self._expect("[") + out = [] + self._skip_ws() + if self.i < len(self.text) and self.text[self.i] == "]": + self.i += 1 + return out + while True: + out.append(self._parse_value()) + self._skip_ws() + if self.i >= len(self.text): + raise ValueError("unclosed array") + if self.text[self.i] == "]": + self.i += 1 + return out + self._expect(",") + + def _parse_bare(self) -> _Bare: + start = self.i + while self.i < len(self.text) and self.text[self.i] not in ",]}": + self.i += 1 + return _Bare(self.text[start : self.i].strip()) + + +def _guess_scalar(raw: str) -> Any: + low = raw.lower() + if low == "true": + return True + if low == "false": + return False + if low == "null": + return None + if _INT_RE.match(raw): + return int(raw) + if _NUM_RE.match(raw): + value = float(raw) + if math.isfinite(value): + return value + return str(raw) + + +def _schema_type(schema: dict[str, Any]) -> Optional[str]: + typ = schema.get("type") + if isinstance(typ, list): + typ = next((t for t in typ if t != "null"), typ[0] if typ else None) + return typ + + +def _coerce_string(value: Any) -> Any: + return str(value) + + +def _coerce_bool(value: Any) -> Any: + low = str(value).strip().lower() + if low == "true": + return True + if low == "false": + return False + return str(value) + + +def _coerce_int(value: Any) -> Any: + s = str(value).strip() + return int(s) if _INT_RE.match(s) else str(value) + + +def _coerce_number(value: Any) -> Any: + s = str(value).strip() + if _NUM_RE.match(s): + parsed = float(s) + if math.isfinite(parsed): + return parsed + return str(value) + + +_SCALAR_COERCERS = { + "string": _coerce_string, + "boolean": _coerce_bool, + "integer": _coerce_int, + "number": _coerce_number, +} + + +def _coerce(value: Any, schema: Optional[dict[str, Any]]) -> Any: + if isinstance(value, dict): + props = (schema or {}).get("properties") or {} + return {k: _coerce(v, props.get(k)) for k, v in value.items()} + if isinstance(value, list): + items = (schema or {}).get("items") + item_schema = items if isinstance(items, dict) else None + return [_coerce(v, item_schema) for v in value] + typ = _schema_type(schema) if schema else None + coercer = _SCALAR_COERCERS.get(typ) + if coercer is not None: + return coercer(value) + return _guess_scalar(value) if isinstance(value, _Bare) else value + + +class GemmaToolCallDetector: + bot_token = _BOT + + def __init__(self): + self._next_index = 0 + + def detect_and_parse(self, text: str, tools: dict[str, dict]) -> ParseResult: + if _BOT not in text: + return ParseResult(normal_text=text) + normal = text[: text.find(_BOT)].strip() + try: + calls = self._parse_calls(text, tools) + except _UndefinedToolCall as e: + logger.debug("undefined tool %s; returning raw text (no partial calls)", e) + return ParseResult(normal_text=text) + except Exception as e: # noqa: BLE001 - never crash + logger.debug("malformed Gemma tool call (%s); degrading to leading text", e) + return ParseResult(normal_text=normal) + return ( + ParseResult(normal_text=normal, calls=calls) if calls else ParseResult(text) + ) + + def _parse_calls(self, text: str, tools: dict[str, dict]) -> list[ToolCallItem]: + calls = [] + pos = 0 + while True: + start = text.find(_BOT, pos) + if start == -1: + break + body_start = start + len(_BOT) + end = text.find(_EOT, body_start) + if end == -1: + raise ValueError("unclosed Gemma tool call") + name, args = _Parser(text[body_start:end]).parse_call() + if name not in tools: + raise _UndefinedToolCall(repr(name)) + schema = tools.get(name) or {} + item = ToolCallItem( + tool_index=self._next_index, + name=name, + arguments=json.dumps( + _coerce(args, schema), ensure_ascii=False, allow_nan=False + ), + ) + self._next_index += 1 + calls.append(item) + pos = end + len(_EOT) + return calls