Skip to content

Commit fd5694d

Browse files
authored
Consolidate lm-eval scripts: merge AnyModel auto-detection into lm_eval_hf.py (#1206)
## Summary - Merge `examples/puzzletron/evaluation/lm_eval_anymodel.py` into the existing `examples/llm_eval/lm_eval_hf.py` so there is a single evaluation entry point for both standard HF and AnyModel/Puzzletron checkpoints. - AnyModel support is auto-detected at load time via `resolve_descriptor_from_pretrained`; the puzzletron extra is optional ## Notes AnyModel auto-detection uses `resolve_descriptor_from_pretrained`, which currently relies on a hardcoded `_MODEL_TYPE_TO_DESCRIPTOR` dict that must be kept in sync manually with descriptor registrations. This should be addressed in the future. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Automated detection and correct loading of Puzzletron heterogeneous pruned checkpoints via the main evaluation entrypoint. * **Documentation** * Added a “Heterogeneous Pruned Checkpoints (Puzzletron)” subsection with install notes, example evaluation commands, and smoke-test guidance. * **Chores** * Removed the separate Puzzletron evaluation script and consolidated evaluation into the primary lm-eval workflow. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: jrausch <jrausch@nvidia.com>
1 parent 25266b8 commit fd5694d

5 files changed

Lines changed: 65 additions & 119 deletions

File tree

.pre-commit-config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ repos:
9696
examples/llm_eval/modeling.py|
9797
examples/llm_qat/main.py|
9898
examples/llm_sparsity/weight_sparsity/finetune.py|
99-
examples/puzzletron/evaluation/lm_eval_anymodel.py|
10099
examples/specdec_bench/specdec_bench/models/specbench_medusa.py|
101100
examples/speculative_decoding/main.py|
102101
examples/speculative_decoding/medusa_utils.py|

examples/llm_eval/README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,22 @@ accelerate launch --multi_gpu --num_processes <num_copies_of_your_model> \
4040
--batch_size 4
4141
```
4242

43+
### Heterogeneous Pruned Checkpoints (Puzzletron)
44+
45+
Heterogeneous pruned checkpoints produced by Puzzletron are automatically detected and loaded with the appropriate model patcher. No additional flags are needed beyond specifying the checkpoint path:
46+
47+
```sh
48+
python lm_eval_hf.py --model hf \
49+
--model_args pretrained=path/to/anymodel/checkpoint,dtype=bfloat16,parallelize=True \
50+
--tasks mmlu \
51+
--num_fewshot 5 \
52+
--batch_size 4
53+
```
54+
55+
For a quick smoke test, add `--limit 10`.
56+
57+
> **Note:** Requires the `puzzletron` extra to be installed (`pip install -e ".[puzzletron]"`).
58+
4359
### Quantized (simulated)
4460

4561
- For simulated quantization with any of the default quantization formats:

examples/llm_eval/lm_eval_hf.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3737
# See the License for the specific language governing permissions and
3838
# limitations under the License.
39+
import contextlib
3940
import warnings
4041

4142
import datasets
@@ -50,9 +51,33 @@
5051
from modelopt.torch.quantization.utils import is_quantized
5152
from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified
5253

54+
try:
55+
import modelopt.torch.puzzletron.anymodel.models # noqa: F401
56+
from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor_factory import (
57+
resolve_descriptor_from_pretrained,
58+
)
59+
from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher
60+
61+
_ANYMODEL_AVAILABLE = True
62+
except ImportError:
63+
_ANYMODEL_AVAILABLE = False
64+
65+
66+
def _anymodel_patcher_context(pretrained, trust_remote_code=False):
67+
"""Return a deci_x_patcher context if *pretrained* is a Puzzletron checkpoint, else a no-op."""
68+
if not _ANYMODEL_AVAILABLE or not pretrained:
69+
return contextlib.nullcontext()
70+
try:
71+
descriptor = resolve_descriptor_from_pretrained(
72+
pretrained, trust_remote_code=trust_remote_code
73+
)
74+
except (ValueError, AttributeError):
75+
return contextlib.nullcontext()
76+
return deci_x_patcher(model_descriptor=descriptor)
77+
5378

5479
def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | None = None) -> T:
55-
"""Overrides the HFLM.create_from_arg_obj"""
80+
"""Override HFLM.create_from_arg_obj to add quantization, sparsity, and Puzzletron support."""
5681

5782
quant_cfg = arg_dict.pop("quant_cfg", None)
5883
auto_quantize_bits = arg_dict.pop("auto_quantize_bits", None)
@@ -72,7 +97,10 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
7297
# Enable automatic save/load of modelopt state huggingface checkpointing
7398
mto.enable_huggingface_checkpointing()
7499

75-
model_obj = cls(**arg_dict, **additional_config)
100+
with _anymodel_patcher_context(
101+
arg_dict.get("pretrained"), arg_dict.get("trust_remote_code", False)
102+
):
103+
model_obj = cls(**arg_dict, **additional_config)
76104
model_obj.tokenizer.padding_side = "left"
77105
if is_quantized(model_obj.model):
78106
# return if model is already quantized
@@ -109,10 +137,28 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
109137
return model_obj
110138

111139

140+
def create_from_arg_string(
141+
cls: type[T], arg_string: str, additional_config: dict | None = None
142+
) -> T:
143+
"""Override HFLM.create_from_arg_string to support Puzzletron checkpoints."""
144+
args = utils.simple_parse_args_string(arg_string)
145+
additional_config = {} if additional_config is None else additional_config
146+
args2 = {k: v for k, v in additional_config.items() if v is not None}
147+
148+
mto.enable_huggingface_checkpointing()
149+
150+
with _anymodel_patcher_context(args.get("pretrained"), args.get("trust_remote_code", False)):
151+
model_obj = cls(**args, **args2)
152+
153+
return model_obj
154+
155+
112156
HFLM.create_from_arg_obj = classmethod(create_from_arg_obj)
157+
HFLM.create_from_arg_string = classmethod(create_from_arg_string)
113158

114159

115160
def setup_parser_with_modelopt_args():
161+
"""Extend the lm-eval argument parser with ModelOpt quantization and sparsity options."""
116162
parser = setup_parser()
117163
parser.add_argument(
118164
"--quant_cfg",

examples/puzzletron/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ The plot shows how token accuracy changes with different compression rates. High
235235
Evaluate AnyModel checkpoints using [lm-eval](https://github.com/EleutherAI/lm-evaluation-harness) directly.
236236
237237
```bash
238-
python examples/puzzletron/evaluation/lm_eval_anymodel.py \
238+
python examples/llm_eval/lm_eval_hf.py \
239239
--model hf \
240240
--model_args pretrained=path/to/checkpoint,dtype=bfloat16,parallelize=True \
241241
--tasks mmlu \

examples/puzzletron/evaluation/lm_eval_anymodel.py

Lines changed: 0 additions & 115 deletions
This file was deleted.

0 commit comments

Comments
 (0)