Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,24 +245,27 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
for param in self.model_.parameters():
if isinstance(param, TorchAOBaseTensor):
param.requires_grad = False
if missing:
missing_weights = [fqn for fqn in missing if fqn.endswith(".weight")]
if missing_weights:
raise ValueError(
f"The provided checkpoint is missing the following weights that are expected by the model: {missing_weights}. Please fix the fqn's in your checkpoint to match."
)
if unexpected:
if self.verbose:
print(f"Unexpected keys: {unexpected}")
else:
print("Checkpoint not provided, defaulting weights to zeros.")
print("Checkpoint not provided, using default initialization.")
# Because we loaded onto meta device, it is annoying to now load onto cpu
# with the standard random initialization.
self.model_.to_empty(device="cpu")
# Need to provide concrete values for meta-initialized tensors for quantization.
# otherwise it is just filled with nan's.
for p in self.model_.parameters():
p.data.fill_(0)
for b in self.model_.buffers():
b.data.fill_(0)
if missing:
missing_weights = [fqn for fqn in missing if fqn.endswith(".weight")]
if missing_weights:
raise ValueError(
f"The provided checkpoint is missing the following weights that are expected by the model: {missing_weights}. Please fix the fqn's in your checkpoint to match."
)
if unexpected:
if self.verbose:
print(f"Unexpected keys: {unexpected}")

def weight_reset(m):
reset_parameters = getattr(m, "reset_parameters", None)
if callable(reset_parameters):
m.reset_parameters()

self.model_.apply(weight_reset)

# Prune the input layer if input_prune_map is provided
if input_prune_map is not None:
Expand Down
Loading