diff --git a/docs/guides/datasets.md b/docs/guides/datasets.md index 5bb511364..0a5f25511 100644 --- a/docs/guides/datasets.md +++ b/docs/guides/datasets.md @@ -59,6 +59,21 @@ guidellm benchmark \ --data '{"prompt_tokens": 256, "output_tokens": 128}' ``` +For embeddings endpoints, you need to specify `output_tokens=1` (a current limitation of the synthetic data generator): + +```bash +guidellm benchmark \ + --target "http://localhost:8000" \ + --request-type embeddings \ + --profile concurrent \ + --rate 32 \ + --max-requests 500 \ + --data "prompt_tokens=256,output_tokens=1" \ + --processor "BAAI/bge-small-en-v1.5" +``` + +For more details on embeddings benchmarking, see the [Embeddings Guide](./embeddings.md). + #### Configuration Options - `prompt_tokens`: Average number of tokens in prompts. If nothing else is specified, all requests will have this number of tokens. diff --git a/docs/guides/embeddings.md b/docs/guides/embeddings.md new file mode 100644 index 000000000..68d958b52 --- /dev/null +++ b/docs/guides/embeddings.md @@ -0,0 +1,284 @@ +# Embeddings Benchmarking + +GuideLLM supports benchmarking embedding models through the `/v1/embeddings` endpoint. This guide covers how to set up and run benchmarks for text embedding models, which are commonly used for semantic search, clustering, and other ML tasks. + +## Overview + +Embedding models convert text into dense vector representations that capture semantic meaning. Benchmarking these models helps you: + +- Measure throughput and latency for embedding generation +- Test performance under different load conditions +- Compare different embedding model deployments +- Optimize your embedding service configuration + +## Supported Backends + +### vLLM + +vLLM supports embedding models starting from version 0.4.0. To serve an embedding model with vLLM: + +```bash +vllm serve "BAAI/bge-small-en-v1.5" +``` + +Popular embedding models supported by vLLM: + +- **BAAI/bge-small-en-v1.5**: Lightweight English embedding model (384 dimensions) +- **BAAI/bge-base-en-v1.5**: Base English embedding model (768 dimensions) +- **BAAI/bge-large-en-v1.5**: Large English embedding model (1024 dimensions) +- **sentence-transformers/all-MiniLM-L6-v2**: Compact multilingual model (384 dimensions) +- **intfloat/e5-large-v2**: High-performance English model (1024 dimensions) + +For the latest list of supported models, see the [vLLM documentation](https://docs.vllm.ai/en/latest/models/supported_models.html). + +### OpenAI API + +GuideLLM can also benchmark OpenAI's embedding endpoints: + +```bash +guidellm benchmark \ + --target "https://api.openai.com" \ + --request-type embeddings \ + --model "text-embedding-3-small" \ + --rate 5 \ + --max-requests 50 \ + --data "prompt_tokens=256,output_tokens=1" \ + --processor "gpt2" +``` + +Note: You'll need to set your OpenAI API key as an environment variable or in the request headers. + +## Basic Benchmarking + +### Simple Concurrent Benchmark (Recommended) + +For embeddings, concurrent testing is the most relevant approach. To run a basic concurrent benchmark with synthetic data: + +```bash +guidellm benchmark \ + --target "http://localhost:8000" \ + --request-type embeddings \ + --profile concurrent \ + --rate 32 \ + --max-requests 100 \ + --data "prompt_tokens=256,output_tokens=1" \ + --processor "BAAI/bge-small-en-v1.5" +``` + +This command: + +- Tests with 32 concurrent requests (parallel processing) +- Stops after 100 total requests +- Uses synthetic text with ~256 tokens per request +- Uses the bge-small tokenizer for token counting +- **Note**: `output_tokens=1` is required when using synthetic data, even though embeddings don't generate output. This is a current limitation of the synthetic data generator. + +## Benchmark Profiles for Embeddings + +Different benchmark profiles serve different purposes when testing embedding models: + +- **Concurrent** (Recommended): Tests parallel request handling - the most common production pattern for embeddings +- **Throughput**: Finds maximum sustainable request rate - useful for capacity planning +- **Synchronous**: Sequential baseline testing - useful for measuring per-request latency without concurrency effects +- **Constant**: Fixed-rate testing - less relevant for embeddings since they have predictable processing times +- **Sweep**: Not recommended for embeddings (designed for optimizing generative model parameters) + +For most embedding benchmarks, use **concurrent** or **throughput** profiles. + +## Advanced Usage + +### Variable Input Lengths + +Test performance across different input lengths: + +```bash +guidellm benchmark \ + --target "http://localhost:8000" \ + --request-type embeddings \ + --rate 10 \ + --max-requests 200 \ + --data "prompt_tokens=256,prompt_tokens_min=128,prompt_tokens_max=500,output_tokens=1" \ + --processor "BAAI/bge-small-en-v1.5" +``` + +This creates requests with uniformly distributed lengths between 128 and 500 tokens. + +### Using Real Data + +Benchmark with actual text data from a file or Hugging Face dataset: + +```bash +guidellm benchmark \ + --target "http://localhost:8000" \ + --request-type embeddings \ + --rate 10 \ + --max-requests 100 \ + --data "path/to/your/data.jsonl" \ + --data-args '{"prompt_column": "text"}' \ + --processor "BAAI/bge-small-en-v1.5" +``` + +Or from Hugging Face: + +```bash +guidellm benchmark \ + --target "http://localhost:8000" \ + --request-type embeddings \ + --rate 10 \ + --max-requests 100 \ + --data "sentence-transformers/stsb" \ + --data-args '{"prompt_column": "sentence1", "split": "test"}' \ + --processor "BAAI/bge-small-en-v1.5" +``` + +### Load Testing Scenarios + +#### Testing Concurrent Request Handling (Recommended) + +The concurrent profile is the most relevant for embeddings, as it simulates how production systems typically use embedding models (parallel batch processing): + +```bash +guidellm benchmark \ + --target "http://localhost:8000" \ + --request-type embeddings \ + --profile concurrent \ + --rate 32 \ + --max-requests 500 \ + --data "prompt_tokens=256,output_tokens=1" \ + --processor "BAAI/bge-small-en-v1.5" +``` + +The `--rate` parameter specifies the number of concurrent streams (e.g., 32 parallel requests). + +#### Finding Maximum Throughput + +Use the throughput profile to find the maximum sustainable request rate for capacity planning: + +```bash +guidellm benchmark \ + --target "http://localhost:8000" \ + --request-type embeddings \ + --profile throughput \ + --max-requests 500 \ + --data "prompt_tokens=256,output_tokens=1" \ + --processor "BAAI/bge-small-en-v1.5" +``` + +## Metrics and Analysis + +When benchmarking embeddings, GuideLLM tracks: + +- **Request Latency**: Time from request start to completion +- **Time to First Token (TTFT)**: For embeddings, this is effectively the processing time +- **Throughput**: Requests processed per second +- **Token Throughput**: Input tokens processed per second +- **Success Rate**: Percentage of successful requests +- **Error Rate**: Percentage of failed requests + +### Example Output + +```bash +guidellm benchmark \ + --target "http://localhost:8000" \ + --request-type embeddings \ + --rate 10 \ + --max-requests 100 \ + --data "prompt_tokens=256,output_tokens=1" \ + --processor "BAAI/bge-small-en-v1.5" \ + --output-path embeddings_report.json +``` + +The JSON report will include: + +- Per-request timing and token counts +- Aggregate statistics (mean, median, percentiles) +- Request success/failure breakdown +- Overall benchmark metadata + +## Best Practices + +1. **Match the Processor**: Use the same tokenizer as your embedding model for accurate token counting + +2. **Account for Model Context Length**: + + - **Check your model's limit**: Query the models endpoint to find `max_model_len`: + + ```bash + curl -s http://localhost:8000/v1/models | python3 -m json.tool | grep "max_model_len" + ``` + + This will show something like: `"max_model_len": 512` + + - **Synthetic data overhead**: The generator adds 2-5 tokens per request to ensure uniqueness + + - **Leave headroom**: Subtract ~10 tokens from `max_model_len` for safety + + - **Examples**: + + - 512-token model → use `prompt_tokens=500` or `prompt_tokens_max=500` + - 8192-token model → use up to `prompt_tokens=8180` + + - **Error symptom**: "maximum context length exceeded" errors mean your tokens + prefix > model limit + +3. **Start with Low Rates**: Begin with conservative request rates and gradually increase + +4. **Use Realistic Data**: Test with data similar to your production workload + +5. **Test Multiple Scenarios**: Vary input lengths, batch sizes, and request patterns + +6. **Monitor System Resources**: Watch CPU, memory, and GPU utilization during benchmarks + +7. **Run Multiple Iterations**: Execute benchmarks several times to account for variance + +## Examples + +### Short Context Embeddings (128-512 tokens) + +Typical for BERT-style models with concurrent processing: + +```bash +guidellm benchmark \ + --target "http://localhost:8000" \ + --request-type embeddings \ + --profile concurrent \ + --rate 32 \ + --max-requests 500 \ + --data "prompt_tokens=256,prompt_tokens_min=128,prompt_tokens_max=500,output_tokens=1" \ + --processor "BAAI/bge-small-en-v1.5" +``` + +This tests with 32 concurrent streams, which matches common production patterns. Using `prompt_tokens_max=500` instead of 512 leaves headroom for the synthetic data generator's unique request prefix. + +### Long Context Embeddings (8k-32k tokens) + +For newer long-context embedding models (lower concurrency due to larger context): + +```bash +guidellm benchmark \ + --target "http://localhost:8000" \ + --request-type embeddings \ + --profile concurrent \ + --rate 8 \ + --max-requests 100 \ + --data "prompt_tokens=16384,prompt_tokens_min=8192,prompt_tokens_max=32768,output_tokens=1" \ + --processor "jinaai/jina-embeddings-v3" +``` + +### Production Simulation + +Simulate realistic production workload with variable input lengths: + +```bash +guidellm benchmark \ + --target "http://localhost:8000" \ + --request-type embeddings \ + --profile concurrent \ + --rate 16 \ + --max-requests 1000 \ + --data "prompt_tokens=256,prompt_tokens_stdev=100,output_tokens=1,samples=1000" \ + --data-sampler random \ + --processor "BAAI/bge-base-en-v1.5" \ + --output-path production_simulation.json +``` + +This runs a comprehensive benchmark with 1000 requests and variable-length inputs (using standard deviation), closely mimicking real-world usage patterns. diff --git a/src/guidellm/backends/__init__.py b/src/guidellm/backends/__init__.py index 6577fa728..94f9c7cc5 100644 --- a/src/guidellm/backends/__init__.py +++ b/src/guidellm/backends/__init__.py @@ -16,6 +16,7 @@ from .response_handlers import ( AudioResponseHandler, ChatCompletionsResponseHandler, + EmbeddingsResponseHandler, GenerationResponseHandler, GenerationResponseHandlerFactory, TextCompletionsResponseHandler, @@ -26,6 +27,7 @@ "Backend", "BackendType", "ChatCompletionsResponseHandler", + "EmbeddingsResponseHandler", "GenerationResponseHandler", "GenerationResponseHandlerFactory", "OpenAIHTTPBackend", diff --git a/src/guidellm/backends/openai.py b/src/guidellm/backends/openai.py index 224a21234..fb9ef310c 100644 --- a/src/guidellm/backends/openai.py +++ b/src/guidellm/backends/openai.py @@ -87,6 +87,7 @@ def __init__( "chat_completions": "v1/chat/completions", "audio_transcriptions": "v1/audio/transcriptions", "audio_translations": "v1/audio/translations", + "embeddings": "v1/embeddings", } self.response_handlers = response_handlers self.timeout = timeout diff --git a/src/guidellm/backends/response_handlers.py b/src/guidellm/backends/response_handlers.py index e8087e058..7c73dcaf5 100644 --- a/src/guidellm/backends/response_handlers.py +++ b/src/guidellm/backends/response_handlers.py @@ -17,6 +17,7 @@ __all__ = [ "AudioResponseHandler", "ChatCompletionsResponseHandler", + "EmbeddingsResponseHandler", "GenerationResponseHandler", "GenerationResponseHandlerFactory", "TextCompletionsResponseHandler", @@ -525,3 +526,152 @@ def extract_metrics( text_words=len(text.split()) if text else 0, text_characters=len(text) if text else 0, ) + + +@GenerationResponseHandlerFactory.register("embeddings") +class EmbeddingsResponseHandler: + """ + Response handler for embeddings API endpoints. + + Processes responses from embeddings APIs that convert text into vector + representations. Unlike other handlers, embeddings typically don't support + streaming, so this handler primarily focuses on non-streaming responses + and extracts embedding vectors along with token usage metrics. + + Example: + :: + handler = EmbeddingsResponseHandler() + response = handler.compile_non_streaming(request, api_response) + """ + + def __init__(self): + """ + Initialize the embeddings response handler. + + Sets up internal state for storing embeddings data and usage metrics. + While embeddings don't typically stream, we maintain streaming state + for protocol compatibility. + """ + self.streaming_embeddings: list[list[float]] = [] + self.streaming_usage: dict[str, int | dict[str, int]] | None = None + self.streaming_response_id: str | None = None + + def compile_non_streaming( + self, request: GenerationRequest, response: dict + ) -> GenerationResponse: + """ + Process a complete embeddings response. + + Extracts embedding vectors and usage metrics from the response. + Converts the embedding vectors to a string representation for storage + in the GenerationResponse. + + :param request: Original generation request + :param response: Complete API response containing embeddings and usage data + :return: Standardized GenerationResponse with embeddings as text and metrics + """ + usage: dict[str, int | dict[str, int]] = response.get("usage", {}) + + embeddings_data = response.get("data", []) + embeddings = [item.get("embedding", []) for item in embeddings_data] + + text_data = json.dumps({"embeddings": embeddings}) + text = text_data.decode() if isinstance(text_data, bytes) else text_data + + input_metrics, output_metrics = self.extract_metrics(usage) + + return GenerationResponse( + request_id=request.request_id, + request_args=str( + request.arguments.model_dump() if request.arguments else None + ), + response_id=response.get("id"), + text=text, + input_metrics=input_metrics, + output_metrics=output_metrics, + ) + + def add_streaming_line(self, line: str) -> int | None: + """ + Process a single line from an embeddings streaming response. + + Note: Embeddings APIs typically don't support streaming, but this method + is implemented for protocol compatibility. It will handle hypothetical + streaming scenarios if they exist. + + :param line: Raw line from the streaming response + :return: 1 if embeddings were extracted, 0 if line ignored, None if done + """ + if line == "data: [DONE]": + return None + + if not line or not (line := line.strip()) or not line.startswith("data:"): + return 0 + + line = line[len("data:") :].strip() + data: dict[str, Any] = json.loads(line) + updated = False + + if "id" in data and self.streaming_response_id is None: + self.streaming_response_id = data["id"] + + if embeddings_data := data.get("data"): + for item in embeddings_data: + if embedding := item.get("embedding"): + self.streaming_embeddings.append(embedding) + updated = True + + if usage := data.get("usage"): + self.streaming_usage = usage + + return 1 if updated else 0 + + def compile_streaming(self, request: GenerationRequest) -> GenerationResponse: + """ + Compile accumulated streaming embeddings into a final response. + + Note: Embeddings APIs typically don't support streaming, but this method + is implemented for protocol compatibility. + + :param request: Original generation request + :return: Standardized GenerationResponse with embeddings and metrics + """ + text_data = json.dumps({"embeddings": self.streaming_embeddings}) + text = text_data.decode() if isinstance(text_data, bytes) else text_data + input_metrics, output_metrics = self.extract_metrics(self.streaming_usage) + + return GenerationResponse( + request_id=request.request_id, + request_args=str( + request.arguments.model_dump() if request.arguments else None + ), + response_id=self.streaming_response_id, + text=text, + input_metrics=input_metrics, + output_metrics=output_metrics, + ) + + def extract_metrics( + self, usage: dict[str, int | dict[str, int]] | None + ) -> tuple[UsageMetrics, UsageMetrics]: + """ + Extract input and output usage metrics from embeddings API response. + + For embeddings, we primarily track input tokens (the text being embedded). + There are no output tokens generated since embeddings produce vectors, + not text. + + :param usage: Usage data dictionary from embeddings API response + :return: Tuple of input_metrics and output_metrics as UsageMetrics objects + """ + if not usage: + return UsageMetrics(), UsageMetrics(text_tokens=0) + + usage_metrics: dict[str, int] = cast("dict[str, int]", usage) + + return UsageMetrics( + text_tokens=usage_metrics.get("prompt_tokens", 0), + ), UsageMetrics( + # Embeddings don't generate text tokens, only consume them + text_tokens=0, + ) diff --git a/src/guidellm/data/builders.py b/src/guidellm/data/builders.py index 7ff584b68..ff96a46a6 100644 --- a/src/guidellm/data/builders.py +++ b/src/guidellm/data/builders.py @@ -219,9 +219,7 @@ def process_dataset( Main method to process and save a dataset with sampled prompt/output token counts. """ _validate_output_suffix(output_path) - logger.info( - f"Starting dataset conversion | Input: {data} | Output: {output_path}" - ) + logger.info(f"Starting dataset conversion | Input: {data} | Output: {output_path}") # Parse config config_obj = parse_synthetic_config(config) @@ -320,9 +318,7 @@ def _extract_column_names( output_mappings = column_mapper.datasets_column_mappings.get( "output_tokens_count_column", [] ) - output_column = ( - output_mappings[0][1] if output_mappings else "output_tokens_count" - ) + output_column = output_mappings[0][1] if output_mappings else "output_tokens_count" return prompt_column, prefix_column, output_column @@ -436,9 +432,7 @@ def _process_single_row( if prefix_tokens_max is not None: prefix_tokens_list = tokenizer.encode(prefix_text) if len(prefix_tokens_list) > prefix_tokens_max: - prefix_text = tokenizer.decode( - prefix_tokens_list[:prefix_tokens_max] - ) + prefix_text = tokenizer.decode(prefix_tokens_list[:prefix_tokens_max]) # Count prefix tokens toward prompt if enabled if include_prefix_in_token_count: @@ -450,9 +444,11 @@ def _process_single_row( elif count_adjustment > 0: adjusted_prompt_len = target_prompt_len - count_adjustment if adjusted_prompt_len <= 0: - logger.warning("The prefix exceeds target output length with " - "--include-prefix-in-token-count enabled; Using prompt size" - "of 1; skipping row") + logger.warning( + "The prefix exceeds target output length with " + "--include-prefix-in-token-count enabled; Using prompt size" + "of 1; skipping row" + ) return None target_prompt_len = adjusted_prompt_len diff --git a/src/guidellm/data/config.py b/src/guidellm/data/config.py index 2b0b2133a..401b5db2f 100644 --- a/src/guidellm/data/config.py +++ b/src/guidellm/data/config.py @@ -48,9 +48,7 @@ def _load_config_file(data: Any, config_class: type[ConfigT]) -> ConfigT | None: if Path(data).is_file() and data_path.suffix.lower() == ".json": try: - return config_class.model_validate_json( - data_path.read_text() - ) + return config_class.model_validate_json(data_path.read_text()) except Exception as err: # noqa: BLE001 error = err @@ -60,9 +58,7 @@ def _load_config_file(data: Any, config_class: type[ConfigT]) -> ConfigT | None: ".config", }: try: - return config_class.model_validate( - yaml.safe_load(data_path.read_text()) - ) + return config_class.model_validate(yaml.safe_load(data_path.read_text())) except Exception as err: # noqa: BLE001 error = err @@ -101,9 +97,7 @@ def _load_config_str(data: str, config_class: type[ConfigT]) -> ConfigT | None: for item in items: key, value = item.split("=") config_dict[key.strip()] = ( - int(value.strip()) - if value.strip().isnumeric() - else value.strip() + int(value.strip()) if value.strip().isnumeric() else value.strip() ) return config_class.model_validate(config_dict) diff --git a/src/guidellm/data/deserializers/synthetic.py b/src/guidellm/data/deserializers/synthetic.py index 068ec78e2..774bc34e6 100644 --- a/src/guidellm/data/deserializers/synthetic.py +++ b/src/guidellm/data/deserializers/synthetic.py @@ -75,6 +75,11 @@ def __iter__(self) -> Iterator[tuple[int, dict[str, Any]]]: prompt_tokens_count = next(prompt_tokens_sampler) output_tokens_count = next(output_tokens_sampler) + # NOTE: The unique prefix (iteration_count and samples_count) ensures + # each request is unique, which is important for caching behavior testing. + # This prefix adds 2-5 tokens overhead. When benchmarking models with + # strict context limits (e.g., 512 tokens), use prompt_tokens=500 or + # similar to leave headroom for this prefix. yield ( samples_count, { diff --git a/src/guidellm/data/entrypoints.py b/src/guidellm/data/entrypoints.py index 1d88f34f2..f39631187 100644 --- a/src/guidellm/data/entrypoints.py +++ b/src/guidellm/data/entrypoints.py @@ -46,7 +46,18 @@ def process_dataset( :raises ValueError: If the output path is invalid or pushing conditions unmet. """ builders.process_dataset( - data, output_path, processor, config, processor_args, data_args, - data_column_mapper, short_prompt_strategy, pad_char, concat_delimiter, - include_prefix_in_token_count, push_to_hub, hub_dataset_id, random_seed, + data, + output_path, + processor, + config, + processor_args, + data_args, + data_column_mapper, + short_prompt_strategy, + pad_char, + concat_delimiter, + include_prefix_in_token_count, + push_to_hub, + hub_dataset_id, + random_seed, ) diff --git a/src/guidellm/data/preprocessors/__init__.py b/src/guidellm/data/preprocessors/__init__.py index 6d6e722d8..a01c5b214 100644 --- a/src/guidellm/data/preprocessors/__init__.py +++ b/src/guidellm/data/preprocessors/__init__.py @@ -1,4 +1,5 @@ from .formatters import ( + EmbeddingsRequestFormatter, GenerativeAudioTranscriptionRequestFormatter, GenerativeAudioTranslationRequestFormatter, GenerativeChatCompletionsRequestFormatter, @@ -17,6 +18,7 @@ "ColumnMapperRegistry", "DataDependentPreprocessor", "DatasetPreprocessor", + "EmbeddingsRequestFormatter", "GenerativeAudioTranscriptionRequestFormatter", "GenerativeAudioTranslationRequestFormatter", "GenerativeChatCompletionsRequestFormatter", diff --git a/src/guidellm/data/preprocessors/formatters.py b/src/guidellm/data/preprocessors/formatters.py index 608128a64..2479c8090 100644 --- a/src/guidellm/data/preprocessors/formatters.py +++ b/src/guidellm/data/preprocessors/formatters.py @@ -9,6 +9,7 @@ from guidellm.schemas import GenerationRequest, GenerationRequestArguments, UsageMetrics __all__ = [ + "EmbeddingsRequestFormatter", "GenerativeAudioTranscriptionRequestFormatter", "GenerativeAudioTranslationRequestFormatter", "GenerativeChatCompletionsRequestFormatter", @@ -402,3 +403,88 @@ def __call__(self, columns: dict[str, list[Any]]) -> GenerationRequest: result = super().__call__(columns) result.request_type = "audio_translations" return result + + +@PreprocessorRegistry.register("embeddings") +class EmbeddingsRequestFormatter(RequestFormatter): + """ + Request formatter for embeddings API endpoints. + + Formats requests for embedding models that convert text into vector representations. + Unlike generative models, embeddings only process input text and return vectors, + so there are no output tokens or streaming. + """ + + def __init__( + self, + model: str, + extras: dict[str, Any] | GenerationRequestArguments | None = None, + encoding_format: str | None = None, + dimensions: int | None = None, + ): + """ + Initialize the embeddings request formatter. + + :param model: The model name to use for embeddings + :param extras: Additional request arguments + :param encoding_format: Format for the embedding vectors + (e.g., 'float', 'base64') + :param dimensions: Number of dimensions for the embedding vectors + """ + self.model = model + self.extras = ( + GenerationRequestArguments(**extras) + if extras and isinstance(extras, dict) + else extras + ) + self.encoding_format = encoding_format + self.dimensions = dimensions + + def __call__(self, columns: dict[str, list[Any]]) -> GenerationRequest: + """ + Format a request for the embeddings endpoint. + + :param columns: A dict of column types to values + :return: A GenerationRequest configured for embeddings + """ + arguments = GenerationRequestArguments() + arguments.body = {} + input_metrics = UsageMetrics() + output_metrics = UsageMetrics() + + if self.model is not None: + arguments.body["model"] = self.model + + # Embeddings don't support streaming + arguments.stream = False + + if self.encoding_format is not None: + arguments.body["encoding_format"] = self.encoding_format + + if self.dimensions is not None: + arguments.body["dimensions"] = self.dimensions + + if prompt_tokens := sum( + count for count in columns.get("prompt_tokens_count_column", []) if count + ): + input_metrics.text_tokens = prompt_tokens + + if self.extras: + arguments.model_combine(self.extras) + + prefix = "".join(pre for pre in columns.get("prefix_column", []) if pre) + text = "".join(txt for txt in columns.get("text_column", []) if txt) + if prefix or text: + input_text = prefix + text + arguments.body["input"] = input_text + input_metrics.add_text_metrics(input_text) + + # Embeddings don't have output tokens, only input processing + # The output is a vector, not text + + return GenerationRequest( + request_type="embeddings", + arguments=arguments, + input_metrics=input_metrics, + output_metrics=output_metrics, + ) diff --git a/src/guidellm/data/schemas.py b/src/guidellm/data/schemas.py index 16af56dff..763f18073 100644 --- a/src/guidellm/data/schemas.py +++ b/src/guidellm/data/schemas.py @@ -25,26 +25,28 @@ "audio_column", ] + class DataNotSupportedError(Exception): """ Exception raised when the data format is not supported by deserializer or config. """ + class DataConfig(StandardBaseModel): """ A generic parent class for various configs for the data package that can be passed in as key-value pairs or JSON. """ -class PreprocessDatasetConfig(DataConfig): +class PreprocessDatasetConfig(DataConfig): prompt_tokens: int = Field( description="The average number of text tokens retained or added to prompts.", gt=0, ) prompt_tokens_stdev: int | None = Field( description="The standard deviation of the number of tokens retained in or " - "added to prompts.", + "added to prompts.", gt=0, default=None, ) @@ -64,7 +66,7 @@ class PreprocessDatasetConfig(DataConfig): ) output_tokens_stdev: int | None = Field( description="The standard deviation of the number of tokens retained or " - "added to outputs.", + "added to outputs.", gt=0, default=None, ) @@ -84,6 +86,7 @@ class PreprocessDatasetConfig(DataConfig): default=None, ) + class SyntheticTextPrefixBucketConfig(StandardBaseModel): bucket_weight: int = Field( description="Weight of this bucket in the overall distribution.", @@ -151,7 +154,6 @@ class SyntheticTextDatasetConfig(DataConfig): default=None, ) - @model_validator(mode="after") def check_prefix_options(self) -> SyntheticTextDatasetConfig: if self.__pydantic_extra__ is not None: diff --git a/src/guidellm/mock_server/handlers/__init__.py b/src/guidellm/mock_server/handlers/__init__.py index 7dbc209ff..f4a34f75e 100644 --- a/src/guidellm/mock_server/handlers/__init__.py +++ b/src/guidellm/mock_server/handlers/__init__.py @@ -12,6 +12,12 @@ from .chat_completions import ChatCompletionsHandler from .completions import CompletionsHandler +from .embeddings import EmbeddingsHandler from .tokenizer import TokenizerHandler -__all__ = ["ChatCompletionsHandler", "CompletionsHandler", "TokenizerHandler"] +__all__ = [ + "ChatCompletionsHandler", + "CompletionsHandler", + "EmbeddingsHandler", + "TokenizerHandler", +] diff --git a/src/guidellm/mock_server/handlers/embeddings.py b/src/guidellm/mock_server/handlers/embeddings.py new file mode 100644 index 000000000..57b104a06 --- /dev/null +++ b/src/guidellm/mock_server/handlers/embeddings.py @@ -0,0 +1,174 @@ +""" +Embeddings API handler for the mock server. + +This module provides the EmbeddingsHandler class that implements the /v1/embeddings +endpoint for the guidellm mock server. It generates synthetic embedding vectors +to simulate realistic embedding model behavior for benchmarking and testing purposes. +The handler supports both single and batch text inputs with configurable dimensions +and realistic timing delays. +""" + +from __future__ import annotations + +import asyncio +import json +import random +import uuid + +from pydantic import ValidationError +from sanic import response +from sanic.request import Request +from sanic.response import HTTPResponse +from transformers import PreTrainedTokenizer + +from guidellm.mock_server.config import MockServerConfig +from guidellm.mock_server.models import ( + EmbeddingObject, + EmbeddingsRequest, + EmbeddingsResponse, + ErrorDetail, + ErrorResponse, + Usage, +) +from guidellm.mock_server.utils import MockTokenizer, sample_number + +__all__ = ["EmbeddingsHandler"] + + +class EmbeddingsHandler: + """ + Handler for the OpenAI /v1/embeddings endpoint in the mock server. + + This handler simulates the OpenAI embeddings API by processing incoming + requests and generating synthetic embedding vectors. It applies realistic + timing delays based on input token count to mimic actual embedding model + behavior for benchmarking purposes. + + Example: + :: + config = MockServerConfig(ttft_ms=50) + handler = EmbeddingsHandler(config) + response = await handler.handle(sanic_request) + """ + + def __init__(self, config: MockServerConfig) -> None: + """ + Initialize the embeddings handler with configuration settings. + + :param config: Mock server configuration containing timing parameters + and tokenizer settings + """ + self.config = config + self.tokenizer = ( + MockTokenizer() + if config.processor is None + else PreTrainedTokenizer.from_pretrained(config.processor) + ) + + async def handle(self, request: Request) -> HTTPResponse: + """ + Process an embeddings request and return the appropriate response. + + Validates the incoming request, generates synthetic embedding vectors, + and returns the response with usage statistics. + + :param request: Sanic request object containing the embeddings request data + :return: HTTP response with embedding data or error information + :raises ValidationError: When request validation fails + :raises json.JSONDecodeError: When request JSON is malformed + """ + try: + req_data = EmbeddingsRequest(**request.json) + except ValidationError as e: + return response.json( + ErrorResponse( + error=ErrorDetail( + message=f"Invalid request: {str(e)}", + type="invalid_request_error", + code="invalid_request", + ) + ).model_dump(), + status=400, + ) + except (json.JSONDecodeError, TypeError): + return response.json( + ErrorResponse( + error=ErrorDetail( + message="Invalid JSON in request body", + type="invalid_request_error", + code="invalid_json", + ) + ).model_dump(), + status=400, + ) + + return await self._handle_embeddings(req_data) + + async def _handle_embeddings(self, req: EmbeddingsRequest) -> HTTPResponse: + """ + Generate embeddings for the input text(s). + + Creates synthetic embedding vectors with realistic timing delays based on + the number of input tokens. Supports both single text and batch processing. + + :param req: Validated embeddings request containing input text(s) + :return: JSON HTTP response with embedding vectors and usage data + """ + inputs = [req.input] if isinstance(req.input, str) else req.input + + total_tokens = sum(len(self.tokenizer(text)) for text in inputs) + + # Simulate processing delay based on token count + # Use TTFT config as base delay per token + processing_delay = ( + sample_number(self.config.ttft_ms, self.config.ttft_ms_std) / 1000.0 + ) + await asyncio.sleep(processing_delay) + + # Determine embedding dimensions + # Default to 1536 (OpenAI ada-002 dimension) or use requested dimensions + dimensions = req.dimensions if req.dimensions else 1536 + + # Generate synthetic embeddings for each input + embeddings_data = [] + for index, _text in enumerate(inputs): + # Generate random normalized embedding vector + embedding = self._generate_embedding(dimensions) + embeddings_data.append( + EmbeddingObject( + embedding=embedding, + index=index, + ) + ) + + embeddings_response = EmbeddingsResponse( + id=f"embd-{uuid.uuid4().hex[:29]}", + model=req.model, + data=embeddings_data, + usage=Usage( + prompt_tokens=total_tokens, + completion_tokens=0, + ), + ) + + return response.json(embeddings_response.model_dump()) + + def _generate_embedding(self, dimensions: int) -> list[float]: + """ + Generate a random normalized embedding vector. + + Creates a synthetic embedding vector with the specified number of + dimensions, normalized to unit length to mimic real embedding outputs. + + :param dimensions: Number of dimensions in the embedding vector + :return: Normalized embedding vector as a list of floats + """ + # Generate random vector + embedding = [random.gauss(0, 1) for _ in range(dimensions)] + + # Normalize to unit length + magnitude = sum(x * x for x in embedding) ** 0.5 + if magnitude > 0: + embedding = [x / magnitude for x in embedding] + + return embedding diff --git a/src/guidellm/mock_server/models.py b/src/guidellm/mock_server/models.py index cd342f7a9..19a2b86e0 100644 --- a/src/guidellm/mock_server/models.py +++ b/src/guidellm/mock_server/models.py @@ -26,6 +26,9 @@ "CompletionsResponse", "DetokenizeRequest", "DetokenizeResponse", + "EmbeddingObject", + "EmbeddingsRequest", + "EmbeddingsResponse", "ErrorDetail", "ErrorResponse", "StreamOptions", @@ -486,6 +489,60 @@ class DetokenizeResponse(BaseModel): text: str = Field(description="Reconstructed text from tokens") +class EmbeddingsRequest(BaseModel): + """Request parameters for embeddings API endpoints. + + Converts input text into vector embeddings for semantic search, + clustering, and other ML tasks. Supports single or batch text inputs. + """ + + model: str = Field(description="Model identifier to use for embeddings") + input: str | list[str] = Field(description="Input text(s) to embed") + encoding_format: Literal["float", "base64"] | None = Field( + default="float", description="Format for returned embeddings" + ) + dimensions: int | None = Field( + default=None, description="Number of dimensions in output embeddings" + ) + user: str | None = Field( + default=None, description="User identifier for tracking and abuse monitoring" + ) + + +class EmbeddingObject(BaseModel): + """A single embedding vector with metadata. + + Represents one embedding result with the vector data, + index position, and object type identifier. + """ + + object: Literal["embedding"] = Field( + default="embedding", description="Object type identifier" + ) + embedding: list[float] = Field(description="The embedding vector") + index: int = Field(description="Index of this embedding in the response") + + +class EmbeddingsResponse(BaseModel): + """Response from embeddings API endpoints. + + Contains embedding vectors for each input text along with + usage statistics and metadata. + """ + + id: str = Field(description="Unique identifier for this embeddings request") + object: Literal["list"] = Field( + default="list", description="Object type identifier" + ) + created: int = Field( + default_factory=lambda: int(time.time()), + description="Unix timestamp of creation", + ) + model: str = Field(description="Model used for embeddings") + data: list[EmbeddingObject] = Field(description="List of embedding objects") + usage: Usage = Field(description="Token usage statistics") + + class ErrorDetail(BaseModel): """Detailed error information for API failures. diff --git a/src/guidellm/mock_server/server.py b/src/guidellm/mock_server/server.py index e85c61344..96ecf182a 100644 --- a/src/guidellm/mock_server/server.py +++ b/src/guidellm/mock_server/server.py @@ -23,6 +23,7 @@ from guidellm.mock_server.handlers import ( ChatCompletionsHandler, CompletionsHandler, + EmbeddingsHandler, TokenizerHandler, ) @@ -56,6 +57,7 @@ def __init__(self, config: MockServerConfig) -> None: self.app = Sanic("guidellm-mock-server") self.chat_handler = ChatCompletionsHandler(config) self.completions_handler = CompletionsHandler(config) + self.embeddings_handler = EmbeddingsHandler(config) self.tokenizer_handler = TokenizerHandler(config) self._setup_middleware() @@ -114,6 +116,12 @@ async def completions(request: Request): return response.text("", status=204) return await self.completions_handler.handle(request) + @self.app.route("/v1/embeddings", methods=["POST", "OPTIONS"]) + async def embeddings(request: Request): + if request.method == "OPTIONS": + return response.text("", status=204) + return await self.embeddings_handler.handle(request) + @self.app.route("/tokenize", methods=["POST", "OPTIONS"]) async def tokenize(request: Request): if request.method == "OPTIONS": diff --git a/src/guidellm/schemas/request.py b/src/guidellm/schemas/request.py index a5193474c..5a8948cae 100644 --- a/src/guidellm/schemas/request.py +++ b/src/guidellm/schemas/request.py @@ -29,6 +29,7 @@ "chat_completions", "audio_transcriptions", "audio_translations", + "embeddings", ] diff --git a/tests/e2e/test_embeddings_benchmark.py b/tests/e2e/test_embeddings_benchmark.py new file mode 100644 index 000000000..aa40aaa55 --- /dev/null +++ b/tests/e2e/test_embeddings_benchmark.py @@ -0,0 +1,179 @@ +# E2E tests for embeddings endpoint benchmarking + +from pathlib import Path + +import pytest + +from tests.e2e.utils import ( + GuidellmClient, + assert_constraint_triggered, + assert_no_python_exceptions, + cleanup_report_file, + load_benchmark_report, +) +from tests.e2e.vllm_sim_server import VllmSimServer + + +@pytest.fixture(scope="module") +def embeddings_server(): + """ + Pytest fixture to start and stop the embeddings server for the entire module. + """ + server = VllmSimServer( + port=8001, + model="text-embedding-ada-002", + mode="random", + time_to_first_token=10, # 10ms processing time + inter_token_latency=1, # Not really used for embeddings + ) + try: + server.start() + yield server + finally: + server.stop() + + +@pytest.mark.timeout(30) +def test_embeddings_max_requests_benchmark(embeddings_server: VllmSimServer): + """ + Test that the max requests constraint works properly for embeddings endpoint. + + ## WRITTEN BY AI ## + """ + report_path = Path("tests/e2e/embeddings_max_requests.json") + rate = 5 + max_requests = 20 + + # Create and configure the guidellm client + client = GuidellmClient(target=embeddings_server.get_url(), output_path=report_path) + + try: + # Start the benchmark + client.start_benchmark( + rate=rate, + max_requests=max_requests, + request_type="embeddings", + ) + + # Wait for the benchmark to complete + client.wait_for_completion(timeout=30) + + # Assert no Python exceptions occurred + assert_no_python_exceptions(client.stderr) + + # Load and validate the report + report = load_benchmark_report(report_path) + benchmark = report["benchmarks"][0] + + # Check that the max requests constraint was triggered + assert_constraint_triggered( + benchmark, "max_requests", {"processed_exceeded": True} + ) + + # Validate successful requests count + successful_requests = benchmark["requests"]["successful"] + assert len(successful_requests) == max_requests, ( + f"Expected {max_requests} successful requests, " + f"got {len(successful_requests)}" + ) + + finally: + cleanup_report_file(report_path) + + +@pytest.mark.timeout(30) +def test_embeddings_max_seconds_benchmark(embeddings_server: VllmSimServer): + """ + Test that the max seconds constraint works properly for embeddings endpoint. + + ## WRITTEN BY AI ## + """ + report_path = Path("tests/e2e/embeddings_max_seconds.json") + rate = 4 + duration = 5 + max_seconds = duration + + # Create and configure the guidellm client + client = GuidellmClient(target=embeddings_server.get_url(), output_path=report_path) + + try: + # Start the benchmark + client.start_benchmark( + rate=rate, + max_seconds=max_seconds, + request_type="embeddings", + ) + + # Wait for the benchmark to complete + client.wait_for_completion(timeout=30) + + # Assert no Python exceptions occurred + assert_no_python_exceptions(client.stderr) + + # Load and validate the report + report = load_benchmark_report(report_path) + benchmark = report["benchmarks"][0] + + # Check that the max duration constraint was triggered + assert_constraint_triggered( + benchmark, "max_seconds", {"duration_exceeded": True} + ) + + # Validate that we have successful requests + successful_requests = benchmark["requests"]["successful"] + assert len(successful_requests) > 0, "Expected at least one successful request" + + finally: + cleanup_report_file(report_path) + + +@pytest.mark.timeout(30) +def test_embeddings_rate_benchmark(embeddings_server: VllmSimServer): + """ + Test basic rate-based benchmarking for embeddings endpoint. + + ## WRITTEN BY AI ## + """ + report_path = Path("tests/e2e/embeddings_rate.json") + rate = 10 + max_requests = 30 + + # Create and configure the guidellm client + client = GuidellmClient(target=embeddings_server.get_url(), output_path=report_path) + + try: + # Start the benchmark + client.start_benchmark( + rate=rate, + max_requests=max_requests, + request_type="embeddings", + ) + + # Wait for the benchmark to complete + client.wait_for_completion(timeout=30) + + # Assert no Python exceptions occurred + assert_no_python_exceptions(client.stderr) + + # Load and validate the report + report = load_benchmark_report(report_path) + benchmark = report["benchmarks"][0] + + # Validate successful requests + successful_requests = benchmark["requests"]["successful"] + assert len(successful_requests) == max_requests, ( + f"Expected {max_requests} successful requests, " + f"got {len(successful_requests)}" + ) + + # Validate that all requests have the expected fields + for request in successful_requests: + assert "start_time" in request, "Request missing start_time" + assert "end_time" in request, "Request missing end_time" + # For embeddings, we don't expect output_tokens, only input_tokens + assert "prompt" in request or "input_tokens" in request, ( + "Request missing prompt or input_tokens" + ) + + finally: + cleanup_report_file(report_path) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 55baa89d2..b22b188fb 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -52,6 +52,7 @@ def start_benchmark( over_saturation: dict[str, Any] | None = None, data: str = "prompt_tokens=256,output_tokens=128", processor: str = "gpt2", + request_type: str | None = None, additional_args: str = "", extra_env: dict[str, str] | None = None, ) -> None: @@ -67,6 +68,7 @@ def start_benchmark( Passed as JSON string to --over-saturation CLI argument. :param data: Data configuration string :param processor: Processor/tokenizer to use + :param request_type: Type of request (e.g., "embeddings") :param additional_args: Additional command line arguments :param extra_env: Additional environment variables to set """ @@ -116,6 +118,9 @@ def start_benchmark( ] ) + if request_type is not None: + cmd_parts.append(f"--request-type {request_type}") + if additional_args: cmd_parts.append(additional_args) diff --git a/tests/unit/backends/test_response_handlers.py b/tests/unit/backends/test_response_handlers.py index f4be83ff5..5707683d7 100644 --- a/tests/unit/backends/test_response_handlers.py +++ b/tests/unit/backends/test_response_handlers.py @@ -5,6 +5,7 @@ from guidellm.backends import ( AudioResponseHandler, ChatCompletionsResponseHandler, + EmbeddingsResponseHandler, GenerationResponseHandler, GenerationResponseHandlerFactory, TextCompletionsResponseHandler, @@ -53,6 +54,7 @@ def test_class_signatures(self): ("chat_completions", None, ChatCompletionsResponseHandler), ("audio_transcriptions", None, AudioResponseHandler), ("audio_translations", None, AudioResponseHandler), + ("embeddings", None, EmbeddingsResponseHandler), ( "text_completions", {"text_completions": ChatCompletionsResponseHandler}, @@ -64,6 +66,7 @@ def test_class_signatures(self): "chat_completions", "audio_transcriptions", "audio_translations", + "embeddings", "override_text_completions", ], ) @@ -721,3 +724,199 @@ def test_extract_metrics( assert output_metrics.text_tokens == expected_output_tokens assert output_metrics.text_words == (len(text.split()) if text else 0) assert output_metrics.text_characters == len(text) + + +class TestEmbeddingsResponseHandler: + @pytest.fixture( + params=[{}], + ids=["default"], + ) + def valid_instances(self, request): + """Create instance of EmbeddingsResponseHandler.""" + return EmbeddingsResponseHandler() + + @pytest.mark.smoke + def test_class_signatures(self): + """ + Test EmbeddingsResponseHandler class signatures. + + ## WRITTEN BY AI ## + """ + handler = EmbeddingsResponseHandler() + assert hasattr(handler, "compile_non_streaming") + assert hasattr(handler, "add_streaming_line") + assert hasattr(handler, "compile_streaming") + assert hasattr(handler, "extract_metrics") + assert hasattr(handler, "streaming_embeddings") + assert hasattr(handler, "streaming_usage") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """ + Test EmbeddingsResponseHandler initialization. + + ## WRITTEN BY AI ## + """ + instance = valid_instances + assert isinstance(instance, EmbeddingsResponseHandler) + assert instance.streaming_embeddings == [] + assert instance.streaming_usage is None + + @pytest.mark.smoke + @pytest.mark.parametrize( + ( + "response", + "expected_embeddings_count", + "expected_input_tokens", + ), + [ + ( + { + "data": [ + {"embedding": [0.1, 0.2, 0.3]}, + {"embedding": [0.4, 0.5, 0.6]}, + ], + "usage": {"prompt_tokens": 10}, + }, + 2, + 10, + ), + ( + { + "data": [ + {"embedding": [0.1] * 1536}, + ], + "usage": {"prompt_tokens": 5}, + }, + 1, + 5, + ), + ( + { + "data": [], + "usage": {}, + }, + 0, + None, + ), + ], + ) + def test_non_streaming( + self, + valid_instances, + generation_request, + response, + expected_embeddings_count, + expected_input_tokens, + ): + """ + Test compile_non_streaming method for embeddings. + + ## WRITTEN BY AI ## + """ + instance: EmbeddingsResponseHandler = valid_instances + + result = instance.compile_non_streaming(generation_request, response) + + # Text should be JSON-formatted embeddings + import json + + text_data = json.loads(result.text) + assert "embeddings" in text_data + assert len(text_data["embeddings"]) == expected_embeddings_count + assert result.input_metrics.text_tokens == expected_input_tokens + # Embeddings don't generate tokens + assert result.output_metrics.text_tokens == 0 + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("lines", "expected_embeddings_count", "expected_input_tokens"), + [ + ( + [ + 'data: {"data": [{"embedding": [0.1, 0.2, 0.3]}], "usage": {}}', + ( + 'data: {"data": [{"embedding": [0.4, 0.5, 0.6]}], ' + '"usage": {"prompt_tokens": 10}}' + ), + "data: [DONE]", + ], + 2, + 10, + ), + ( + [ + 'data: {"data": [{"embedding": [0.1]}], "usage": {}}', + "data: [DONE]", + ], + 1, + None, + ), + ( + ["data: [DONE]"], + 0, + None, + ), + ], + ) + def test_streaming( + self, + valid_instances, + generation_request, + lines, + expected_embeddings_count, + expected_input_tokens, + ): + """ + Test streaming pathway for embeddings. + + ## WRITTEN BY AI ## + """ + instance: EmbeddingsResponseHandler = valid_instances + + updated_count = 0 + for line in lines: + result = instance.add_streaming_line(line) + if result == 1: + updated_count += 1 + elif result is None: + break + + response = instance.compile_streaming(generation_request) + + # Text should be JSON-formatted embeddings + import json + + text_data = json.loads(response.text) + assert "embeddings" in text_data + assert len(text_data["embeddings"]) == expected_embeddings_count + assert response.input_metrics.text_tokens == expected_input_tokens + # Embeddings don't generate tokens + assert response.output_metrics.text_tokens == 0 + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("usage", "expected_input_tokens"), + [ + ({"prompt_tokens": 10}, 10), + ({"prompt_tokens": 100}, 100), + (None, None), + ({}, None), + ], + ) + def test_extract_metrics( + self, + valid_instances, + usage, + expected_input_tokens, + ): + """ + Test extract_metrics method for embeddings. + + ## WRITTEN BY AI ## + """ + instance: EmbeddingsResponseHandler = valid_instances + input_metrics, output_metrics = instance.extract_metrics(usage) + + assert input_metrics.text_tokens == expected_input_tokens + assert output_metrics.text_tokens == 0 # Embeddings don't generate tokens diff --git a/tests/unit/data/deserializers/test_synthetic.py b/tests/unit/data/deserializers/test_synthetic.py index eda02ef58..3664470be 100644 --- a/tests/unit/data/deserializers/test_synthetic.py +++ b/tests/unit/data/deserializers/test_synthetic.py @@ -413,7 +413,8 @@ def test_load_config_file_yaml(self): try: loaded_config = config_module._load_config_file( - yaml_path, SyntheticTextDatasetConfig, + yaml_path, + SyntheticTextDatasetConfig, ) assert loaded_config.prompt_tokens == 60 @@ -443,7 +444,8 @@ def test_load_config_file_config_extension(self): try: loaded_config = config_module._load_config_file( - config_path, SyntheticTextDatasetConfig, + config_path, + SyntheticTextDatasetConfig, ) assert loaded_config.prompt_tokens == 90 @@ -460,7 +462,8 @@ def test_load_config_str_json(self): """ json_str = '{"prompt_tokens": 50, "output_tokens": 25}' loaded_config = config_module._load_config_str( - json_str, SyntheticTextDatasetConfig, + json_str, + SyntheticTextDatasetConfig, ) assert loaded_config.prompt_tokens == 50 @@ -474,7 +477,8 @@ def test_load_config_str_key_value(self): """ kv_str = "prompt_tokens=50,output_tokens=25" loaded_config = config_module._load_config_str( - kv_str, SyntheticTextDatasetConfig, + kv_str, + SyntheticTextDatasetConfig, ) assert loaded_config.prompt_tokens == 50 @@ -488,7 +492,8 @@ def test_load_config_str_invalid_format(self): """ with pytest.raises(DataNotSupportedError, match="Unsupported string data"): config_module._load_config_str( - "invalid_format_string", SyntheticTextDatasetConfig, + "invalid_format_string", + SyntheticTextDatasetConfig, ) @pytest.mark.regression @@ -498,7 +503,8 @@ def test_load_config_file_non_existent(self): ### WRITTEN BY AI ### """ loaded_config = config_module._load_config_file( - "/non/existent/path.config", SyntheticTextDatasetConfig, + "/non/existent/path.config", + SyntheticTextDatasetConfig, ) assert loaded_config is None diff --git a/tests/unit/data/test_builders.py b/tests/unit/data/test_builders.py index 946b9cd1b..d0626a739 100644 --- a/tests/unit/data/test_builders.py +++ b/tests/unit/data/test_builders.py @@ -52,60 +52,72 @@ def decode_side_effect(tokens, skip_special_tokens=False): @pytest.fixture def sample_dataset_default_columns(): """Sample dataset with default column names.""" - return Dataset.from_dict({ - "prompt": [ - ( - "This is a very long prompt that should be sufficient for " - "testing purposes. " - ) * 10, - "Short.", - ( - "Another very long prompt for testing the dataset processing " - "functionality. " - ) * 10, - ], - }) + return Dataset.from_dict( + { + "prompt": [ + ( + "This is a very long prompt that should be sufficient for " + "testing purposes. " + ) + * 10, + "Short.", + ( + "Another very long prompt for testing the dataset processing " + "functionality. " + ) + * 10, + ], + } + ) @pytest.fixture def sample_dataset_custom_columns(): """Sample dataset with custom column names requiring mapping.""" - return Dataset.from_dict({ - "question": [ - ( - "What is the meaning of life? This is a longer question that " - "should work for testing. " - ) * 10, - ( - "How does this work? Let me explain in detail how this system " - "functions. " - ) * 10, - ( - "Tell me about machine learning. Machine learning is a " - "fascinating field. " - ) * 10, - ], - }) + return Dataset.from_dict( + { + "question": [ + ( + "What is the meaning of life? This is a longer question that " + "should work for testing. " + ) + * 10, + ( + "How does this work? Let me explain in detail how this system " + "functions. " + ) + * 10, + ( + "Tell me about machine learning. Machine learning is a " + "fascinating field. " + ) + * 10, + ], + } + ) @pytest.fixture def sample_dataset_with_prefix(): """Sample dataset with prefix column.""" - return Dataset.from_dict({ - "prompt": [ - ( - "This is a long prompt that should be sufficient for testing " - "purposes. " - ) * 10, - "Another long prompt here that will work for testing. " * 10, - "Yet another long prompt for testing purposes. " * 10, - ], - "system_prompt": [ - "You are a helpful assistant.", - "You are a helpful assistant.", - "You are a helpful assistant.", - ], - }) + return Dataset.from_dict( + { + "prompt": [ + ( + "This is a long prompt that should be sufficient for testing " + "purposes. " + ) + * 10, + "Another long prompt here that will work for testing. " * 10, + "Yet another long prompt for testing purposes. " * 10, + ], + "system_prompt": [ + "You are a helpful assistant.", + "You are a helpful assistant.", + "You are a helpful assistant.", + ], + } + ) @pytest.fixture @@ -192,30 +204,32 @@ def test_process_dataset_concatenate_strategy( # Create a dataset with short prompts that can be concatenated to reach target # Use a lower target (15 tokens) so concatenation is achievable short_config = '{"prompt_tokens": 15, "output_tokens": 10}' - short_prompts_dataset = Dataset.from_dict({ - "prompt": [ - "A", # 1 char = 1 token - "B", # 1 char = 1 token - "C", # 1 char = 1 token - "D", # 1 char = 1 token - "E", # 1 char = 1 token - "F", # 1 char = 1 token - "G", # 1 char = 1 token - "H", # 1 char = 1 token - "I", # 1 char = 1 token - "J", # 1 char = 1 token - "K", # 1 char = 1 token - "L", # 1 char = 1 token - "M", # 1 char = 1 token - "N", # 1 char = 1 token - "O", # 1 char = 1 token - "P", # 1 char = 1 token - "Q", # 1 char = 1 token - "R", # 1 char = 1 token - "S", # 1 char = 1 token - "T", # 1 char = 1 token - ], - }) + short_prompts_dataset = Dataset.from_dict( + { + "prompt": [ + "A", # 1 char = 1 token + "B", # 1 char = 1 token + "C", # 1 char = 1 token + "D", # 1 char = 1 token + "E", # 1 char = 1 token + "F", # 1 char = 1 token + "G", # 1 char = 1 token + "H", # 1 char = 1 token + "I", # 1 char = 1 token + "J", # 1 char = 1 token + "K", # 1 char = 1 token + "L", # 1 char = 1 token + "M", # 1 char = 1 token + "N", # 1 char = 1 token + "O", # 1 char = 1 token + "P", # 1 char = 1 token + "Q", # 1 char = 1 token + "R", # 1 char = 1 token + "S", # 1 char = 1 token + "T", # 1 char = 1 token + ], + } + ) # Setup mocks mock_check_processor.return_value = tokenizer_mock @@ -323,8 +337,9 @@ def test_process_dataset_pad_strategy( # Verify that prompts meet minimum token count requirements actual_tokens = len(tokenizer_mock.encode(row["prompt"])) - assert actual_tokens >= 50, \ + assert actual_tokens >= 50, ( f"Padded prompt should have at least 50 tokens, got {actual_tokens}" + ) assert row["prompt_tokens_count"] == actual_tokens # For the "Short." prompt (index 1), verify it was padded @@ -527,12 +542,14 @@ def test_process_dataset_with_instruction_column( """ # Create dataset with 'instruction' column (one of the default # text_column names) - dataset = Dataset.from_dict({ - "instruction": [ - "Follow these instructions carefully. " * 20, - "Complete the task as described. " * 20, - ], - }) + dataset = Dataset.from_dict( + { + "instruction": [ + "Follow these instructions carefully. " * 20, + "Complete the task as described. " * 20, + ], + } + ) # Setup mocks mock_check_processor.return_value = tokenizer_mock @@ -823,10 +840,12 @@ def test_process_dataset_empty_after_filtering( ## WRITTEN BY AI ## """ # Create dataset with only very short prompts that will be filtered out - dataset = Dataset.from_dict({ - # Very short prompts (1 char each, less than 50 tokens) - "prompt": ["A", "B", "C"], - }) + dataset = Dataset.from_dict( + { + # Very short prompts (1 char each, less than 50 tokens) + "prompt": ["A", "B", "C"], + } + ) # Setup mocks mock_check_processor.return_value = tokenizer_mock @@ -1462,8 +1481,9 @@ def test_prompt_trimming_accuracy( # Verify all prompts are trimmed to exactly 50 tokens for row in saved_dataset: actual_tokens = len(tokenizer_mock.encode(row["prompt"])) - assert actual_tokens == 50, \ + assert actual_tokens == 50, ( f"Prompt not trimmed correctly: expected 50 tokens, got {actual_tokens}" + ) @pytest.mark.sanity @patch("guidellm.data.builders.save_dataset_to_file") @@ -1515,8 +1535,9 @@ def test_prompt_padding_accuracy( for row in saved_dataset: prompt_text = row["prompt"] actual_tokens = len(tokenizer_mock.encode(prompt_text)) - assert actual_tokens == 100, \ + assert actual_tokens == 100, ( f"Prompt not padded correctly: expected 100 tokens, got {actual_tokens}" + ) assert row["prompt_tokens_count"] == 100 # Verify that pad_char "X" appears in the padded prompts @@ -1813,9 +1834,11 @@ def test_process_dataset_push_to_hub_called( ): """Test that push_to_hub is called when push_to_hub=True.""" # Create a dataset with prompts long enough to be processed - sample_dataset = Dataset.from_dict({ - "prompt": ["abc " * 50], # Long enough - }) + sample_dataset = Dataset.from_dict( + { + "prompt": ["abc " * 50], # Long enough + } + ) mock_check_processor.return_value = tokenizer_mock mock_deserializer_factory_class.deserialize.return_value = sample_dataset @@ -1854,9 +1877,11 @@ def test_process_dataset_push_to_hub_not_called( ): """Test that push_to_hub is not called when push_to_hub=False.""" # Create a dataset with prompts long enough to be processed - sample_dataset = Dataset.from_dict({ - "prompt": ["abc " * 50], # Long enough - }) + sample_dataset = Dataset.from_dict( + { + "prompt": ["abc " * 50], # Long enough + } + ) mock_check_processor.return_value = tokenizer_mock mock_deserializer_factory_class.deserialize.return_value = sample_dataset @@ -1918,15 +1943,18 @@ def test_strategy_handler_called( ): """Test that strategy handlers are called during dataset processing.""" from guidellm.data.builders import STRATEGY_HANDLERS + mock_handler = MagicMock(return_value="processed_prompt") with patch.dict(STRATEGY_HANDLERS, {ShortPromptStrategy.IGNORE: mock_handler}): # Create a dataset with prompts that need processing - sample_dataset = Dataset.from_dict({ - "prompt": [ - "abc" * 20, # Long enough to pass - "def" * 20, # Long enough to pass - ], - }) + sample_dataset = Dataset.from_dict( + { + "prompt": [ + "abc" * 20, # Long enough to pass + "def" * 20, # Long enough to pass + ], + } + ) mock_check_processor.return_value = tokenizer_mock mock_deserializer_factory_class.deserialize.return_value = sample_dataset diff --git a/tests/unit/schemas/test_request.py b/tests/unit/schemas/test_request.py index 91fb979b9..1b41da2ca 100644 --- a/tests/unit/schemas/test_request.py +++ b/tests/unit/schemas/test_request.py @@ -25,11 +25,12 @@ def test_generative_request_type(): """Test that GenerativeRequestType is defined correctly.""" assert hasattr(typing, "get_args") args = typing.get_args(GenerativeRequestType) - assert len(args) == 4 + assert len(args) == 5 assert "text_completions" in args assert "chat_completions" in args assert "audio_transcriptions" in args assert "audio_translations" in args + assert "embeddings" in args class TestGenerationRequestArguments: