Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
498 changes: 498 additions & 0 deletions cpp/tensorrt_llm/kernels/speculativeDecoding/logitsPenaltyKernels.cu

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES.
* All rights reserved. SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include "tensorrt_llm/common/config.h"
#include "tensorrt_llm/common/cudaDriverWrapper.h"
#include "tensorrt_llm/common/cudaUtils.h"

#include <cuda_runtime.h>
#include <cstdint>

TRTLLM_NAMESPACE_BEGIN

namespace kernels
{

template <typename T, typename TokenT>
void invokeApplySpeculativeTokenPenalties(T* logits, TokenT const* tokenIds, float const* penaltyValues, int32_t numRows,
int32_t width, int32_t vocabSize, cudaStream_t stream);

void invokeApplySpeculativeHistoryFrequencyPenalty(float* logits, int32_t const* historyTokens,
int32_t const* historyLens, int32_t const* rowSlots, float const* frequencyPenalties, int32_t numRows,
int32_t historyCapacity, int32_t vocabSize, cudaStream_t stream);

void invokeAppendSpeculativeAcceptedTokens(int32_t* historyTokens, int32_t* historyLens, int32_t const* seqSlots,
int32_t const* acceptedTokens, int32_t const* acceptedLens, int32_t numRows, int32_t acceptedStride,
int32_t historyCapacity, cudaStream_t stream);

template <typename T>
void invokeApplySpeculativeCountFrequencyPenalty(T* logits, int32_t const* tokenCounts,
int32_t const* rowSlots, float const* frequencyPenalties, int32_t numRows, int32_t vocabSize, cudaStream_t stream);

void invokeAppendSpeculativeAcceptedTokenCounts(int32_t* tokenCounts, int32_t const* seqSlots,
int32_t const* acceptedTokens, int32_t const* acceptedLens, int32_t numRows, int32_t acceptedStride,
int32_t vocabSize, cudaStream_t stream);

template <typename T>
void invokeApplySpeculativeSparseCountFrequencyPenalty(T* logits, int32_t const* tokenIds,
int32_t const* tokenCounts, int32_t const* countLens, int32_t const* rowSlots, float const* frequencyPenalties,
int32_t numRows, int32_t countCapacity, int32_t vocabSize, cudaStream_t stream);

void invokeAppendSpeculativeSparseTokenCounts(int32_t* tokenIds, int32_t* tokenCounts, int32_t* countLens,
int32_t const* seqSlots, int32_t const* acceptedTokens, int32_t const* acceptedLens, int32_t numRows,
int32_t acceptedStride, int32_t countCapacity, int32_t vocabSize, cudaStream_t stream);

void invokeInitSpeculativeSparseTokenCounts(int32_t* tokenIds, int32_t* tokenCounts, int32_t* countLens,
int32_t const* promptTokenIds, int32_t const* promptTokenCounts, int32_t const* promptLens,
int32_t const* seqSlots, int32_t numRows, int32_t promptCapacity, int32_t countCapacity, int32_t vocabSize,
cudaStream_t stream);

} // namespace kernels

TRTLLM_NAMESPACE_END
6 changes: 5 additions & 1 deletion cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output,
if (!weight_warp)
{
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
}

for (int ki = 0; ki < K_LOOPS_DMA; ki++)
Expand Down Expand Up @@ -422,6 +421,11 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output,

__syncthreads();

if (threadIdx.x == 0) // one thread per block suffices according to official code examples
{
cudaTriggerProgrammaticLaunchCompletion();
}

