diff --git a/src/maxtext/input_pipeline/grain_data_processing.py b/src/maxtext/input_pipeline/grain_data_processing.py index cf2dd649bd..ca5a9a4a8c 100644 --- a/src/maxtext/input_pipeline/grain_data_processing.py +++ b/src/maxtext/input_pipeline/grain_data_processing.py @@ -238,7 +238,7 @@ def pretrain_preprocessing_pipeline( # global_batch_size_to_load has been expanded in pyconfig.py when expansion_factor_real_data > 1. # But when using Grain, we want to keep the batch_size consistent with that in the checkpoint. # We revert the batch_size expansion here, but load multiple batches per step in multihost_dataloading.py. - batch_size = batch_size // config.expansion_factor_real_data + batch_size = int(batch_size // config.expansion_factor_real_data) if config.packing: length_struct = {col: config.max_target_length for col in data_columns} diff --git a/src/maxtext/input_pipeline/multihost_dataloading.py b/src/maxtext/input_pipeline/multihost_dataloading.py index 369ebcff0a..4b31628120 100644 --- a/src/maxtext/input_pipeline/multihost_dataloading.py +++ b/src/maxtext/input_pipeline/multihost_dataloading.py @@ -125,7 +125,7 @@ def _get_next_batch_sharded(self) -> jax.Array: # expansion_loading_factor_for_grain times to get the # right batch_size for the host that is loading real data. local_data_list = [local_data] - for _ in range(1, self.expansion_loading_factor_for_grain): + for _ in range(1, int(self.expansion_loading_factor_for_grain)): next_batch = next(self.local_iterator) local_data_list.append(next_batch) local_data = jtu.tree_map(lambda *xs: np.concatenate(xs, axis=0), *local_data_list) diff --git a/src/maxtext/input_pipeline/synthetic_data_processing.py b/src/maxtext/input_pipeline/synthetic_data_processing.py index 80332ace9f..4660df7a41 100644 --- a/src/maxtext/input_pipeline/synthetic_data_processing.py +++ b/src/maxtext/input_pipeline/synthetic_data_processing.py @@ -19,8 +19,6 @@ import numpy as np -import tensorflow as tf - import jax import jax.numpy as jnp from jax.sharding import PartitionSpec as P @@ -100,26 +98,19 @@ def reset(self): @staticmethod def get_place_holder_synthetic_data(config: pyconfig.HyperParameters): """fill negative value in synthetic data""" - output = {} - output["inputs"] = tf.data.Dataset.from_tensor_slices( - np.full((1, config.max_target_length), -1, dtype=jax.numpy.int32) - ) - output["inputs_position"] = tf.data.Dataset.from_tensor_slices( - np.full((1, config.max_target_length), -1, dtype=jax.numpy.int32) - ) - output["inputs_segmentation"] = tf.data.Dataset.from_tensor_slices( - np.full((1, config.max_target_length), -1, dtype=jax.numpy.int32) - ) - output["targets"] = tf.data.Dataset.from_tensor_slices( - np.full((1, config.max_target_length), -1, dtype=jax.numpy.int32) - ) - output["targets_position"] = tf.data.Dataset.from_tensor_slices( - np.full((1, config.max_target_length), -1, dtype=jax.numpy.int32) - ) - output["targets_segmentation"] = tf.data.Dataset.from_tensor_slices( - np.full((1, config.max_target_length), -1, dtype=jax.numpy.int32) - ) - dataset = tf.data.Dataset.zip((output)) # pytype: disable=wrong-arg-types - dataset = dataset.repeat() - dataset = dataset.batch(config.global_batch_size_to_load // jax.process_count()) - return dataset + batch_size = config.global_batch_size_to_load // jax.process_count() + neg_ones = np.full((batch_size, config.max_target_length), -1, dtype=np.int32) + batch = { + "inputs": neg_ones, + "inputs_position": neg_ones, + "inputs_segmentation": neg_ones, + "targets": neg_ones, + "targets_position": neg_ones, + "targets_segmentation": neg_ones, + } + + def infinite_iterator(): + while True: + yield batch + + return infinite_iterator()