2626#include " types.h"
2727
2828#include < executorch/extension/llm/runner/llm_runner_helper.h>
29+ #include < executorch/extension/llm/runner/util.h>
2930#include < executorch/extension/llm/runner/wav_loader.h>
3031#include < executorch/extension/llm/tokenizers/third-party/llama.cpp-unicode/include/unicode.h>
3132#include < executorch/extension/module/module.h>
3233#include < executorch/extension/tensor/tensor_ptr_maker.h>
3334#include < executorch/runtime/core/evalue.h>
35+ #include < executorch/runtime/core/exec_aten/util/scalar_type_util.h>
3436#include < executorch/runtime/platform/log.h>
3537#ifdef ET_BUILD_METAL
3638#include < executorch/backends/apple/metal/runtime/stats.h>
@@ -108,6 +110,39 @@ TimestampOutputMode parse_timestamp_output_mode(const std::string& raw_arg) {
108110 " '. Expected: token, word, segment, all." );
109111}
110112
113+ // Helper to get expected scalar type for a method input
114+ ::executorch::runtime::Result<::executorch::aten::ScalarType>
115+ get_input_scalar_type (
116+ Module& model,
117+ const char * method_name,
118+ size_t input_index) {
119+ auto method_meta_result = model.method_meta (method_name);
120+ if (!method_meta_result.ok ()) {
121+ ET_LOG (Error, " Failed to get method metadata for %s" , method_name);
122+ return method_meta_result.error ();
123+ }
124+ auto method_meta = method_meta_result.get ();
125+ if (method_meta.num_inputs () <= input_index) {
126+ ET_LOG (
127+ Error,
128+ " Method %s has %zu inputs, but requested index %zu" ,
129+ method_name,
130+ method_meta.num_inputs (),
131+ input_index);
132+ return ::executorch::runtime::Error::InvalidArgument;
133+ }
134+ auto input_meta_result = method_meta.input_tensor_meta (input_index);
135+ if (input_meta_result.error () != ::executorch::runtime::Error::Ok) {
136+ ET_LOG (
137+ Error,
138+ " Failed to get input tensor metadata for %s[%zu]" ,
139+ method_name,
140+ input_index);
141+ return input_meta_result.error ();
142+ }
143+ return input_meta_result.get ().scalar_type ();
144+ }
145+
111146std::vector<Token> greedy_decode_executorch (
112147 Module& model,
113148 const ::executorch::aten::Tensor& f_proj,
@@ -118,27 +153,49 @@ std::vector<Token> greedy_decode_executorch(
118153 int64_t max_symbols_per_step = 10 ) {
119154 std::vector<Token> hypothesis;
120155
121- // Shape: [1, time_steps, joint_hidden]
122- auto f_proj_sizes = f_proj.sizes ();
123- int64_t time_steps = f_proj_sizes[1 ];
124- int64_t proj_dim = f_proj_sizes[2 ];
156+ // Shape: [1, T, joint_hidden]
157+ size_t proj_dim = static_cast <size_t >(f_proj.sizes ()[2 ]);
125158
126- // Initialize LSTM state
127- std::vector<float > h_data (num_rnn_layers * 1 * pred_hidden, 0 .0f );
128- std::vector<float > c_data (num_rnn_layers * 1 * pred_hidden, 0 .0f );
159+ // Get expected dtype for decoder_step h and c inputs (indices 1 and 2)
160+ auto h_dtype_result = get_input_scalar_type (model, " decoder_step" , 1 );
161+ if (!h_dtype_result.ok ()) {
162+ return hypothesis;
163+ }
164+ auto c_dtype_result = get_input_scalar_type (model, " decoder_step" , 2 );
165+ if (!c_dtype_result.ok ()) {
166+ return hypothesis;
167+ }
168+ auto h_dtype = h_dtype_result.get ();
169+ auto c_dtype = c_dtype_result.get ();
170+
171+ ET_LOG (
172+ Info,
173+ " Decoder h dtype: %s, c dtype: %s" ,
174+ ::executorch::runtime::toString (h_dtype),
175+ ::executorch::runtime::toString(c_dtype));
176+
177+ // Calculate buffer sizes based on dtype
178+ size_t h_elem_size = ::executorch::runtime::elementSize (h_dtype);
179+ size_t c_elem_size = ::executorch::runtime::elementSize (c_dtype);
180+ size_t num_elements =
181+ static_cast <size_t >(num_rnn_layers) * static_cast <size_t >(pred_hidden);
182+
183+ // Initialize LSTM state with zeros (using byte buffers for dtype flexibility)
184+ std::vector<uint8_t > h_data (num_elements * h_elem_size, 0 );
185+ std::vector<uint8_t > c_data (num_elements * c_elem_size, 0 );
129186
130187 auto h = from_blob (
131188 h_data.data (),
132189 {static_cast <::executorch::aten::SizesType>(num_rnn_layers),
133190 1 ,
134191 static_cast <::executorch::aten::SizesType>(pred_hidden)},
135- ::executorch::aten::ScalarType::Float );
192+ h_dtype );
136193 auto c = from_blob (
137194 c_data.data (),
138195 {static_cast <::executorch::aten::SizesType>(num_rnn_layers),
139196 1 ,
140197 static_cast <::executorch::aten::SizesType>(pred_hidden)},
141- ::executorch::aten::ScalarType::Float );
198+ c_dtype );
142199
143200 // Prime the decoder with SOS (= blank_id) to match NeMo TDT label-looping:
144201 // - SOS is defined as blank:
@@ -159,41 +216,61 @@ std::vector<Token> greedy_decode_executorch(
159216 auto g_proj_init = init_outputs[0 ].toTensor ();
160217 auto new_h_init = init_outputs[1 ].toTensor ();
161218 auto new_c_init = init_outputs[2 ].toTensor ();
162- std::memcpy (
163- h_data.data (),
164- new_h_init.const_data_ptr <float >(),
165- h_data.size () * sizeof (float ));
166- std::memcpy (
167- c_data.data (),
168- new_c_init.const_data_ptr <float >(),
169- c_data.size () * sizeof (float ));
219+ std::memcpy (h_data.data (), new_h_init.const_data_ptr (), h_data.size ());
220+ std::memcpy (c_data.data (), new_c_init.const_data_ptr (), c_data.size ());
170221
171- // Copy g_proj data for reuse
172- std::vector<float > g_proj_data (
173- g_proj_init.const_data_ptr <float >(),
174- g_proj_init.const_data_ptr <float >() + g_proj_init.numel ());
222+ // Get expected dtype for joint inputs (f and g at indices 0 and 1)
223+ auto f_dtype_result = get_input_scalar_type (model, " joint" , 0 );
224+ if (!f_dtype_result.ok ()) {
225+ return hypothesis;
226+ }
227+ auto g_dtype_result = get_input_scalar_type (model, " joint" , 1 );
228+ if (!g_dtype_result.ok ()) {
229+ return hypothesis;
230+ }
231+ auto f_dtype = f_dtype_result.get ();
232+ auto g_dtype = g_dtype_result.get ();
233+
234+ ET_LOG (
235+ Info,
236+ " Joint f dtype: %s, g dtype: %s" ,
237+ ::executorch::runtime::toString (f_dtype),
238+ ::executorch::runtime::toString(g_dtype));
239+
240+ size_t f_elem_size = ::executorch::runtime::elementSize (f_dtype);
241+ size_t g_elem_size = ::executorch::runtime::elementSize (g_dtype);
242+
243+ // Copy g_proj data for reuse (using byte buffer for dtype flexibility)
244+ size_t g_proj_num_bytes =
245+ static_cast <size_t >(g_proj_init.numel ()) * g_elem_size;
246+ std::vector<uint8_t > g_proj_data (g_proj_num_bytes);
247+ std::memcpy (
248+ g_proj_data.data (), g_proj_init.const_data_ptr (), g_proj_num_bytes);
175249
176250 int64_t t = 0 ;
177251 int64_t symbols_on_frame = 0 ;
252+ const uint8_t * f_proj_ptr =
253+ static_cast <const uint8_t *>(f_proj.const_data_ptr ());
254+ size_t f_t_num_bytes = proj_dim * f_elem_size;
178255
179256 // Scan over encoder output
180257 while (t < encoder_len) {
181258 // Get encoder frame at time t: f_proj[:, t:t+1, :]
182- const float * f_proj_ptr = f_proj.const_data_ptr <float >();
259+ std::vector<uint8_t > f_t_data (f_t_num_bytes);
260+ std::memcpy (
261+ f_t_data.data (),
262+ f_proj_ptr + static_cast <size_t >(t) * f_t_num_bytes,
263+ f_t_num_bytes);
183264
184- std::vector<float > f_t_data (1 * 1 * proj_dim);
185- for (int64_t d = 0 ; d < proj_dim; d++) {
186- f_t_data[d] = f_proj_ptr[t * proj_dim + d];
187- }
188265 auto f_t = from_blob (
189266 f_t_data.data (),
190267 {1 , 1 , static_cast <::executorch::aten::SizesType>(proj_dim)},
191- ::executorch::aten::ScalarType::Float );
268+ f_dtype );
192269
193270 auto g_proj = from_blob (
194271 g_proj_data.data (),
195272 {1 , 1 , static_cast <::executorch::aten::SizesType>(proj_dim)},
196- ::executorch::aten::ScalarType::Float );
273+ g_dtype );
197274
198275 auto joint_result = model.execute (
199276 " joint" , std::vector<::executorch::runtime::EValue>{f_t , g_proj});
@@ -230,18 +307,10 @@ std::vector<Token> greedy_decode_executorch(
230307 auto new_c = outputs[2 ].toTensor ();
231308
232309 // Update h, c, and g_proj
310+ std::memcpy (h_data.data (), new_h.const_data_ptr (), h_data.size ());
311+ std::memcpy (c_data.data (), new_c.const_data_ptr (), c_data.size ());
233312 std::memcpy (
234- h_data.data (),
235- new_h.const_data_ptr <float >(),
236- h_data.size () * sizeof (float ));
237- std::memcpy (
238- c_data.data (),
239- new_c.const_data_ptr <float >(),
240- c_data.size () * sizeof (float ));
241- std::memcpy (
242- g_proj_data.data (),
243- new_g_proj.const_data_ptr <float >(),
244- g_proj_data.size () * sizeof (float ));
313+ g_proj_data.data (), new_g_proj.const_data_ptr (), g_proj_data.size ());
245314
246315 t += dur;
247316
0 commit comments