Skip to content

Commit c5e54e1

Browse files
author
jetstream authors
committed
Merge pull request #260 from AI-Hypercomputer:maxtext_logp
PiperOrigin-RevId: 751216668
2 parents 5c398f6 + 7cec37f commit c5e54e1

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

CONTRIBUTING.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ This project follows
2525

2626
## Contribution process
2727

28+
### Style Formatting
29+
30+
Please run `make format` before submitting your PR.
31+
2832
### Code Reviews
2933

3034
All submissions, including submissions by project members, require review. We

jetstream/engine/engine_api.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class SlotData:
6060
tokens: Union[jax.Array, np.ndarray]
6161
valid: Union[jax.Array, np.ndarray]
6262
lengths: Union[jax.Array, np.ndarray]
63+
log_prob: Union[jax.Array, np.ndarray] = None
6364

6465

6566
# pylint: disable=g-doc-args
@@ -91,6 +92,11 @@ class ResultTokens(abc.ABC):
9192
samples_per_slot: int = struct.field(
9293
pytree_node=False,
9394
)
95+
# log probabilities of the tokens. Shape: [batch, tokens]
96+
log_prob: Union[jax.Array, np.ndarray] = struct.field(
97+
pytree_node=False,
98+
default=None,
99+
)
94100

95101
def copy_to_host_async(self: "ResultTokens") -> None:
96102
"""Copy to host asynchronously."""
@@ -107,6 +113,7 @@ def convert_to_numpy(self: "ResultTokens") -> "ResultTokens":
107113
self.valid_idx,
108114
self.length_idx,
109115
self.samples_per_slot,
116+
self.log_prob,
110117
)
111118

112119
def get_result_at_slot(self, slot: int) -> SlotData:
@@ -148,6 +155,7 @@ def get_result_at_slots(self, slots: tuple[int]) -> SlotData:
148155
valid=self.data[slots, self.valid_idx[0] : self.valid_idx[1]],
149156
# Only get a 1D representation here
150157
lengths=self.data[slots, self.length_idx[0] : self.length_idx[1]][:, 0],
158+
log_prob=self.log_prob[slots, :] if self.log_prob is not None else None,
151159
)
152160

153161

0 commit comments

Comments
 (0)