Skip to content

Comments

Integrate Engram into custom model#3183

Open
RissyRan wants to merge 1 commit intomainfrom
new_engram_integration
Open

Integrate Engram into custom model#3183
RissyRan wants to merge 1 commit intomainfrom
new_engram_integration

Conversation

@RissyRan
Copy link
Collaborator

@RissyRan RissyRan commented Feb 18, 2026

Description

Integrate Engram feature into a custom model

  • Add configs into base.yml and type.py, integrates with deepseek-custom model
  • This PR supports unscan version of Engram, and scan version will be next PR.
  • Tried to initialize the hash map to model level for one time initialization, but met various JAX initialization error (see more in b/478294699 and this PR). Also added a comment there.
  • Currently, to make it work, you will see mix of jnp and np as some operations running on CPU to avoid ConcretizationTypeError and other issues.

Tests

  • Expect github runners to pass
  • Unit tests still passing for Engram: link
  • End-to-end unscan training for custom model: link
  • Sanity check for DS v2 (expect no impact)
Before change:

I0218 19:13:51.139775 139970093350464 max_utils.py:697] 	Using (GB) 43.98 / 95.74 (45.936912%) on TPU_1(process=0,(1,0,0,0))
I0218 19:15:18.498541 139970093350464 metric_logger.py:181] completed step: 19, seconds: 4.366, TFLOP/s/device: 123.145, Tokens/s/device: 7505.830, total_weights: 131072, loss: 8.135

After change:

I0218 19:05:35.490278 140385131265600 max_utils.py:697] 	Using (GB) 43.99 / 95.74 (45.947357%) on TPU_0(process=0,(0,0,0,0))
I0218 19:07:02.849351 140385131265600 metric_logger.py:181] completed step: 19, seconds: 4.366, TFLOP/s/device: 123.144, Tokens/s/device: 7505.806, total_weights: 131072, loss: 8.135

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Feb 18, 2026

Codecov Report

❌ Patch coverage is 94.00000% with 3 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/layers/deepseek.py 86.95% 2 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

@RissyRan RissyRan force-pushed the new_engram_integration branch from aa6e11c to 5c4f07b Compare February 18, 2026 20:24
@github-actions
Copy link

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I have some minor comment. As a follow up, we could try placing ngram-hash-mapping globally for efficiency: either (1) at model level or (2) at data pipeline level, treat it as a special tokenizer.

engram_vocab_bases=config.engram_vocab_sizes,
max_ngram_size=config.engram_max_ngram_size,
engram_num_heads=config.engram_num_heads,
layer_ids=[layer_idx],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be initialized with global layer_ids, as in unit test. Otherwise hashing behavior is different.

layer_ids=self.config.engram_layer_ids,

Comment on lines +1106 to +1107
# List of vocab sizes for each n-gram order.
engram_vocab_sizes: []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) This config can be confusing: This is not the actual size, but the search start for prime sizes.

For accuracy, I suggest: engram_vocab_bases, and # List of minimum head vocab sizes for each n-gram order in 2...engram_max_ngram_size.

For example, if the max_ngram = 2, num_head=3, engram_vocab_bases = [4]

  • ngram2, head0, find_next_unseen_prime > 4-1 -- size = 5
  • ngram2, head1, find_next_unseen_prime > 5 --- size = 7
  • ngram2, head2, find_next_unseen_prime > 7 --- size = 11

(2) to reflect the hierarchy (engram -> 2...max ngram -> num_head per ngram), the order could be:

engram_max_ngram_size 
engram_num_heads, engram_head_dim
engram_vocab_sizes

and self.gradient_accumulation_steps > 1
):
raise ValueError("FP8 quantization is not compatible with gradient accumulation.")
if self.engram_layers:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the checks! Might be good to check len(engram_vocab_sizes) == (engram_max_ngram_size -1)

@@ -181,10 +182,14 @@ def __init__(
self.compressed_tokenizer = CompressedTokenizer(tokenizer)
self.tokenizer_vocab_size = len(self.compressed_tokenizer)
if pad_id is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When pad_id is not none, we init self.pad_id, which is required by _get_ngram_hashes for padding.

When it is None, it is unclear from reference
https://github.com/deepseek-ai/Engram/blob/fb7f84a21f91223715394a33a1dc24bbfb7f788e/engram_demo_v1.py#L211
Maybe raise error, and add a todo in comment?

Comment on lines +110 to +111
# TODO(ranran): Refactor NgramHashMapping to initialize once globally or at the model level.
# Moving this to decoders.py currently causes JAX initialization errors.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, can we initialize it in data pipeline? treating it as a special tokenizer, operating with np and cpu, similar to existing hf_tokenizer

Copy link
Collaborator Author

@RissyRan RissyRan Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I tried both in this draft PR:

  1. put in the data pipeline: https://screenshot.googleplex.com/5dhowKp7EjzMuud - not working.
  2. put in the model level: https://screenshot.googleplex.com/5FW7nxAXu9fFDCA - not working.

Met different issues here and there. I think if bandwidth allows, we could dive deep more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants