diff --git a/assets/training/endpoint_evaluation/component/asset.yaml b/assets/training/endpoint_evaluation/component/asset.yaml new file mode 100644 index 0000000000..3277efbab0 --- /dev/null +++ b/assets/training/endpoint_evaluation/component/asset.yaml @@ -0,0 +1,6 @@ +type: component +spec: spec.yaml +categories: ["Benchmark", "Speculative Decoding"] +test: + pytest: + enabled: false \ No newline at end of file diff --git a/assets/training/endpoint_evaluation/component/spec.yaml b/assets/training/endpoint_evaluation/component/spec.yaml new file mode 100644 index 0000000000..6f67011171 --- /dev/null +++ b/assets/training/endpoint_evaluation/component/spec.yaml @@ -0,0 +1,106 @@ +$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json +type: command + +name: endpoint_benchmarking +display_name: Endpoint Benchmarking Component +description: Runs benchmark on AzureML online endpoints. +version: 0.0.1 +is_deterministic: True + +inputs: + base_scoring_url: + type: string + optional: False + description: The URL of the base endpoint. + base_connection_name: + type: string + optional: False + description: The name of the connection to fetch the API_KEY for the base endpoint authentication. + target_scoring_url: + type: string + optional: False + description: The URL of the target endpoint. + target_connection_name: + type: string + optional: False + description: The name of the connection to fetch the API_KEY for the target endpoint authentication. + base_model: + type: string + optional: False + default: nvidia/Llama-3.1-8B-Instruct-FP8 + description: HuggingFace repo ID of the model for the base endpoint. + target_model: + type: string + optional: False + default: nvidia/Llama-3.1-8B-Instruct-FP8 + description: HuggingFace repo ID of the model for the target endpoint. + base_backend: + type: string + optional: True + default: sglang + description: LLM Inference Engine for base endpoint. + enum: + - sglang + - vllm + target_backend: + type: string + optional: True + default: sglang + description: LLM Inference Engine for target endpoint. + enum: + - sglang + - vllm + dataset_name: + type: string + optional: True + default: sharegpt + description: Depending on the LLM Inference Engine. + enum: + - sharegpt + request_rate: + type: integer + optional: True + default: 10 + description: The request rate per second for sending requests to the endpoint. + num_prompts: + type: integer + optional: True + default: 2500 + description: The total number of prompts to send to the endpoint. + disable_shuffle: + type: boolean + optional: True + default: True + description: Disable shuffling the dataset before sending requests. + trials: + type: integer + optional: True + default: 5 + description: Number of trials to run the benchmark, result will be averaged over all trials. + +outputs: + metrics: + type: uri_folder + description: The output folder containing the benchmarking metrics. + +environment: azureml://registries/azureml/environments/acft-rft-training/versions/1 +resources: + instance_count: 1 + +code: ../src +command: >- + python main.py + --output-file ${{outputs.metrics}} + --base-url ${{inputs.base_scoring_url}} + --connection-name ${{inputs.base_connection_name}} + --base-model ${{inputs.base_model}} + --target-url ${{inputs.target_scoring_url}} + --target-connection-name ${{inputs.target_connection_name}} + --target-model ${{inputs.target_model}} + $[[--base-backend ${{inputs.base_backend}}]] + $[[--target-backend ${{inputs.target_backend}}]] + $[[--trials ${{inputs.trials}}]] + $[[--dataset-name ${{inputs.dataset_name}}]] + $[[--request-rate ${{inputs.request_rate}}]] + $[[--num-prompts ${{inputs.num_prompts}}]] + $[[--disable-shuffle ${{inputs.disable_shuffle}}]] diff --git a/assets/training/endpoint_evaluation/src/bench_serving.py b/assets/training/endpoint_evaluation/src/bench_serving.py new file mode 100644 index 0000000000..414ac46c1d --- /dev/null +++ b/assets/training/endpoint_evaluation/src/bench_serving.py @@ -0,0 +1,1090 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Benchmark online serving with dynamic requests. + +Usage: +python3 -m sglang.bench_serving --backend sglang --num-prompt 10 + +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 + --random-input 1024 --random-output 1024 --random-range-ratio 0.5 +python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range + 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi +""" + +import argparse +import asyncio +import json +import os +import random +import sys +import time +import traceback +import warnings +from argparse import ArgumentParser +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple + +import aiohttp +import numpy as np +import requests +from data_processing import MsgContent, SampleOutput, get_dataset +from tqdm.asyncio import tqdm +from transformers import PreTrainedTokenizerBase + +from sglang.bench_serving import get_tokenizer, remove_prefix, set_ulimit + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60) + +global args + + +@dataclass +class RequestFuncInput: + """Input data structure for benchmark request functions. + + Attributes: + prompts: List of tuples containing (message_content, input_length, output_length) + api_url: URL endpoint for the API + model: Model identifier + lora_name: LoRA adapter name if applicable + extra_request_body: Additional parameters for the request payload + prev_messages: Previous messages in conversation for multiturn chat + finished_prompts: Number of completed prompts in the conversation + """ + + prompts: List[Tuple[MsgContent, int, int]] + api_url: str + model: str + lora_name: str + extra_request_body: Dict[str, Any] + + # For multiturn chat, store the context + prev_messages: List = field(default_factory=list) + finished_prompts: int = 0 + + +@dataclass +class RequestFuncOutput: + """Output data structure for benchmark request functions. + + Attributes: + generated_text: List of generated text responses + prompt_len: List of prompt lengths in tokens + output_len: List of output lengths in tokens + latency: List of end-to-end latencies in seconds + ttft: List of time-to-first-token values in seconds + itl: List of inter-token latencies in seconds + success: Whether the request completed successfully + error: Error message if request failed + """ + + generated_text: List[str] = field(default_factory=list) + prompt_len: List[int] = field(default_factory=list) + output_len: List[int] = field(default_factory=list) + latency: List[float] = field(default_factory=list) + ttft: List[float] = field(default_factory=list) + itl: List[float] = field(default_factory=list) # List of inter-token latencies + + success: bool = False + error: str = "" + + +# set ignore_eos True by default +async def async_request_openai_completions( + request_func_input: RequestFuncInput, + queue: asyncio.Queue, + tokenizer: PreTrainedTokenizerBase, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + """Make asynchronous requests to OpenAI-compatible completions API. + + Args: + request_func_input: Input configuration for the request + queue: Async queue for managing concurrent requests + tokenizer: Tokenizer for processing text + pbar: Optional progress bar for tracking completion + + Returns: + RequestFuncOutput: Results from the API request including generated text, + latency metrics, and success/error status + """ + api_url = request_func_input.api_url + assert api_url.endswith("completions"), "OpenAI Completions API URL must end with 'completions'." + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "model": request_func_input.model, + "temperature": 0.0, + "best_of": 1, + "stream": not args.disable_stream, + "stream_options": {"include_usage": True}, + "ignore_eos": not args.disable_ignore_eos, + **request_func_input.extra_request_body, + } + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + output = RequestFuncOutput() + + prompt_idx = request_func_input.finished_prompts + messages = request_func_input.prev_messages + prompt, input_len, max_tokens = request_func_input.prompts[prompt_idx] + prompt_len = sum( + prompt[1] + prompt[2] for prompt in request_func_input.prompts[:prompt_idx] # input_len + output_len + ) + prompt_len += input_len + + # Messages + messages.append( + { + "role": "user", + "content": prompt, + } + ) + payload["messages"] = messages + payload["max_tokens"] = max_tokens + + # output.prompt_len = request_func_input.prompt_len + # print(payload) + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload, headers=headers) as response: + if response.status == 200: + actual_prompt_len = prompt_len - 1 + actual_output_len = 0 + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + latency = time.perf_counter() - st + if chunk == "[DONE]": + pass + else: + data = json.loads(chunk) + timestamp = time.perf_counter() + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if "usage" in data and data["usage"] is not None and len(data["usage"]) > 0: + actual_prompt_len = data["usage"]["prompt_tokens"] + actual_output_len = data["usage"]["completion_tokens"] + continue + delta = data["choices"][0]["delta"] + + if delta.get("content", None): + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft.append(ttft) + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + generated_text += delta["content"] + most_recent_timestamp = timestamp + + output.prompt_len.append(actual_prompt_len) # truncate + output.output_len.append(actual_output_len) + output.generated_text.append(generated_text) + output.success = True + output.latency.append(latency) + + # Prepare for the new request + request_func_input.prompts[prompt_idx] = ( + prompt, + input_len, + actual_output_len, # changes from max_tokens to output_len + ) + prompt_idx += 1 + messages.append( + { + "role": "assistant", + "content": generated_text, + } + ) + + # Move the new request to the end of the queue + if prompt_idx < len(request_func_input.prompts): + request_func_input.finished_prompts = prompt_idx + request_func_input.prev_messages = messages + await queue.put(request_func_input) + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_profile(api_url: str) -> RequestFuncOutput: + """Send profiling control request to the API endpoint. + + Args: + api_url: URL for the profiling control endpoint + + Returns: + RequestFuncOutput: Response indicating success or failure of profiling command + """ + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + output = RequestFuncOutput() + try: + async with session.post(url=api_url) as response: + if response.status == 200: + output.success = True + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + return output + + +ASYNC_REQUEST_FUNCS = { + "sglang": async_request_openai_completions, + "vllm": async_request_openai_completions, + "lmdeploy": async_request_openai_completions, +} + + +@dataclass +class BenchmarkMetrics: + """Comprehensive metrics collected during benchmark execution. + + Attributes: + completed: Number of successfully completed requests + total_input: Total number of input tokens processed + total_output: Total number of output tokens generated + total_output_retokenized: Total output tokens after retokenization + request_throughput: Requests processed per second + input_throughput: Input tokens processed per second + output_throughput: Output tokens generated per second + output_throughput_retokenized: Retokenized output tokens per second + total_throughput: Combined input/output tokens per second + total_throughput_retokenized: Combined throughput with retokenized output + mean_ttft_ms: Mean time-to-first-token in milliseconds + median_ttft_ms: Median time-to-first-token in milliseconds + std_ttft_ms: Standard deviation of TTFT in milliseconds + p90_ttft_ms: 90th percentile TTFT in milliseconds + p99_ttft_ms: 99th percentile TTFT in milliseconds + mean_tpot_ms: Mean time-per-output-token in milliseconds + median_tpot_ms: Median time-per-output-token in milliseconds + std_tpot_ms: Standard deviation of TPOT in milliseconds + p90_tpot_ms: 90th percentile TPOT in milliseconds + p99_tpot_ms: 99th percentile TPOT in milliseconds + mean_itl_ms: Mean inter-token latency in milliseconds + median_itl_ms: Median inter-token latency in milliseconds + std_itl_ms: Standard deviation of ITL in milliseconds + p90_itl_ms: 90th percentile ITL in milliseconds + p99_itl_ms: 99th percentile ITL in milliseconds + mean_e2e_latency_ms: Mean end-to-end latency in milliseconds + median_e2e_latency_ms: Median end-to-end latency in milliseconds + std_e2e_latency_ms: Standard deviation of E2E latency in milliseconds + p99_e2e_latency_ms: 99th percentile E2E latency in milliseconds + concurrency: Average concurrency level during benchmark + """ + + completed: int + total_input: int + total_output: int + total_output_retokenized: int + request_throughput: float + input_throughput: float + output_throughput: float + output_throughput_retokenized: float + total_throughput: float + total_throughput_retokenized: float + mean_ttft_ms: float + median_ttft_ms: float + std_ttft_ms: float + p90_ttft_ms: float + p99_ttft_ms: float + mean_tpot_ms: float + median_tpot_ms: float + std_tpot_ms: float + p90_tpot_ms: float + p99_tpot_ms: float + mean_itl_ms: float + median_itl_ms: float + std_itl_ms: float + p90_itl_ms: float + p99_itl_ms: float + mean_e2e_latency_ms: float + median_e2e_latency_ms: float + std_e2e_latency_ms: float + p99_e2e_latency_ms: float + concurrency: float + + +async def get_requests( + input_requests_queue: asyncio.Queue, + request_rate: float, + num_actual_requests: int, +) -> AsyncGenerator[RequestFuncInput, None]: + """Generate requests at specified rate using Poisson process. + + Args: + input_requests_queue: Queue containing prepared request inputs + request_rate: Target requests per second (inf for immediate) + num_actual_requests: Total number of requests to generate + + Yields: + RequestFuncInput: Individual request configurations at controlled intervals + """ + for _ in range(num_actual_requests): + try: + request = await asyncio.wait_for(input_requests_queue.get(), timeout=300) # Wait for 5 minutes then abort + except Exception as e: + print(f"exception: {e}") + break + + yield request + + if request_rate == float("inf"): + continue + + interval = np.random.exponential(1.0 / request_rate) + await asyncio.sleep(interval) + + +def calculate_metrics( + outputs: List[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, + backend: str, +) -> Tuple[BenchmarkMetrics, List[int]]: + """Calculate comprehensive benchmark metrics from request outputs. + + Args: + outputs: List of completed request outputs + dur_s: Total benchmark duration in seconds + tokenizer: Tokenizer for retokenizing generated text + backend: Backend identifier for metric calculation + + Returns: + Tuple containing: + - BenchmarkMetrics: Comprehensive performance metrics + - List[int]: Output lengths for each successful request + """ + output_lens: List[int] = [] + retokenized_output_lens: List[int] = [] + total_input = 0 + completed = 0 + itls: List[float] = [] + tpots: List[float] = [] + ttfts: List[float] = [] + e2e_latencies: List[float] = [] + output_success = 0 + for i in range(len(outputs)): + if outputs[i].success: + output_success += 1 + assert len(outputs[i].generated_text) == len(outputs[i].latency) + assert len(outputs[i].generated_text) == len(outputs[i].ttft) + for j in range(len(outputs[i].generated_text)): + output_len = outputs[i].output_len[j] + output_lens.append(output_len) + retokenized_output_len = len(tokenizer.encode(outputs[i].generated_text[j], add_special_tokens=False)) + retokenized_output_lens.append(retokenized_output_len) + total_input += outputs[i].prompt_len[j] + if output_len > 1: + tpots.append((outputs[i].latency[j] - outputs[i].ttft[j]) / (output_len - 1)) + + completed += 1 + itls += outputs[i].itl + ttfts += outputs[i].ttft + e2e_latencies += outputs[i].latency + + else: + output_lens.append(0) + retokenized_output_lens.append(0) + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " "on the benchmark arguments.", + stacklevel=2, + ) + metrics = BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=sum(output_lens), + total_output_retokenized=sum(retokenized_output_lens), + request_throughput=completed / dur_s, + input_throughput=total_input / dur_s, + output_throughput=sum(output_lens) / dur_s, + output_throughput_retokenized=sum(retokenized_output_lens) / dur_s, + total_throughput=(total_input + sum(output_lens)) / dur_s, + total_throughput_retokenized=(total_input + sum(retokenized_output_lens)) / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend + median_ttft_ms=np.median(ttfts or 0) * 1000, + std_ttft_ms=np.std(ttfts or 0) * 1000, + p90_ttft_ms=np.percentile(ttfts or 0, 90) * 1000, + p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, + mean_tpot_ms=np.mean(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + std_tpot_ms=np.std(tpots or 0) * 1000, + p90_tpot_ms=np.percentile(tpots or 0, 90) * 1000, + p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, + mean_itl_ms=np.mean(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + std_itl_ms=np.std(itls or 0) * 1000, + p90_itl_ms=np.percentile(itls or 0, 90) * 1000, + p99_itl_ms=np.percentile(itls or 0, 99) * 1000, + mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000, + median_e2e_latency_ms=np.median(e2e_latencies) * 1000, + std_e2e_latency_ms=np.std(e2e_latencies) * 1000, + p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000, + concurrency=np.sum(e2e_latencies) / dur_s, + ) + + return metrics, output_lens + + +async def benchmark( + backend: str, + api_url: str, + base_url: str, + model_id: str, + tokenizer: PreTrainedTokenizerBase, + input_requests: SampleOutput, + request_rate: float, + max_concurrency: Optional[int], + disable_tqdm: bool, + lora_name: str, + extra_request_body: Dict[str, Any], + profile: bool, + enable_shared_prefix: bool, +): + """Execute the main benchmark against a language model endpoint. + + Args: + backend: Backend type (sglang, vllm, etc.) + api_url: Complete API endpoint URL + base_url: Base URL for the service + model_id: Model identifier + tokenizer: Tokenizer for text processing + input_requests: Prepared benchmark requests + request_rate: Target requests per second + max_concurrency: Maximum concurrent requests (optional) + disable_tqdm: Whether to disable progress bar + lora_name: LoRA adapter name (optional) + extra_request_body: Additional request parameters + profile: Whether to enable profiling + enable_shared_prefix: Whether to enable shared prefix optimization + + Returns: + dict: Comprehensive benchmark results including all metrics + """ + if backend in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS[backend] + else: + raise ValueError(f"Unknown backend: {backend}") + + # Limit concurrency + # From https://github.com/vllm-project/vllm/pull/9390 + semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None + + async def limited_request_func(request_func_input, queue, tokenizer, pbar): + if semaphore is None: + return await request_func( + request_func_input=request_func_input, + queue=queue, + tokenizer=tokenizer, + pbar=pbar, + ) + async with semaphore: + return await request_func( + request_func_input=request_func_input, + queue=queue, + tokenizer=tokenizer, + pbar=pbar, + ) + + num_actual_requests = sum(len(r) for r in input_requests) + print(f"Num of shared prefixes or conversations: {len(input_requests)}") + print(f"Num of total requests: {num_actual_requests}") + + # flatten the requests for shared prefix + if enable_shared_prefix: + input_requests = [[r] for requests in input_requests for r in requests] + inputs_requests_queue = asyncio.Queue(maxsize=len(input_requests)) + print("Starting initial single prompt test run...") + # NOTE: Just use the first request of the first conversation for warmup + test_input = RequestFuncInput( + model=model_id, + prompts=input_requests[0][:1], + api_url=api_url, + lora_name=lora_name, + extra_request_body=extra_request_body, + ) + test_output = await request_func(request_func_input=test_input, queue=inputs_requests_queue, tokenizer=tokenizer) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark arguments " + f"are correctly specified. Error: {test_output.error}" + ) + else: + print("Initial test run completed. Starting main benchmark run...") + + # Check the states + assert inputs_requests_queue.empty() + + # Flush cache + if "sglang" in backend: + requests.post(base_url + "/flush_cache") + + time.sleep(1.0) + + # Start profiler + if profile: + print("Starting profiler...") + profile_output = await async_request_profile(api_url=base_url + "/start_profile") + if profile_output.success: + print("Profiler started") + + for request in input_requests: + request_func_input = RequestFuncInput( + model=model_id, + prompts=request, + api_url=api_url, + lora_name=lora_name, + extra_request_body=extra_request_body, + ) + inputs_requests_queue.put_nowait(request_func_input) + if ( + not args.enable_multiturn + and not args.enable_shared_prefix + and not args.dataset_name == "generated-shared-prefix" + ): + assert len(input_requests) == num_actual_requests + + pbar = None if disable_tqdm else tqdm(total=num_actual_requests) + + benchmark_start_time = time.perf_counter() + tasks: List[asyncio.Task] = [] + async for request in get_requests(inputs_requests_queue, request_rate, num_actual_requests): + tasks.append( + asyncio.create_task( + limited_request_func( + request_func_input=request, + queue=inputs_requests_queue, + tokenizer=tokenizer, + pbar=pbar, + ) + ) + ) + outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + + # Stop profiler + if profile: + print("Stopping profiler...") + profile_output = await async_request_profile(api_url=base_url + "/stop_profile") + if profile_output.success: + print("Profiler stopped") + + if pbar is not None: + pbar.close() + + # Compute metrics and print results + benchmark_duration = time.perf_counter() - benchmark_start_time + metrics, output_lens = calculate_metrics( + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + backend=backend, + ) + + print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) + print("{:<40} {:<10}".format("Backend:", backend)) + print("{:<40} {:<10}".format("Traffic request rate:", request_rate)) + print( + "{:<40} {:<10}".format( + "Max request concurrency:", + max_concurrency if max_concurrency else "not set", + ) + ) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) + print("{:<40} {:<10}".format("Total generated tokens (retokenized):", metrics.total_output_retokenized)) + print("{:<40} {:<10.2f}".format("Request throughput (req/s):", metrics.request_throughput)) + print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):", metrics.input_throughput)) + print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput)) + print("{:<40} {:<10.2f}".format("Total token throughput (tok/s):", metrics.total_throughput)) + print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency)) + print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)) + print("{:<40} {:<10.2f}".format("Median E2E Latency (ms):", metrics.median_e2e_latency_ms)) + print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) + print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms)) + print("{:<40} {:<10.2f}".format("P90 TTFT (ms):", metrics.p90_ttft_ms)) + print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) + print("{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms)) + print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms)) + print("{:<40} {:<10.2f}".format("P90 TPOT (ms):", metrics.p90_tpot_ms)) + print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) + print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) + print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) + print("{:<40} {:<10.2f}".format("P90 ITL (ms):", metrics.p90_itl_ms)) + print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) + print("=" * 50) + + if ( + metrics.median_ttft_ms is not None + and metrics.mean_itl_ms is not None + and metrics.output_throughput is not None + ): + result = { + # Arguments + "backend": args.backend, + "dataset_name": args.dataset_name, + "request_rate": request_rate, + "max_concurrency": max_concurrency, + "fixed_output_len": args.fixed_output_len, + "random_input_len": args.random_input_len, + "random_output_len": args.random_output_len, + "random_range_ratio": args.random_range_ratio, + # Results + "duration": benchmark_duration, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "total_output_tokens_retokenized": metrics.total_output_retokenized, + "request_throughput": metrics.request_throughput, + "input_throughput": metrics.input_throughput, + "output_throughput": metrics.output_throughput, + "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, + "median_e2e_latency_ms": metrics.median_e2e_latency_ms, + "std_e2e_latency_ms": metrics.std_e2e_latency_ms, + "p99_e2e_latency_ms": metrics.p99_e2e_latency_ms, + "mean_ttft_ms": metrics.mean_ttft_ms, + "median_ttft_ms": metrics.median_ttft_ms, + "std_ttft_ms": metrics.std_ttft_ms, + "p99_ttft_ms": metrics.p99_ttft_ms, + "mean_tpot_ms": metrics.mean_tpot_ms, + "median_tpot_ms": metrics.median_tpot_ms, + "std_tpot_ms": metrics.std_tpot_ms, + "p99_tpot_ms": metrics.p99_tpot_ms, + "mean_itl_ms": metrics.mean_itl_ms, + "median_itl_ms": metrics.median_itl_ms, + "std_itl_ms": metrics.std_itl_ms, + "p99_itl_ms": metrics.p99_itl_ms, + "concurrency": metrics.concurrency, + "completed": metrics.completed, + } + else: + print(f"Error running benchmark for request rate: {request_rate}") + print("-" * 30) + + # Determine output file name + if args.output_file: + output_file_name = args.output_file + else: + now = datetime.now().strftime("%m%d") + if args.dataset_name == "random": + output_file_name = ( + f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl" + ) + else: + output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.dataset_name}.jsonl" + + # Append results to a JSONL file + with open(output_file_name, "a") as file: + file.write(json.dumps(result) + "\n") + + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "total_output_tokens_retokenized": metrics.total_output_retokenized, + "request_throughput": metrics.request_throughput, + "input_throughput": metrics.input_throughput, + "output_throughput": metrics.output_throughput, + "mean_ttft_ms": metrics.mean_ttft_ms, + "median_ttft_ms": metrics.median_ttft_ms, + "std_ttft_ms": metrics.std_ttft_ms, + "p90_ttft_ms": metrics.p90_ttft_ms, + "p99_ttft_ms": metrics.p99_ttft_ms, + "mean_tpot_ms": metrics.mean_tpot_ms, + "median_tpot_ms": metrics.median_tpot_ms, + "std_tpot_ms": metrics.std_tpot_ms, + "p90_tpot_ms": metrics.p90_tpot_ms, + "p99_tpot_ms": metrics.p99_tpot_ms, + "mean_itl_ms": metrics.mean_itl_ms, + "median_itl_ms": metrics.median_itl_ms, + "std_itl_ms": metrics.std_itl_ms, + "p90_itl_ms": metrics.p90_itl_ms, + "p99_itl_ms": metrics.p99_itl_ms, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, + "median_e2e_latency_ms": metrics.median_e2e_latency_ms, + } + return result + + +def run_benchmark(args_: argparse.Namespace): + """Run a complete benchmark using the provided arguments. + + Args: + args_: Namespace containing all benchmark configuration options + + Returns: + dict: Benchmark results from the async benchmark execution + """ + global args + args = args_ + + # Set default value for max_concurrency if not present + if not hasattr(args, "max_concurrency"): + args.max_concurrency = None + + # Set global environments + set_ulimit() + random.seed(args.seed) + np.random.seed(args.seed) + + extra_request_body = {} + if args.extra_request_body: + extra_request_body = json.loads(args.extra_request_body) + + # Set url + if args.port is None: + args.port = { + "sglang": 30000, + "lmdeploy": 23333, + "vllm": 8000, + }.get(args.backend, 30000) + + model_url = f"{args.base_url}/v1/models" if args.base_url else f"http://{args.host}:{args.port}/v1/models" + + if args.backend in ["sglang", "vllm", "lmdeploy"]: + api_url = ( + f"{args.base_url}/v1/chat/completions" + if args.base_url + else f"http://{args.host}:{args.port}/v1/chat/completions" + ) + base_url = f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url + + # Get model name + if args.model is None: + if args.backend == "truss": + print( + "Please provide a model with `--model` when using truss backend. e.g. --model " + "meta-llama/Llama-3.1-8B-Instruct" + ) + sys.exit(1) + try: + response = requests.get(model_url) + model_list = response.json().get("data", []) + args.model = model_list[0]["id"] if model_list else None + except Exception as e: + print(f"Failed to fetch model from {model_url}. Error: {e}") + print("Please specify the correct host and port using `--host` and `--port`.") + sys.exit(1) + + if args.model is None: + print("No model specified or found. Please provide a model using `--model`.") + sys.exit(1) + + # Dataset compatibility check + if args.enable_multiturn: + # TODO: Support multiturn for random + if args.dataset_name not in ["sharegpt", "ultrachat", "loogle", "nextqa"]: + print("Multiturn conversation is only supported for sharegpt, ultrachat, loogle, and nextqa datasets.") + sys.exit(1) + + if args.enable_shared_prefix: + if args.dataset_name not in ["loogle", "nextqa"]: + print("Shared prefix is only supported for loogle and nextqa datasets.") + sys.exit(1) + + print(f"{args}\n") + + # Read dataset + backend = args.backend + model_id = args.model + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + + tokenizer = get_tokenizer(tokenizer_id) + + input_requests = get_dataset(args, tokenizer) + + return asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + base_url=base_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=args.request_rate, + max_concurrency=args.max_concurrency, + disable_tqdm=args.disable_tqdm, + lora_name=args.lora_name, + extra_request_body=extra_request_body, + profile=args.profile, + enable_shared_prefix=args.enable_shared_prefix, + ) + ) + + +if __name__ == "__main__": + parser = ArgumentParser(description="Benchmark the online serving throughput.") + parser.add_argument( + "--backend", + type=str, + choices=list(ASYNC_REQUEST_FUNCS.keys()), + default="sglang", + help="Must specify a backend, depending on the LLM Inference Engine.", + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + parser.add_argument("--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0.") + parser.add_argument( + "--port", + type=int, + help="If not set, the default port is configured according to its default value for different LLM " + "Inference Engines.", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="sharegpt", + choices=[ + "sharegpt", + "random", + "generated-shared-prefix", + "ultrachat", + "loogle", + "nextqa", + ], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument("--dataset-path", type=str, default="", help="Path to the dataset.") + parser.add_argument( + "--model", + type=str, + help="Name or path of the model. If not set, the default model will request /v1/models for conf.", + ) + parser.add_argument( + "--tokenizer", + type=str, + help="Name or path of the tokenizer. If not set, using the model conf.", + ) + parser.add_argument( + "--chat-template", + type=str, + help="The buliltin chat template name or the path of the chat template file. This is only used for " + "OpenAI-compatible API server.", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompts to process. Default is 1000.", + ) + parser.add_argument( + "--fixed-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length from the dataset.", + ) + parser.add_argument( + "--sharegpt-context-len", + type=int, + default=None, + help="The context length of the model for the ShareGPT dataset. Requests longer than the context " + "length will be dropped.", + ) + parser.add_argument( + "--random-input-len", + type=int, + default=1024, + help="Number of input tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-output-len", + default=1024, + type=int, + help="Number of output tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range of sampled ratio of input/output length, " "used only for random dataset.", + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, then all the requests are sent at time 0. " + "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.", + ) + parser.add_argument( + "--max-concurrency", + type=int, + default=None, + help="Maximum number of concurrent requests. This can be used " + "to help simulate an environment where a higher level component " + "is enforcing a maximum number of concurrent requests. While the " + "--request-rate argument controls the rate at which requests are " + "initiated, this argument will control how many are actually allowed " + "to execute at a time. This means that when used in combination, the " + "actual request rate may be lower than specified with --request-rate, " + "if the server is not processing requests fast enough to keep up.", + ) + parser.add_argument( + "--multi", + action="store_true", + help="Use request rate range rather than single value.", + ) + parser.add_argument( + "--request-rate-range", + type=str, + default="2,34,2", + help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a " + "list of request rates, requiring the parameters to not equal three.", + ) + parser.add_argument("--output-file", type=str, help="Output JSONL file name.") + parser.add_argument( + "--enable-multiturn", + action="store_true", + help="Enable multiturn chat for online serving benchmarking. " + "This option is effective on the following datasets: " + "sharegpt, ultrachat, loogle, nextqa", + ) + parser.add_argument( + "--enable-shared-prefix", + action="store_true", + help="Enable shared prefix for online serving benchmarking. " + "This option is effective on the following datasets: " + "loogle, nextqa", + ) + + parser.add_argument( + "--disable-shuffle", + action="store_true", + help="Disable shuffling datasets. This is useful to generate stable output " "in benchmarking", + ) + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Specify to disable tqdm progress bar.", + ) + parser.add_argument( + "--disable-stream", + action="store_true", + help="Disable streaming mode.", + ) + parser.add_argument( + "--return-logprob", + action="store_true", + help="Return logprob.", + ) + parser.add_argument("--seed", type=int, default=1, help="The random seed.") + parser.add_argument( + "--disable-ignore-eos", + action="store_true", + help="Disable ignoring EOS.", + ) + parser.add_argument( + "--extra-request-body", + metavar='{"key1": "value1", "key2": "value2"}', + type=str, + help="Append given JSON object to the request payload. You can use this to specify" + "additional generate params like sampling params.", + ) + parser.add_argument( + "--apply-chat-template", + action="store_true", + help="Apply chat template", + ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " "SGLANG_TORCH_PROFILER_DIR to enable profiler.", + ) + parser.add_argument( + "--lora-name", + type=str, + default=None, + help="The name of LoRA adapter", + ) + + group = parser.add_argument_group("generated-shared-prefix dataset arguments") + group.add_argument( + "--gsp-num-groups", + type=int, + default=64, + help="Number of system prompt groups for generated-shared-prefix dataset", + ) + group.add_argument( + "--gsp-prompts-per-group", + type=int, + default=16, + help="Number of prompts per system prompt group for generated-shared-prefix dataset", + ) + group.add_argument( + "--gsp-system-prompt-len", + type=int, + default=2048, + help="Target length in tokens for system prompts in generated-shared-prefix dataset", + ) + group.add_argument( + "--gsp-question-len", + type=int, + default=128, + help="Target length in tokens for questions in generated-shared-prefix dataset", + ) + group.add_argument( + "--gsp-output-len", + type=int, + default=256, + help="Target length in tokens for outputs in generated-shared-prefix dataset", + ) + # videos specific + parser.add_argument( + "--max-frames", + type=int, + default=sys.maxsize, + help="The maximum number of frames to extract from each video. " + "This option is specific to the nextqa dataset (video benchmark). ", + ) + args = parser.parse_args() + + if args.enable_multiturn and args.enable_shared_prefix: + parser.error("--enable-multiturn and --enable-shared-prefix cannot be set at the same time.") + + run_benchmark(args) diff --git a/assets/training/endpoint_evaluation/src/data_processing.py b/assets/training/endpoint_evaluation/src/data_processing.py new file mode 100644 index 0000000000..2e4af03454 --- /dev/null +++ b/assets/training/endpoint_evaluation/src/data_processing.py @@ -0,0 +1,711 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Data processing utilities for benchmark datasets. + +This module provides functions to load, filter, and prepare various datasets +for benchmarking including ShareGPT, UltraChat, Loogle, NextQA, and synthetic +datasets. It handles tokenization, length filtering, and format conversion. +""" + +import json +import os +import pickle +import random +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import numpy as np +from nextqa import NExTQALoader + +# from nextqa.video import , VideoPrompt +from tqdm.asyncio import tqdm +from transformers import PreTrainedTokenizerBase + +from sglang.bench_serving import ( + download_and_cache_file, +) +from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path +from sglang.srt.entrypoints.openai.protocol import ChatCompletionMessageContentPart +from sglang.utils import encode_video_base64 + + +SHAREGPT_URL = ( + "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/" + "ShareGPT_V3_unfiltered_cleaned_split.json" +) + +# type of content fields, can be only prompts or with images/videos +MsgContent = Union[str, List[ChatCompletionMessageContentPart]] + +# A list of all the conversations. Each conversation is a list of +# tuples. If multiturn is not enabled, the length of list is 1, +# containing only the first Q&A pair. +# For the shared prefix workload (synthetic, loogle, nextqa), it +# is a list of conversations sharing the same prefix (synthetic, +# doc, video) +SampleOutput = List[List[Tuple[MsgContent, int, int]]] + + +def common_filter_chat( + num_requests: int, + new_dataset: List, + tokenizer: PreTrainedTokenizerBase, + min_prompt_len: Optional[int], + min_output_len: Optional[int], + max_prompt_len: Optional[int], + max_output_len: Optional[int], + fixed_output_len: Optional[int], +) -> SampleOutput: + """Filter chat dataset based on token length constraints. + + Args: + num_requests: Number of requests to generate + new_dataset: Raw dataset conversations + tokenizer: Tokenizer for length calculation + min_prompt_len: Minimum prompt length in tokens (optional) + min_output_len: Minimum output length in tokens (optional) + max_prompt_len: Maximum prompt length in tokens (optional) + max_output_len: Maximum output length in tokens (optional) + fixed_output_len: Fixed output length override (optional) + + Returns: + SampleOutput: Filtered dataset with conversations meeting length criteria + """ + # Filter out sequences that are too long or too short + filtered_dataset: SampleOutput = [] + k = 0 + input_tokens = 0 + output_tokens = 0 + while k < num_requests: + for i in range(len(new_dataset)): + if k == num_requests: + break + processed = [] + for j in new_dataset[i]: + # Tokenize the prompts and completions. + prompt = j[0] + prompt_token_ids = tokenizer.encode(prompt) + prompt_len = len(prompt_token_ids) + + completion = j[1] + completion_token_ids = tokenizer.encode(completion) + output_len = len(completion_token_ids) if fixed_output_len is None else fixed_output_len + if ( + min_prompt_len is not None + and prompt_len < min_prompt_len + or min_output_len is not None + and output_len < min_output_len + or max_prompt_len is not None + and prompt_len > max_prompt_len + or max_output_len is not None + and output_len > max_output_len + ): + # Prune too short sequences. + continue + input_tokens += prompt_len + output_tokens += output_len + processed.append((prompt, prompt_len, output_len)) + if len(processed) != 0: + filtered_dataset.append(processed) + k += 1 + + print(f"#Input tokens: {input_tokens}") + print(f"#Output tokens: {output_tokens}") + return filtered_dataset + + +def sample_sharegpt_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + disable_shuffle: bool = False, + enable_multiturn: bool = True, + fixed_output_len: Optional[int] = None, +) -> SampleOutput: + """Sample requests from ShareGPT dataset. + + Args: + dataset_path: Path to ShareGPT JSON file + num_requests: Number of conversations to sample + tokenizer: Tokenizer for processing text + disable_shuffle: Whether to disable dataset shuffling + enable_multiturn: Whether to include full conversations or just first turn + fixed_output_len: Fixed output length override (optional) + + Returns: + SampleOutput: Processed ShareGPT conversations ready for benchmarking + """ + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Download sharegpt if necessary + if not os.path.isfile(dataset_path): + dataset_path = download_and_cache_file(SHAREGPT_URL) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + + # Keep one conversation in one list + new_dataset = [] + for data in dataset: + if len(data["conversations"]) % 2 != 0: + continue + if data["conversations"][0]["from"] != "human": + continue + chat = [] + total_len = 2 + if enable_multiturn: + total_len = len(data["conversations"]) + for i in range(0, total_len, 2): + # One user One Assistant + chat.append( + ( + data["conversations"][i]["value"], + data["conversations"][i + 1]["value"], + ) + ) + new_dataset.append(chat) + + if not disable_shuffle: + # Shuffle the dataset. + random.shuffle(new_dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: SampleOutput = common_filter_chat( + num_requests, new_dataset, tokenizer, 4, 4, None, None, fixed_output_len + ) + return filtered_dataset + + +def sample_ultrachat_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + disable_shuffle: bool = False, + enable_multiturn: bool = True, + fixed_output_len: Optional[int] = None, +) -> SampleOutput: + """Sample requests from UltraChat dataset. + + Args: + dataset_path: Path to UltraChat JSONL file + num_requests: Number of conversations to sample + tokenizer: Tokenizer for processing text + disable_shuffle: Whether to disable dataset shuffling + enable_multiturn: Whether to include full conversations or just first turn + fixed_output_len: Fixed output length override (optional) + + Returns: + SampleOutput: Processed UltraChat conversations ready for benchmarking + """ + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Load the dataset + dataset = [] + with open(dataset_path) as f: + while True: + line = f.readline() + if not line: + break + dataset.append(json.loads(line)) + + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["data"]) >= 2] + + # Keep one conversation in one list + new_dataset = [] + for data in dataset: + if len(data["data"]) % 2 != 0: + continue + chat = [] + total_len = 2 + if enable_multiturn: + total_len = len(data["data"]) + for i in range(0, total_len, 2): + # One user One Assistant + chat.append((data["data"][i], data["data"][i + 1])) + new_dataset.append(chat) + + # Shuffle the dataset. + if not disable_shuffle: + random.shuffle(new_dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: SampleOutput = common_filter_chat( + num_requests, new_dataset, tokenizer, 4, 4, None, None, fixed_output_len + ) + return filtered_dataset + + +def sample_loogle_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + disable_shuffle: bool = False, + enable_multiturn: bool = True, + enable_shared_prefix: bool = False, + fixed_output_len: Optional[int] = None, +) -> SampleOutput: + """Sample requests from Loogle dataset with document QA pairs. + + Args: + dataset_path: Path to Loogle JSONL file + num_requests: Number of conversations to sample + tokenizer: Tokenizer for processing text + disable_shuffle: Whether to disable dataset shuffling + enable_multiturn: Whether to include multiple QA pairs per document + enable_shared_prefix: Whether to use shared document prefix optimization + fixed_output_len: Fixed output length override (optional) + + Returns: + SampleOutput: Processed Loogle conversations with document context + """ + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Load the dataset + dataset = [] + with open(dataset_path) as f: + while True: + line = f.readline() + if not line: + break + dataset.append(json.loads(line)) + + # Keep one conversation in one list + new_dataset = [] + # TODO: Add shared prefix support for loogle + # NOTE: Now we preprocess it only for chat + for data in dataset: + chat = [] + if "qa_pairs" not in data or data["qa_pairs"] == "none" or len(data["qa_pairs"]) == 0: + # If Q is none (for summarization), + # We add a question for summarization + # And keep the summary up to 1024 words + chat.append( + ( + "Input: " + data["input"] + " Question: " + "Please summarize the input", + data["input"][:1024], + ) + ) + new_dataset.append(chat) + else: + qa_pairs = eval(data["qa_pairs"]) + for i, qa in enumerate(qa_pairs): + if i == 0 or enable_shared_prefix: + # Combine input with the first Q + chat.append(("Input: " + data["input"] + " Question: " + qa["Q"], qa["A"])) + elif enable_multiturn: + chat.append((qa["Q"], qa["A"])) + + new_dataset.append(chat) + + # Shuffle the dataset. + if not disable_shuffle: + random.shuffle(new_dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: SampleOutput = common_filter_chat( + num_requests, new_dataset, tokenizer, 4, None, None, None, fixed_output_len + ) + return filtered_dataset + + +def sample_nextqa_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + max_frames: int, # Specific for video + model_path: str, + disable_shuffle: bool = False, + enable_multiturn: bool = True, # No multiturn support for now + backend: str = "sglang-oai", + chat_template_name: Optional[str] = None, + fixed_output_len: Optional[int] = None, +) -> SampleOutput: + """Sample requests from NextQA video dataset for video question answering. + + Creates multimodal requests with video content and text questions. + Encodes videos as base64 and combines with text prompts. + + Example of messages: + message = { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": base64_data}}, + {"type": "text", "text": video.prompt}, + ], + } + + Args: + dataset_path: Directory containing NextQA video files + num_requests: Number of video QA pairs to sample + tokenizer: Tokenizer for processing text prompts + max_frames: Maximum number of video frames to extract + model_path: Path to model for chat template resolution + disable_shuffle: Whether to disable dataset shuffling + enable_multiturn: Whether to enable multiturn (not currently supported) + backend: Backend type for prompt formatting + chat_template_name: Optional chat template name override + fixed_output_len: Fixed output length for responses + + Returns: + SampleOutput: Processed video QA requests with base64 encoded videos + """ + if fixed_output_len is None: + fixed_output_len = 4096 + + # TODO: Check for multiturn + dataset = NExTQALoader(video_dir=dataset_path, max_frames=max_frames) + new_dataset = [] + for v in dataset: + new_dataset.append(v) + + if not disable_shuffle: + random.shuffle(new_dataset) + + # TODO: prompt len can get from server side + filtered_dataset = [] + k = 0 + while k < num_requests: + for i in range(len(new_dataset)): + if k == num_requests: + break + + video = new_dataset[i] + + # text prompt + prompt = video.prompt + + # NOTE: Chat Template is a must for video benchmark because we have to + # add special image token for later expansion + if backend == "sglang" or backend == "sglang-native": + if "chat_template" in tokenizer.init_kwargs: + chat_template = get_chat_template(tokenizer.get_chat_template()) + elif chat_template_name is not None: + chat_template = get_chat_template(chat_template_name) + else: + chat_template = get_chat_template_by_model_path(model_path) + prompt = chat_template.image_token + prompt + + prompt_token_ids = tokenizer(prompt).input_ids + prompt_len = len(prompt_token_ids) + output_len = fixed_output_len # max output len, not real output len + + # video input + base64_data = encode_video_base64(video.path, video.num_frames) + + # NOTE: This will be replaced by the expanded length from the server + prompt_len += video.num_frames + + # add to content + content = [ + {"type": "image_url", "image_url": {"url": base64_data}}, + {"type": "text", "text": prompt}, + ] + + filtered_dataset.append([(content, prompt_len, output_len)]) + k += 1 + return filtered_dataset + + +def sample_random_requests( + input_len: int, + output_len: int, + num_prompts: int, + range_ratio: float, + tokenizer: PreTrainedTokenizerBase, + dataset_path: str, + disable_shuffle: bool = False, +) -> SampleOutput: + """Generate random benchmark requests with specified token lengths. + + Args: + input_len: Target input length in tokens + output_len: Target output length in tokens + num_prompts: Number of prompts to generate + range_ratio: Ratio for length variation (0.0 = exact, 1.0 = full range) + tokenizer: Tokenizer for processing text + dataset_path: Path to source dataset for token sampling + disable_shuffle: Whether to disable shuffling + + Returns: + SampleOutput: Generated random requests with specified characteristics + """ + input_lens = np.random.randint( + max(int(input_len * range_ratio), 1), + input_len + 1, + size=num_prompts, + ) + output_lens = np.random.randint( + int(output_len * range_ratio), + output_len + 1, + size=num_prompts, + ) + + if True: + # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens + + # Download sharegpt if necessary + if not os.path.isfile(dataset_path): + dataset_path = download_and_cache_file(SHAREGPT_URL) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [(data["conversations"][0]["value"], data["conversations"][1]["value"]) for data in dataset] + + if not disable_shuffle: + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + input_requests: SampleOutput = [] + for data in dataset: + i = len(input_requests) + if i == num_prompts: + break + + # Tokenize the prompts and completions. + prompt = data[0] + prompt_token_ids = tokenizer.encode(prompt) + prompt_len = len(prompt_token_ids) + + # Skip empty prompt + if prompt_len == 0: + continue + + if prompt_len > input_lens[i]: + input_ids = prompt_token_ids[: input_lens[i]] + else: + ratio = (input_lens[i] + prompt_len - 1) // prompt_len + input_ids = (prompt_token_ids * ratio)[: input_lens[i]] + prompt = tokenizer.decode(input_ids) + input_requests.append([(prompt, int(input_lens[i]), int(output_lens[i]))]) + else: + # Sample token ids from random integers. This can cause some NaN issues. + offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) + input_requests = [] + for i in range(num_prompts): + prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])]) + input_requests.append([(prompt, int(input_lens[i]), int(output_lens[i]))]) + + print(f"#Input tokens: {np.sum(input_lens)}") + print(f"#Output tokens: {np.sum(output_lens)}") + return input_requests + + +def gen_prompt(tokenizer, token_num): + """Generate a random prompt of specified token length. + + Args: + tokenizer: Tokenizer to use for text generation + token_num: Number of tokens to generate + + Returns: + str: Generated prompt with approximately the specified token count + """ + """Generate a random prompt of specified token length using tokenizer vocabulary.""" + all_available_tokens = list(tokenizer.get_vocab().values()) + selected_tokens = random.choices(all_available_tokens, k=token_num) + return tokenizer.decode(selected_tokens) + + +def get_gen_prefix_cache_path(args, tokenizer): + """Create cache directory path for generated shared prefix dataset. + + Args: + args: Arguments containing generation parameters + tokenizer: Tokenizer used for generation + + Returns: + Path: Cache file path based on generation parameters + """ + """Create cache directory under ~/.cache/sglang/benchmark""" + cache_dir = Path.home() / ".cache" / "sglang" / "benchmark" + + # Create a unique cache filename based on the generation parameters + cache_key = ( + f"gsp_prefix_{args.gsp_num_groups}_{args.gsp_prompts_per_group}_" + f"{args.gsp_system_prompt_len}_{args.gsp_question_len}_{args.gsp_output_len}_" + f"{tokenizer.__class__.__name__}.pkl" + ) + return cache_dir / cache_key + + +def sample_generated_shared_prefix_requests( + num_groups: int, + prompts_per_group: int, + system_prompt_len: int, + question_len: int, + output_len: int, + tokenizer: PreTrainedTokenizerBase, + args, + disable_shuffle: bool = False, +) -> SampleOutput: + """Generate synthetic benchmark requests with shared system prompts. + + Creates groups of requests sharing common system prompts to test + prefix caching and shared context optimization. Uses caching to + avoid regenerating identical datasets. + + Args: + num_groups: Number of system prompt groups to create + prompts_per_group: Number of questions per system prompt + system_prompt_len: Target length in tokens for system prompts + question_len: Target length in tokens for individual questions + output_len: Target length in tokens for expected outputs + tokenizer: Tokenizer for generating and measuring text + args: Full arguments object for cache key generation + disable_shuffle: Whether to disable shuffling of groups + + Returns: + SampleOutput: Generated requests grouped by shared system prompts + """ + cache_path = get_gen_prefix_cache_path(args, tokenizer) + + # Try to load from cache first + if cache_path.exists(): + print(f"\nLoading cached generated input data from {cache_path}") + with open(cache_path, "rb") as f: + return pickle.load(f) + + print("\nGenerating new input data...") + + # Generate system prompts for each group + system_prompts = [] + for _ in range(num_groups): + system_prompt = gen_prompt(tokenizer, system_prompt_len) + system_prompts.append(system_prompt) + + # Generate questions + questions = [] + for _ in range(num_groups * prompts_per_group): + question = gen_prompt(tokenizer, question_len) + questions.append(question) + + # Combine system prompts with questions + input_requests = [] + total_input_tokens = 0 + total_output_tokens = 0 + + for group_idx in tqdm(range(num_groups), desc="Generating system prompt"): + system_prompt = system_prompts[group_idx] + input_requests.append([]) + for prompt_idx in tqdm(range(prompts_per_group), desc="Generating questions", leave=False): + question = questions[group_idx * prompts_per_group + prompt_idx] + full_prompt = f"{system_prompt}\n\n{question}" + prompt_len = len(tokenizer.encode(full_prompt)) + input_requests[-1].append((full_prompt, prompt_len, output_len)) + total_input_tokens += prompt_len + total_output_tokens += output_len + + if not disable_shuffle: + # Shuffle questions + random.shuffle(input_requests) + + # Print statistics + print("\nGenerated shared prefix dataset statistics:") + print(f"Number of groups: {num_groups}") + print(f"Prompts per group: {prompts_per_group}") + print(f"Total prompts: {len(input_requests) * prompts_per_group}") + print(f"Total input tokens: {total_input_tokens}") + print(f"Total output tokens: {total_output_tokens}") + avg_system_prompt_len = sum(len(tokenizer.encode(sp)) for sp in system_prompts) / len(system_prompts) + avg_question_len = sum(len(tokenizer.encode(q)) for q in questions) / len(questions) + print(f"Average system prompt length: {avg_system_prompt_len:.1f} tokens") + print(f"Average question length: {avg_question_len:.1f} tokens\n") + + # Save to cache + cache_path.parent.mkdir(parents=True, exist_ok=True) + print(f"Caching generated input data to {cache_path}") + with open(cache_path, "wb") as f: + pickle.dump(input_requests, f) + + return input_requests + + +def get_dataset(args, tokenizer): + """Get processed dataset based on the specified dataset name and configuration. + + Args: + args: Arguments containing dataset configuration + tokenizer: Tokenizer for text processing + + Returns: + SampleOutput: Processed dataset ready for benchmarking + + Raises: + ValueError: If dataset_name is not recognized + """ + if args.dataset_name == "sharegpt": + input_requests = sample_sharegpt_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + disable_shuffle=args.disable_shuffle, + enable_multiturn=args.enable_multiturn, + fixed_output_len=args.fixed_output_len, + ) + elif args.dataset_name == "ultrachat": + input_requests = sample_ultrachat_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + disable_shuffle=args.disable_shuffle, + enable_multiturn=args.enable_multiturn, + fixed_output_len=args.fixed_output_len, + ) + elif args.dataset_name == "loogle": + input_requests = sample_loogle_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + disable_shuffle=args.disable_shuffle, + enable_multiturn=args.enable_multiturn, + enable_shared_prefix=args.enable_shared_prefix, + fixed_output_len=args.fixed_output_len, + ) + elif args.dataset_name == "nextqa": + input_requests = sample_nextqa_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + max_frames=args.max_frames, + model_path=args.model, + disable_shuffle=args.disable_shuffle, + enable_multiturn=args.enable_multiturn, + backend=args.backend, + chat_template_name=args.chat_template, + fixed_output_len=args.fixed_output_len, + ) + elif args.dataset_name == "random": + input_requests = sample_random_requests( + input_len=args.random_input_len, + output_len=args.random_output_len, + num_prompts=args.num_prompts, + range_ratio=args.random_range_ratio, + tokenizer=tokenizer, + dataset_path=args.dataset_path, + ) + elif args.dataset_name == "generated-shared-prefix": + input_requests = sample_generated_shared_prefix_requests( + num_groups=args.gsp_num_groups, + prompts_per_group=args.gsp_prompts_per_group, + system_prompt_len=args.gsp_system_prompt_len, + question_len=args.gsp_question_len, + output_len=args.gsp_output_len, + args=args, + tokenizer=tokenizer, + ) + else: + raise ValueError(f"Unknown dataset: {args.dataset_name}") + return input_requests diff --git a/assets/training/endpoint_evaluation/src/helper.py b/assets/training/endpoint_evaluation/src/helper.py new file mode 100644 index 0000000000..717c375e3e --- /dev/null +++ b/assets/training/endpoint_evaluation/src/helper.py @@ -0,0 +1,212 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Helper utilities for Azure ML integration and API key management. + +This module provides functions for: +- Retrieving API keys from Azure ML workspace connections +- Creating HTTP sessions with retry policies +- Logging metrics to Azure ML runs +- Managing authentication and workspace connections +""" + +from typing import Optional, Tuple +import json +import requests +from requests.adapters import HTTPAdapter +from urllib3 import Retry + +from azureml.core import Run, Workspace +from azureml.core.run import _OfflineRun + + +RETRIABLE_STATUS_CODES = {413, 429, 500, 502, 503, 504, None} +# Define loggable metrics - only mean metrics for TPOT, TTFT, E2E latency, and ITL +LOGGABLE_METRIC_NAMES = { + # Core throughput metrics (lowercase from JSON files) + "request_throughput", + "input_throughput", + "output_throughput", + # Mean latency metrics only (no median) + "mean_e2e_latency_ms", + "mean_ttft_ms", + "mean_tpot_ms", + "mean_itl_ms", + "total_input_tokens", + "total_output_tokens", + # Uppercase versions for AML logging + "REQUEST_THROUGHPUT", + "INPUT_THROUGHPUT", + "OUTPUT_THROUGHPUT", + "MEAN_E2E_LATENCY_MS", + "MEAN_TTFT_MS", + "MEAN_TPOT_MS", + "MEAN_ITL_MS", + "TOTAL_INPUT_TOKENS", + "TOTAL_OUTPUT_TOKENS", +} + + +def _get_retry_policy(num_retry: int = 3) -> Retry: + """Create HTTP retry policy with exponential backoff. + + Args: + num_retry: Maximum number of retry attempts + + Returns: + Retry: Configured retry policy for HTTP requests + """ + """ + Request retry policy with increasing backoff. + + :return: Returns the msrest or requests REST client retry policy. + :rtype: urllib3.Retry + """ + backoff_factor = 0.4 + retry_policy = Retry( + total=num_retry, + read=num_retry, + connect=num_retry, + backoff_factor=backoff_factor, + status_forcelist=RETRIABLE_STATUS_CODES, + # By default this is True. We set it to false to get the full error trace, including url and + # status code of the last retry. Otherwise, the error message is too many 500 error responses', + # which is not useful. + raise_on_status=False, + ) + return retry_policy + + +def _create_session_with_retry(retry: int = 3) -> requests.Session: + """Create HTTP session with retry capability. + + Args: + retry: Number of retry attempts for failed requests + + Returns: + requests.Session: Configured session with retry adapters + """ + """ + Create requests.session with retry. + + :type retry: int + rtype: Response + """ + retry_policy = _get_retry_policy(num_retry=retry) + + session = requests.Session() + session.mount("https://", HTTPAdapter(max_retries=retry_policy)) + session.mount("http://", HTTPAdapter(max_retries=retry_policy)) + return session + + +def _send_post_request(url: str, headers: dict, payload: dict): + """Send HTTP POST request with retry handling. + + Args: + url: Target URL for the request + headers: HTTP headers dictionary + payload: JSON payload dictionary + + Returns: + requests.Response: HTTP response object + + Raises: + requests.exceptions.HTTPError: If request fails after retries + """ + """Send a POST request.""" + try: + with _create_session_with_retry() as session: + response = session.post(url, data=json.dumps(payload), headers=headers) + # Raise an exception if the response contains an HTTP error status code + response.raise_for_status() + except requests.exceptions.HTTPError: + raise + return response + + +def get_api_key_from_connection(connections_name: str) -> Tuple[str, Optional[str]]: + """ + Get api_key from connections_name. + + :param connections_name: Name of the connection. + :return: api_key, api_version. + """ + run = Run.get_context() + if isinstance(run, _OfflineRun): + curr_ws = Workspace.from_config() + else: + curr_ws = run.experiment.workspace + + if hasattr(curr_ws._auth, "get_token"): + bearer_token = curr_ws._auth.get_token("https://management.azure.com/.default").token + else: + bearer_token = curr_ws._auth.token + + endpoint = curr_ws.service_context._get_endpoint("api") + url_list = [ + endpoint, + "rp/workspaces/subscriptions", + curr_ws.subscription_id, + "resourcegroups", + curr_ws.resource_group, + "providers", + "Microsoft.MachineLearningServices", + "workspaces", + curr_ws.name, + "connections", + connections_name, + "listsecrets?api-version=2023-02-01-preview", + ] + + resp = _send_post_request( + "/".join(url_list), + {"Authorization": f"Bearer {bearer_token}", "content-type": "application/json"}, + {}, + ) + + credentials = resp.json()["properties"]["credentials"] + metadata = resp.json()["properties"].get("metadata", {}) + if "key" in credentials: + return credentials["key"], metadata.get("ApiVersion") + else: + if "secretAccessKey" not in credentials and "keys" in credentials: + credentials = credentials["keys"] + return credentials["secretAccessKey"], None + + +def _get_azureml_run(): + """Get active Azure ML run context if available. + + Returns: + Run or None: Azure ML run object if available, None otherwise + """ + """Get AzureML Run context if available.""" + try: + azureml_run = Run.get_context() + if azureml_run and "OfflineRun" not in azureml_run.id: + return azureml_run + except Exception as e: + print(f"Warning: Failed to get AzureML run context: {e}") + return None + + +def log_metrics(metrics: dict): + """Log metrics to Azure ML run if available. + + Args: + metrics: Dictionary of metric names and values to log + """ + """Log metrics to AzureML Run if available.""" + azureml_run = _get_azureml_run() + if azureml_run: + for key, value in metrics.items(): + # Check if the key (or key without prefix) is a loggable metric + # Support both prefixed (base_/target_) and non-prefixed keys + key_to_check = key + if key.startswith("base_") or key.startswith("target_"): + key_to_check = key.split("_", 1)[1] # Remove prefix + + if key_to_check in LOGGABLE_METRIC_NAMES: + azureml_run.log(key, value) + azureml_run.flush() diff --git a/assets/training/endpoint_evaluation/src/main.py b/assets/training/endpoint_evaluation/src/main.py new file mode 100644 index 0000000000..7e2b362c6f --- /dev/null +++ b/assets/training/endpoint_evaluation/src/main.py @@ -0,0 +1,516 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Main entry point for endpoint evaluation benchmarking. + +This module provides command-line interface and orchestration for running +performance benchmarks against both base and target endpoints using SGLang. +It supports various datasets and backend configurations, and logs metrics +to Azure ML. +""" + +import argparse +import os +import sys +import json +from copy import deepcopy + +from bench_serving import run_benchmark +from helper import get_api_key_from_connection, log_metrics + + +def parse_args(): + """Parse command-line arguments for endpoint evaluation. + + Returns: + argparse.Namespace: Parsed command-line arguments containing all + configuration options for the benchmark including endpoint URLs, + dataset settings, and benchmark parameters. + """ + parser = argparse.ArgumentParser(description="SGLang Benchmarking") + # parser.add_argument( + # "--metrics_path", + # type=str, + # required=True, + # help="Output JSON file to store the benchmarking metrics.", + # ) + parser.add_argument( + "--connection-name", + type=str, + required=True, + help="The name of the workspace connection used to fetch API key for base endpoint.", + ) + parser.add_argument( + "--target-url", + type=str, + required=True, + help="Server or API base url for target endpoint.", + ) + parser.add_argument( + "--target-connection-name", + type=str, + required=True, + help="The name of the workspace connection used to fetch API key for target endpoint.", + ) + parser.add_argument( + "--trials", + type=int, + default=10, + help="Number of trials to run the benchmark, result will be averaged over all trials.", + ) + parser.add_argument( + "--base-backend", + type=str, + default="sglang", + help="Backend for base endpoint, depending on the LLM Inference Engine.", + ) + parser.add_argument( + "--target-backend", + type=str, + default="sglang", + help="Backend for target endpoint, depending on the LLM Inference Engine.", + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + parser.add_argument("--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0.") + parser.add_argument( + "--port", + type=int, + help="If not set, the default port is configured according to its default value for" + " different LLM Inference Engines.", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="sharegpt", + choices=[ + "sharegpt", + "random", + "generated-shared-prefix", + "ultrachat", + "loogle", + "nextqa", + ], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument("--dataset-path", type=str, default="", help="Path to the dataset.") + parser.add_argument( + "--base-model", + type=str, + help="Name or path of the model for base endpoint. If not set, the default model will request " + "/v1/models for conf.", + default=None, + ) + parser.add_argument( + "--target-model", + type=str, + help="Name or path of the model for target endpoint. If not set, the default model will request " + "/v1/models for conf.", + default=None, + ) + parser.add_argument( + "--tokenizer", + type=str, + help="Name or path of the tokenizer. If not set, using the model conf.", + ) + parser.add_argument( + "--chat-template", + type=str, + help="The buliltin chat template name or the path of the chat template file. This is only used " + "for OpenAI-compatible API server.", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompts to process. Default is 1000.", + ) + parser.add_argument( + "--fixed-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length from the dataset.", + ) + parser.add_argument( + "--sharegpt-context-len", + type=int, + default=None, + help="The context length of the model for the ShareGPT dataset. Requests longer than the " + "context length will be dropped.", + ) + parser.add_argument( + "--random-input-len", + type=int, + default=1024, + help="Number of input tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-output-len", + default=1024, + type=int, + help="Number of output tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range of sampled ratio of input/output length, " "used only for random dataset.", + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, then all the requests are sent at time 0. " + "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.", + ) + parser.add_argument( + "--max-concurrency", + type=int, + default=None, + help="Maximum number of concurrent requests. This can be used " + "to help simulate an environment where a higher level component " + "is enforcing a maximum number of concurrent requests. While the " + "--request-rate argument controls the rate at which requests are " + "initiated, this argument will control how many are actually allowed " + "to execute at a time. This means that when used in combination, the " + "actual request rate may be lower than specified with --request-rate, " + "if the server is not processing requests fast enough to keep up.", + ) + parser.add_argument( + "--multi", + action="store_true", + help="Use request rate range rather than single value.", + ) + parser.add_argument( + "--request-rate-range", + type=str, + default="2,34,2", + help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports " + "a list of request rates, requiring the parameters to not equal three.", + ) + parser.add_argument("--output-file", type=str, help="Output JSONL file name.") + parser.add_argument( + "--enable-multiturn", + action="store_true", + help="Enable multiturn chat for online serving benchmarking. " + "This option is effective on the following datasets: " + "sharegpt, ultrachat, loogle, nextqa", + ) + parser.add_argument( + "--enable-shared-prefix", + action="store_true", + help="Enable shared prefix for online serving benchmarking. " + "This option is effective on the following datasets: " + "loogle, nextqa", + ) + parser.add_argument( + "--disable-shuffle", + type=lambda x: x.lower() in ("true", "1", "yes"), + default=False, + help="Disable shuffling datasets. Accepts true/false.", + ) + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Specify to disable tqdm progress bar.", + ) + parser.add_argument( + "--disable-stream", + action="store_true", + help="Disable streaming mode.", + ) + parser.add_argument( + "--return-logprob", + action="store_true", + help="Return logprob.", + ) + parser.add_argument("--seed", type=int, default=1, help="The random seed.") + parser.add_argument( + "--disable-ignore-eos", + action="store_true", + help="Disable ignoring EOS.", + ) + parser.add_argument( + "--extra-request-body", + metavar='{"key1": "value1", "key2": "value2"}', + type=str, + help="Append given JSON object to the request payload. You can use this to specify" + "additional generate params like sampling params.", + ) + parser.add_argument( + "--apply-chat-template", + action="store_true", + help="Apply chat template", + ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " "SGLANG_TORCH_PROFILER_DIR to enable profiler.", + ) + parser.add_argument( + "--lora-name", + type=str, + default=None, + help="The name of LoRA adapter", + ) + + group = parser.add_argument_group("generated-shared-prefix dataset arguments") + group.add_argument( + "--gsp-num-groups", + type=int, + default=64, + help="Number of system prompt groups for generated-shared-prefix dataset", + ) + group.add_argument( + "--gsp-prompts-per-group", + type=int, + default=16, + help="Number of prompts per system prompt group for generated-shared-prefix dataset", + ) + group.add_argument( + "--gsp-system-prompt-len", + type=int, + default=2048, + help="Target length in tokens for system prompts in generated-shared-prefix dataset", + ) + group.add_argument( + "--gsp-question-len", + type=int, + default=128, + help="Target length in tokens for questions in generated-shared-prefix dataset", + ) + group.add_argument( + "--gsp-output-len", + type=int, + default=256, + help="Target length in tokens for outputs in generated-shared-prefix dataset", + ) + # videos specific + parser.add_argument( + "--max-frames", + type=int, + default=sys.maxsize, + help="The maximum number of frames to extract from each video. " + "This option is specific to the nextqa dataset (video benchmark). ", + ) + args = parser.parse_args() + return args + + +def _generate_avg_metrics(metrics_file: str, prefix: str = "", log_to_aml: bool = True): + """Generate average metrics from multiple trial results. + + Args: + metrics_file (str): Path to JSONL file containing metrics from each trial. + prefix (str, optional): Prefix to add to metric names. Defaults to "". + log_to_aml (bool, optional): Whether to log metrics to Azure ML. Defaults to True. + + Returns: + dict: Dictionary containing averaged metrics across all trials. + """ + metrics_list = [] + with open(metrics_file, "r") as f: + for line in f: + line = line.strip() + if line: + metrics_list.append(json.loads(line)) + + # Compute average metrics + avg_metrics = {} + count = len(metrics_list) + + for key in metrics_list[0].keys(): + if isinstance(metrics_list[0][key], (int, float)): + avg_metrics[key] = sum(result[key] for result in metrics_list) / count + else: + avg_metrics[key] = metrics_list[0][key] + + output_file = os.path.join(os.path.dirname(metrics_file), f"{prefix}metrics_avg.json") + with open(output_file, "w") as f: + json.dump(avg_metrics, f, indent=4) + + # Log metrics with prefix only if requested (not in child process) + if log_to_aml: + # Convert keys to uppercase for consistent bold display + prefixed_metrics = {} + for k, v in avg_metrics.items(): + metric_key = k.upper() + prefixed_metrics[f"{prefix}{metric_key}"] = v + log_metrics(prefixed_metrics) + + return avg_metrics + + +def run_endpoint_benchmark( + args, + endpoint_name: str, + url: str, + connection_name: str, + model: str, + backend: str, + output_dir: str, +): + """Run benchmark for a specific endpoint.""" + print(f"\n{'='*60}") + print(f"Starting benchmark for {endpoint_name} endpoint") + print(f"{'='*60}\n") + + # Create a copy of args for this endpoint + endpoint_args = deepcopy(args) + endpoint_args.base_url = url + endpoint_args.model = model + endpoint_args.backend = backend + + # Remove last slash if exists + if endpoint_args.base_url and endpoint_args.base_url.endswith("/"): + endpoint_args.base_url = endpoint_args.base_url[:-1] + + # Get API key for this endpoint + api_key, _ = get_api_key_from_connection(connection_name) + os.environ["OPENAI_API_KEY"] = api_key + + # Set output file for this endpoint + endpoint_args.output_file = os.path.join(output_dir, f"{endpoint_name}_metrics_each_trial.jsonl") + + trials = endpoint_args.trials + del endpoint_args.trials + + # Run trials + for trial in range(trials): + print(f"[{endpoint_name}] Starting trial {trial + 1} of {trials}...") + try: + run_benchmark(endpoint_args) + except Exception as e: + print(f"[{endpoint_name}] Trial {trial + 1} failed with error: {e}") + import traceback + + traceback.print_exc() + raise + + # Generate average metrics with prefix (don't log in child process) + _generate_avg_metrics(endpoint_args.output_file, prefix=f"{endpoint_name}_", log_to_aml=False) + print(f"\n[{endpoint_name}] Benchmark completed!\n") + + +def main(): + """Execute the endpoint evaluation benchmark. + + Orchestrates the complete benchmark process: + 1. Parses command-line arguments + 2. Runs benchmarks sequentially for base and target endpoints + 3. Logs final metrics to Azure ML + 4. Handles errors and cleanup + """ + args = parse_args() + print("> Parsed arguments:", args) + + # Store endpoint configurations + base_url = args.base_url + base_connection = args.connection_name + base_model = args.base_model + base_backend = args.base_backend + target_url = args.target_url + target_connection = args.target_connection_name + target_model = args.target_model + target_backend = args.target_backend + output_dir = args.output_file + + # Remove endpoint-specific args from the base args + del args.connection_name + del args.base_model + del args.base_backend + del args.target_url + del args.target_connection_name + del args.target_model + del args.target_backend + del args.output_file + del args.base_url + + # Run benchmarks sequentially + print("\n" + "=" * 60) + print("Starting sequential benchmarks for base and target endpoints") + print("=" * 60 + "\n") + + # Run base endpoint benchmark first + print("Running base endpoint benchmark...") + try: + run_endpoint_benchmark( + args, + "base", + base_url, + base_connection, + base_model, + base_backend, + output_dir, + ) + print("Base endpoint benchmark completed.\n") + except Exception as e: + print(f"[ERROR] Base endpoint benchmark failed: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + + # Run target endpoint benchmark second + print("Running target endpoint benchmark...") + try: + run_endpoint_benchmark( + args, + "target", + target_url, + target_connection, + target_model, + target_backend, + output_dir, + ) + print("Target endpoint benchmark completed.\n") + except Exception as e: + print(f"[ERROR] Target endpoint benchmark failed: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + + print("\n" + "=" * 60) + print("All benchmarks completed successfully!") + print("=" * 60 + "\n") + + # Log metrics to AML from main process + print("Logging metrics to AzureML...") + try: + base_metrics_file = os.path.join(output_dir, "base_metrics_avg.json") + target_metrics_file = os.path.join(output_dir, "target_metrics_avg.json") + + if os.path.exists(base_metrics_file): + with open(base_metrics_file, "r") as f: + base_metrics = json.load(f) + # Create prefixed metrics with uppercase keys for bold display + prefixed_base = {} + for k, v in base_metrics.items(): + metric_key = k.upper() # Convert to uppercase + prefixed_base[f"base_{metric_key}"] = v + log_metrics(prefixed_base) + print(f" ✓ Logged {len(base_metrics)} base metrics") + + if os.path.exists(target_metrics_file): + with open(target_metrics_file, "r") as f: + target_metrics = json.load(f) + # Create prefixed metrics with uppercase keys for bold display + prefixed_target = {} + for k, v in target_metrics.items(): + metric_key = k.upper() # Convert to uppercase + prefixed_target[f"target_{metric_key}"] = v + log_metrics(prefixed_target) + print(f" ✓ Logged {len(target_metrics)} target metrics") + + print("Metrics logging completed!") + except Exception as e: + print(f"Warning: Failed to log metrics to AML: {e}") + + +if __name__ == "__main__": + main() diff --git a/assets/training/endpoint_evaluation/src/nextqa.py b/assets/training/endpoint_evaluation/src/nextqa.py new file mode 100644 index 0000000000..4f19b512b0 --- /dev/null +++ b/assets/training/endpoint_evaluation/src/nextqa.py @@ -0,0 +1,256 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""NextQA dataset loading utilities for video question answering benchmarks. + +This module provides classes and functions for loading and processing +video files and NextQA dataset entries for benchmarking video-language models. +Supports both file-based video loading and HuggingFace dataset integration. +""" + +import os +import sys +from typing import List + +import av +from datasets import load_dataset + + +def find_video_files(video_dir) -> List[str]: + """Recursively find all video files in a directory. + + Args: + video_dir: Directory path to search for video files + + Returns: + List[str]: List of video file paths found + """ + if os.path.isfile(video_dir): + return [video_dir] + + video_files = [] + for root, dirs, files in os.walk(video_dir): + for file in files: + if file.endswith((".mp4", ".avi", ".mov")): + video_files.append(os.path.join(root, file)) + # if file is dir + elif os.path.isdir(file): + video_files.extend(find_video_files(file)) + return video_files + + +def video_frames(video_path, max_frames) -> int: + """Get the number of frames to extract from a video. + + Args: + video_path: Path to the video file + max_frames: Maximum number of frames to extract + + Returns: + int: Actual number of frames to extract (min of total frames and max_frames) + """ + container = av.open(video_path) + total_frames = container.streams.video[0].frames + return min(total_frames, max_frames) + + +class Video: + """Represents a video file with frame count information. + + Attributes: + path (str): Path to the video file + num_frames (int): Number of frames in the video + """ + + def __init__(self, video_path, num_frames): + """Initialize a Video object. + + Args: + video_path (str): Path to the video file + num_frames (int): Number of frames in the video + """ + self.path = video_path + self.num_frames = num_frames + + def __str__(self): + """Return string representation of the Video object.""" + return f"Video({self.path}, {self.num_frames})" + + def __iter__(self): + """Return iterator over video path and frame count.""" + return iter((self.path, self.num_frames)) + + +class VideoPrompt(Video): + """Represents a video with an associated text prompt for QA tasks. + + Inherits from Video and adds prompt functionality for question-answering. + + Attributes: + path (str): Path to the video file + num_frames (int): Number of frames in the video + prompt (str): Text prompt/question associated with the video + """ + + def __init__(self, video_path, num_frames, prompt): + """Initialize a VideoPrompt object. + + Args: + video_path (str): Path to the video file + num_frames (int): Number of frames in the video + prompt (str): Text prompt/question for the video + """ + super().__init__(video_path, num_frames) + self.prompt = prompt + + def __str__(self): + """Return string representation of the VideoPrompt object.""" + return f"VideoPrompt({self.path}, {self.num_frames}, {self.prompt})" + + def __iter__(self): + """Return iterator over video path, frame count, and prompt.""" + return iter((self.path, self.num_frames, self.prompt)) + + +class VideoLoader: + """Base class for video loading implementations. + + Provides a common interface for different video loading strategies. + """ + + pass + + +class VideoFileLoader(VideoLoader): + """Load videos from filesystem directory. + + Scans a directory for video files and provides iteration interface + for processing videos in batches. + """ + + def __init__(self, video_dir, batch_size=1, max_frames=sys.maxsize): + """Initialize video file loader. + + Args: + video_dir (str): Directory containing video files + batch_size (int): Number of videos per batch + max_frames (int): Maximum frames to extract per video + """ + super().__init__() + self.video_dir = video_dir + self.video_files = find_video_files(video_dir) + self.batch_size = batch_size + self.max_frames = max_frames + print(f"batch_size: {batch_size}, max_frames: {max_frames}") + + def __iter__(self): + """Iterate over video files in the directory. + + Yields: + Video or List[Video]: Individual videos or batches of videos + """ + if self.batch_size == 1: + for video_file in self.video_files: + yield Video(video_file, video_frames(video_file, self.max_frames)) + else: + batch = [] + for video_file in self.video_files: + video = Video(video_file, video_frames(video_file, self.max_frames)) + batch.append(video) + if len(batch) == self.batch_size: + yield batch + batch = [] + + +class NExTQALoader(VideoLoader): + """Load videos and prompts from NextQA dataset. + + Integrates with HuggingFace datasets to load NextQA dataset and + combines it with local video files for benchmarking. + """ + + def __init__(self, video_dir, batch_size=1, max_frames=sys.maxsize, dset="test", task="OE"): + """Initialize NextQA dataset loader. + + Args: + video_dir (str): Directory containing NextQA video files + batch_size (int): Number of video prompts per batch + max_frames (int): Maximum frames to extract per video + dset (str): Dataset split ('train', 'test', 'validation') + task (str): Task type ('MV' for multiple choice, 'OE' for open-ended) + """ + super().__init__() + self.task = task + print(f"Loading the {dset} data of {task} from lmms-lab/NExTQA") + self.ds = load_dataset("lmms-lab/NExTQA", task) + self.ds = self.ds[dset] + + # self.n = ds.num_rows + self.video_dir = video_dir + self.video_files = find_video_files(video_dir) + self.video_to_path = dict() + for video_file in self.video_files: + video_id = video_file.split("/")[-1].split(".")[0] + self.video_to_path[video_id] = video_file + + self.batch_size = batch_size + self.max_frames = max_frames + + def get_video_prompt(self, entry, max_frames) -> VideoPrompt: + """Create VideoPrompt object from dataset entry. + + Args: + entry: Dataset entry containing video and question information + max_frames (int): Maximum number of frames to extract + + Returns: + VideoPrompt: Video object with associated question prompt + """ + # Get video + video_id = entry["video"] + video_path = self.video_to_path[video_id] + assert os.path.exists(video_path), f"Video not found: {video_path}" + num_frames = min(entry["frame_count"], max_frames) + # video = Video(video_path, num_frames) + prompt = entry["question"] + "?" + if self.task == "MC": # add choices + prompt += f' a0: {entry["a0"]}, a1: {entry["a1"]}, a2: {entry["a2"]}, a3: {entry["a3"]}' + return VideoPrompt(video_path, num_frames, prompt) + + def __iter__(self): + """Iterate over NextQA dataset entries. + + Yields: + VideoPrompt or List[VideoPrompt]: Individual video prompts or batches + """ + if self.batch_size == 1: + for entry in self.ds: + yield self.get_video_prompt(entry, self.max_frames) + else: + batch = [] + for entry in self.ds: + video = self.get_video_prompt(entry, self.max_frames) + batch.append(video) + if len(batch) == self.batch_size: + yield batch + batch = [] + + +# main +if __name__ == "__main__": + video_dir = "./videos" + # video_loader = VideoFileLoader(video_dir, batch_size=16) + # for batch in video_loader: + # print(f"Number of videos in batch: {len(batch)}") + # for video_file, num_frames in batch: + # print(f"Video: {video_file} number of frames: {num_frames}") + + video_loader = NExTQALoader(video_dir, batch_size=16, dset="test", task="OE") + for batch in video_loader: + print(f"Number of videos in batch: {len(batch)}") + for video_file, num_frames, prompt in batch: + print(f"Video: {video_file} number of frames: {num_frames}, prompt: {prompt}") + # break + # for video_file, prompt in batch: + # print(f"Video: {video_file} prompt: {prompt}") + # break