Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bitsandbytes/triton/quantize_columnwise_and_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _quantize_columnwise_and_transpose(
x = tl.load(x_ptr + offsets, mask=p2_arange_mask)
abs_x = tl.abs(x)
max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127.0 * (x / max_val))
output = tl.extra.cuda.libdevice.rint(127.0 * (x / max_val))

new_start = pid * M
new_offsets = new_start + p2_arange
Expand Down
4 changes: 2 additions & 2 deletions bitsandbytes/triton/quantize_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _quantize_global(
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
absmax_inv = tl.load(absmax_inv_ptr)
output = tl.libdevice.llrint(127.0 * (x * absmax_inv))
output = tl.extra.cuda.libdevice.rint(127.0 * (x * absmax_inv))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue here: should be tl.extra.cuda.libdevice.llrint (not rint) to match the original semantics.

tl.store(output_ptr + offsets, output, mask=mask)

def quantize_global(x: torch.Tensor):
Expand Down Expand Up @@ -95,7 +95,7 @@ def _quantize_global_transpose(
B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn)
mask = (rm < M)[:, None] & (rn < N)[None, :]

output = tl.libdevice.llrint(127.0 * (a * absmax_inv))
output = tl.extra.cuda.libdevice.rint(127.0 * (a * absmax_inv))

tl.store(B, output, mask=mask)

Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/triton/quantize_rowwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _quantize_rowwise(

abs_x = tl.abs(x)
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127.0 * (x / max_val))
output = tl.extra.cuda.libdevice.rint(127.0 * (x / max_val))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tl.extra.cuda.libdevice.rint is not the correct 1:1 replacement for tl.libdevice.llrint. llrint rounds to nearest integer and returns an integer type; rint rounds to nearest integer but returns a float. Since tl.extra.cuda.libdevice.llrint exists in modern Triton, this should use tl.extra.cuda.libdevice.llrint instead to preserve the original semantics.

tl.store(output_ptr + offsets, output, mask=row_mask)
tl.store(output_maxs + pid, max_val)

Expand Down