if (warp_id == 0)
{

Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/thop/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ add_library(
weightOnlyQuantGemm.cpp
weightOnlyQuantOp.cpp
specDecOp.cpp
speculativeLogitsPenaltyOp.cpp
loraOp.cpp
finegrained_mixed_dtype_gemm_thop.cpp
tinygemm2.cpp
Expand Down
473 changes: 473 additions & 0 deletions cpp/tensorrt_llm/thop/speculativeLogitsPenaltyOp.cpp

Large diffs are not rendered by default.

114 changes: 108 additions & 6 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,9 @@ def __init__(

self.is_warmup = False
self.previous_request_ids = []
self.previous_device_sampled_request_ids: set[int] = set()
self.debug_spec_device_draft_guard = (
os.environ.get("TRTLLM_SPEC_DRAFT_GUARD_DEBUG", "0") == "1")
self.has_previous_device_draft = False
self.previous_accepted_tokens_cuda = torch.empty((self.batch_size, ),
dtype=torch.int,
Expand Down Expand Up @@ -1410,7 +1413,8 @@ def _set_up_spec_metadata(
max_num_tokens=self.max_num_tokens,
spec_resource_manager=spec_resource_manager,
is_draft_model=self.is_draft_model,
max_seq_len=self.max_seq_len)
max_seq_len=self.max_seq_len,
max_num_sequence_slots=self.get_max_num_sequences())

if self.spec_metadata is not None:
return self.spec_metadata
Expand All @@ -1421,7 +1425,8 @@ def _set_up_spec_metadata(
max_num_tokens=self.max_num_tokens,
spec_resource_manager=spec_resource_manager,
is_draft_model=self.is_draft_model,
max_seq_len=self.max_seq_len)
max_seq_len=self.max_seq_len,
max_num_sequence_slots=self.get_max_num_sequences())
return self.spec_metadata

def __del__(self) -> None:
Expand Down Expand Up @@ -2374,13 +2379,27 @@ def _prepare_tp_inputs(
extend_dummy_requests = []
generation_requests = []
first_draft_requests = []
previous_device_sampled_request_ids = self.previous_device_sampled_request_ids
# Collect generation request IDs during categorization to avoid
# a separate iteration over scheduled_requests.generation_requests later.
all_gen_request_ids = []
for request in scheduled_requests.generation_requests:
all_gen_request_ids.append(request.py_request_id)
if get_draft_token_length(
request) > 0 or next_draft_tokens_device is not None:
has_previous_device_draft = (
next_draft_tokens_device is not None
and request.py_batch_idx is not None
and request.py_request_id in previous_device_sampled_request_ids)
if (self.debug_spec_device_draft_guard
and next_draft_tokens_device is not None
and request.py_batch_idx is not None
and request.py_request_id
not in previous_device_sampled_request_ids):
logger.info(
"Ignoring stale speculative device draft for request_id=%s "
"prev_seq_slot=%s current_seq_slot=%s",
request.py_request_id, request.py_batch_idx,
request.py_seq_slot)
if get_draft_token_length(request) > 0 or has_previous_device_draft:
if request.is_dummy:
extend_dummy_requests.append(request)
else:
Expand Down Expand Up @@ -2416,7 +2435,11 @@ def _prepare_tp_inputs(
# (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or
# (2) a dummy request; or
# (3) the first step in the generation server of disaggregated serving
if next_draft_tokens_device is None or request.is_dummy or request.py_batch_idx is None:
has_previous_device_draft = (
next_draft_tokens_device is not None
and request.py_batch_idx is not None
and request.py_request_id in previous_device_sampled_request_ids)
if not has_previous_device_draft or request.is_dummy:
# get token ids, including input token ids and draft token ids. For these dummy requests,
# no need to copy the token ids.
if not (request.is_attention_dp_dummy
Expand Down Expand Up @@ -3069,6 +3092,9 @@ def previous_seq_slots_device():

if spec_metadata is not None:
total_draft_lens = sum(draft_lens)
spec_sampling_requests = (
scheduled_requests.context_requests + extend_requests +
first_draft_requests + generation_requests)
spec_metadata.draft_tokens = self.draft_tokens_cuda[:
total_draft_lens]
spec_metadata.request_ids = request_ids
Expand All @@ -3077,13 +3103,23 @@ def previous_seq_slots_device():
scheduled_requests.generation_requests)
spec_metadata.num_tokens = total_num_tokens
spec_metadata.seq_lens = sequence_lengths
spec_metadata.sampling_request_ids = [
int(request.py_request_id) for request in spec_sampling_requests
]
spec_metadata.sampling_seq_slots = [
int(request.py_seq_slot)
if request.py_seq_slot is not None else -1
for request in spec_sampling_requests
]
spec_metadata.num_accepted_draft_tokens = self.num_accepted_draft_tokens_cuda[:len(
num_accepted_draft_tokens)]
if isinstance(spec_metadata, Eagle3SpecMetadata):
spec_metadata.request_accepted_path = request_accepted_path
# No-op for non 1-model
spec_metadata.populate_sampling_params_for_one_model(
scheduled_requests.all_requests())
spec_sampling_requests)
spec_metadata.prepare_device_penalty_counts(
int(self.model.config.vocab_size))
spec_metadata.prepare()
inputs['spec_metadata'] = spec_metadata

Expand Down Expand Up @@ -3116,6 +3152,13 @@ def previous_seq_slots_device():

if not self.is_warmup:
self.previous_request_ids = all_gen_request_ids
self.previous_device_sampled_request_ids = {
request.py_request_id
for request in (
scheduled_requests.context_requests_last_chunk +
scheduled_requests.generation_requests)
if not request.is_dummy
}
self.has_previous_device_draft = next_draft_tokens_device is not None

return inputs, self.gather_ids_cuda[:len(
Expand Down Expand Up @@ -3871,9 +3914,68 @@ def capture_postprocess_fn(inputs: Dict[str, Any]):
self.forward_pass_callable()

self._execute_logit_post_processors(scheduled_requests, outputs)
self._attach_spec_penalty_outputs(outputs,
inputs.get("spec_metadata"))

return outputs

@staticmethod
def _attach_spec_penalty_outputs(outputs: Dict[str, Any],
spec_metadata: Any) -> None:
if not isinstance(outputs, dict) or spec_metadata is None:
return

sampling_request_ids = getattr(spec_metadata, "sampling_request_ids",
None)
if sampling_request_ids is not None:
outputs["penalty_sampling_request_ids"] = sampling_request_ids
sampling_seq_slots = getattr(spec_metadata, "sampling_seq_slots", None)
if sampling_seq_slots is not None:
outputs["penalty_sampling_seq_slots"] = sampling_seq_slots

if getattr(spec_metadata, "use_device_penalty_counts", False):
count_seq_slots = getattr(spec_metadata,
"device_penalty_count_seq_slots", None)
if count_seq_slots is not None:
outputs["penalty_count_seq_slots"] = count_seq_slots
count_mode = getattr(spec_metadata, "device_penalty_count_mode",
"")
if (count_mode == "dense"
and getattr(spec_metadata, "device_penalty_token_counts",
None) is not None):
outputs[
"penalty_token_counts"] = spec_metadata.device_penalty_token_counts
elif (count_mode == "sparse"
and getattr(spec_metadata, "device_penalty_sparse_token_ids",
None) is not None
and getattr(spec_metadata,
"device_penalty_sparse_token_counts",
None) is not None
and getattr(spec_metadata, "device_penalty_sparse_count_lens",
None) is not None):
outputs[
"penalty_sparse_token_ids"] = spec_metadata.device_penalty_sparse_token_ids
outputs[
"penalty_sparse_token_counts"] = spec_metadata.device_penalty_sparse_token_counts
outputs[
"penalty_sparse_count_lens"] = spec_metadata.device_penalty_sparse_count_lens
outputs["penalty_count_vocab_size"] = getattr(
spec_metadata, "device_penalty_count_vocab_size", 0)

if (getattr(spec_metadata, "use_device_penalty_history", False)
and getattr(spec_metadata, "device_penalty_history_tokens",
None) is not None
and getattr(spec_metadata, "device_penalty_history_lens",
None) is not None):
history_seq_slots = getattr(spec_metadata,
"device_penalty_seq_slots", None)
if history_seq_slots is not None:
outputs["penalty_history_seq_slots"] = history_seq_slots
outputs[
"penalty_history_tokens"] = spec_metadata.device_penalty_history_tokens
outputs[
"penalty_history_lens"] = spec_metadata.device_penalty_history_lens

def model_forward(self, **kwargs):
attrs = get_model_extra_attrs()
assert attrs is not None, "Model extra attrs is not set"
Expand Down
Loading
Loading