Skip to content

Pull Request: Adding HiRA integration into PEFT library#2668

Open
hqsiswiliam wants to merge 115 commits intohuggingface:mainfrom
hqsiswiliam:main
Open

Pull Request: Adding HiRA integration into PEFT library#2668
hqsiswiliam wants to merge 115 commits intohuggingface:mainfrom
hqsiswiliam:main

Conversation

@hqsiswiliam
Copy link
Copy Markdown

Feature request

This request proposes integrating HiRA (Hadamard High-Rank Adaptation) as described in the ICLR 2025 oral paper (https://openreview.net/pdf?id=TwJrTz9cRS) (https://iclr.cc/virtual/2025/oral/31839) and implemented in the hqsiswiliam/hira repository into the core PEFT library. This will enable users to apply HiRA through the familiar get_peft_model API and benefit from its high-rank updates without adding any inference overhead.

Motivation

General Motivation

PEFT methods like LoRA achieve parameter-efficient fine-tuning by injecting low-rank updates into pre-trained weights. While effective, purely low-rank adaptation can struggle to capture complex patterns in large language models.

1. Expressiveness grows with the rank

Empirically, increasing the LoRA rank in LLM training yields better downstream performance:

LoRA performance vs. rank
Higher LoRA rank correlates with improved task accuracy.

2. HiRA: Hadamard high-rank updates without extra parameters

HiRA sidesteps the expressiveness constraint by computing a Hadamard-enhanced update:

$$ \Delta W = W_0 \odot (A B) $$

HiRA update formula
HiRA uses the Hadamard product to inject high-rank structure into the frozen weight matrix $W_0$ via low-rank matrix $A$ and $B$.

3. Singular-value patterns

After training, HiRA exhibits a rich singular-value pattern, akin to full-rank fine-tuning (FFT), indicating its ability to model complex transformations without the expensive computational overhead:

Singular-value pattern comparison
HiRA’s singular-value distribution closely mirrors that of FFT.

4. Performance gains

Across commonsense reasoning benchmarks, HiRA outperforms LoRA and other PEFT baselines:

Commonsense reasoning performance
HiRA delivers notable accuracy improvements over baseline adapters.

5. No extra parameter or compute cost

Despite its high-rank behaviour, HiRA introduces no additional trainable parameters compared to LoRA:

Resource comparison: LoRA vs. HiRA
HiRA matches LoRA’s GRAM usage and training hours.

6. Complementary with LoRA (HiLoRA)

Combining HiRA and LoRA into a hybrid “HiLoRA” setup yields even stronger results than either method alone:

HiLoRA concept
HiLoRA performance gains
HiLoRA leverages both low-rank and Hadamard high-rank updates for better expressiveness.


By integrating HiRA into PEFT, users gain richer adaptation capability without sacrificing the parameter efficiency and usability that PEFT provides.

Your contribution

We would be pleased to submit a pull request to integrate HiRA class implementation into the PEFT framework. We welcome any suggestions for alternative integration approaches and appreciate any guidance on best practices.

Copy link
Copy Markdown
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR to add HiRA to PEFT. The method looks promising and the provided code is already quite mature.

When I started reading the paper, I was at first reminded of FedPara, aka LoHa, which is already integrated into PEFT, as that method also relies on the Hadamard product. However, IIUC, the two methods are still distinct: HiRA basically corresponds to LoRA, but instead of adding dW, we multiply it. In that way, it is much closer to LoRA than to LoHa. Still, I wanted to flag this, as I'm not sure you are aware (your paper doesn't seem to be reference FedPara).

At the moment, I haven't done a full in-depth review, but I think that makes more sense once we have completed the next step.

I noticed that you have formatted some unrelated files in method_comparison, could you please undo those changes? Usually, when you run make style, that directory should not be included.

I think a good next step is to add HiRA to the testing matrix we have in PEFT. For now, let's add some entries similar to the ones you can find here:

("Vanilla MLP 1 LoRA", "MLP", LoraConfig, {"target_modules": "lin0"}),
("Vanilla MLP 2 LoRA", "MLP", LoraConfig, {"target_modules": ["lin0"]}),
("Vanilla MLP 3 LoRA", "MLP", LoraConfig, {"target_modules": ["lin1"]}),

Since you also support embedding and conv layers, please make sure to include examples with those layers as well (basically, copy the relevant examples from LoRA and adjust them).

Then, please run pytest tests/test_custom_models.py -k "hira and not shira" -v and see if those tests pass. Once we get there, we can discuss the best next steps.

