diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index bc019444a7d..1ec85936f7a 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -245,24 +245,27 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): for param in self.model_.parameters(): if isinstance(param, TorchAOBaseTensor): param.requires_grad = False + if missing: + missing_weights = [fqn for fqn in missing if fqn.endswith(".weight")] + if missing_weights: + raise ValueError( + f"The provided checkpoint is missing the following weights that are expected by the model: {missing_weights}. Please fix the fqn's in your checkpoint to match." + ) + if unexpected: + if self.verbose: + print(f"Unexpected keys: {unexpected}") else: - print("Checkpoint not provided, defaulting weights to zeros.") + print("Checkpoint not provided, using default initialization.") + # Because we loaded onto meta device, it is annoying to now load onto cpu + # with the standard random initialization. self.model_.to_empty(device="cpu") - # Need to provide concrete values for meta-initialized tensors for quantization. - # otherwise it is just filled with nan's. - for p in self.model_.parameters(): - p.data.fill_(0) - for b in self.model_.buffers(): - b.data.fill_(0) - if missing: - missing_weights = [fqn for fqn in missing if fqn.endswith(".weight")] - if missing_weights: - raise ValueError( - f"The provided checkpoint is missing the following weights that are expected by the model: {missing_weights}. Please fix the fqn's in your checkpoint to match." - ) - if unexpected: - if self.verbose: - print(f"Unexpected keys: {unexpected}") + + def weight_reset(m): + reset_parameters = getattr(m, "reset_parameters", None) + if callable(reset_parameters): + m.reset_parameters() + + self.model_.apply(weight_reset) # Prune the input layer if input_prune_map is provided if input_prune_map is not None: