@@ -60,6 +60,7 @@ class SlotData:
6060 tokens : Union [jax .Array , np .ndarray ]
6161 valid : Union [jax .Array , np .ndarray ]
6262 lengths : Union [jax .Array , np .ndarray ]
63+ log_prob : Union [jax .Array , np .ndarray ] = None
6364
6465
6566# pylint: disable=g-doc-args
@@ -91,6 +92,11 @@ class ResultTokens(abc.ABC):
9192 samples_per_slot : int = struct .field (
9293 pytree_node = False ,
9394 )
95+ # log probabilities of the tokens. Shape: [batch, tokens]
96+ log_prob : Union [jax .Array , np .ndarray ] = struct .field (
97+ pytree_node = False ,
98+ default = None ,
99+ )
94100
95101 def copy_to_host_async (self : "ResultTokens" ) -> None :
96102 """Copy to host asynchronously."""
@@ -107,6 +113,7 @@ def convert_to_numpy(self: "ResultTokens") -> "ResultTokens":
107113 self .valid_idx ,
108114 self .length_idx ,
109115 self .samples_per_slot ,
116+ self .log_prob ,
110117 )
111118
112119 def get_result_at_slot (self , slot : int ) -> SlotData :
@@ -148,6 +155,7 @@ def get_result_at_slots(self, slots: tuple[int]) -> SlotData:
148155 valid = self .data [slots , self .valid_idx [0 ] : self .valid_idx [1 ]],
149156 # Only get a 1D representation here
150157 lengths = self .data [slots , self .length_idx [0 ] : self .length_idx [1 ]][:, 0 ],
158+ log_prob = self .log_prob [slots , :] if self .log_prob is not None else None ,
151159 )
152160
153161
0 commit comments