Comment thread src/peft/tuners/hira/__init__.py Outdated
Comment thread src/peft/tuners/hira/config.py Outdated
Comment thread src/peft/tuners/hira/config.py Outdated
Comment thread src/peft/utils/constants.py Outdated
Comment thread tests/test_hira.py Outdated
@github-actions
Copy link
Copy Markdown

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@BenjaminBossan
Copy link
Copy Markdown
Member

@hqsiswiliam Do you still plan on working on this PR?

@hqsiswiliam
Copy link
Copy Markdown
Author

@hqsiswiliam Do you still plan on working on this PR?

Hi, BenjaminBossan. Thanks for checking in! I’ll continue working on this PR over the next few days.

@hqsiswiliam
Copy link
Copy Markdown
Author

Hi @BenjaminBossan, I’ve merged the latest main into this branch. Please let me know if any additional changes are needed.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating the PR. Generally, not much is missing, but I found a few smaller issues, please check. A few HiRA tests are still failing, could you please check and resolve these errors?

On top of this, let's also add tests to test_config.py, test_decoder_models.py, test_feature_extraction_models.py, and test_seq_classifier.py. Moreover, I would strongly suggest to add an experiment to the MetaMathQA benchmark. This has the advantage that we can run an experiment to check that training works as expected and is more or less aligned with the expectations from the paper.

Comment thread src/peft/tuners/hira/model.py Outdated
Comment thread src/peft/tuners/hira/model.py Outdated
Comment thread src/peft/tuners/hira/layer.py Outdated
Comment thread src/peft/tuners/hira/layer.py
Comment thread src/peft/tuners/__init__.py Outdated
@hqsiswiliam
Copy link
Copy Markdown
Author

Thank you for your review. I have updated the code accordingly. I will also look into integrating an experiment on the MetaMathQA benchmark following the provided guidelines.

Copy link
Copy Markdown
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the latest updates.

Thank you for your review. I have updated the code accordingly. I will also look into integrating an experiment on the MetaMathQA benchmark following the provided guidelines.

Sounds good. Please ping me once that is done, + the missing unit tests I mentioned in my last comment.

Comment thread src/peft/tuners/__init__.py
@github-actions
Copy link
Copy Markdown

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@github-actions github-actions Bot closed this Apr 26, 2026
@hqsiswiliam
Copy link
Copy Markdown
Author

Hi @BenjaminBossan, sorry for the delay that led to this being auto-closed.

I've completed the remaining items from your Mar 21 review:

  • Added HiRA entries to test_config.py, test_decoder_models.py, test_feature_extraction_models.py, and test_seq_classifier.py
  • Re-ran all HiRA tests locally
  • Merged the latest main

Could you please reopen this PR for another review?

Copy link
Copy Markdown
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the last updates, there is not much that's still required.

I will also look into integrating an experiment on the MetaMathQA benchmark following the provided guidelines.

This is still missing, did you have time to check? Don't hesitate to ask me if anything is unclear.

Comment thread src/peft/tuners/hira/model.py Outdated
Comment thread src/peft/tuners/hira/layer.py Outdated
@hqsiswiliam
Copy link
Copy Markdown
Author

Thanks for the review! I've made the following updates:

  • Both comments addressed.
  • Added a MetaMathQA experiment at method_comparison/MetaMathQA/experiments/hira/llama-3.2-3B-rank32-lr4e-3/.
  • Merged with the latest upstream main.

Copy link
Copy Markdown
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your updates, the layer initialization now conforms with the rest of PEFT and the experiment looks promising. There are a few issues there which I flagged, otherwise the PR is ready.

Comment thread method_comparison/MetaMathQA/utils.py Outdated
optimizer_type: The name of a torch optimizer (e.g. AdamW) or a PEFT method ("lora+", "lora-fa")
optimizer_kwargs: The optimizer keyword arguments (lr etc.)
lr_scheduler: The learning rate scheduler (currently only None or 'cosine' are supported)
warmup_step_ratio: Fraction of total steps used for LR warmup (only relevant when lr_scheduler='cosine'), defaults to WARMUP_STEP_RATIO (0.1)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change necessary? I'd rather keep the benchmarking code fix for this PR. We can discuss adding the warmup ratio in a separate PR.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Results look nice, thanks for running. But let's remove the result file, we will run the experiment on our own machine to ensure that the different results are comparable.

"weight_decay": 0.0
},
"warmup_step_ratio": 0.05,
"use_gc": true
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not use gradient checkpointing, as all the other experiments also run without it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants