Skip to content

Commit 92a1a27

Browse files
evilsocketclaude
andcommitted
feat(moe): wire DiskProvider into MoE block loaders
When --expert-offload is set and tensor_storage is available: - Qwen3MoeBlock loads router gate from VarBuilder (stays in RAM), expert weights streamed via DiskExpertProvider from safetensors - Qwen3_5MoeBlock same pattern + shared expert stays in RAM - Without --expert-offload, behavior is unchanged (ResidentProvider) The full expert offloading pipeline is now functional end-to-end: CLI flag → Context → tensor_storage → DiskProvider → pread per expert Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 3a6301e commit 92a1a27

3 files changed

Lines changed: 74 additions & 2 deletions

File tree

cake-core/src/models/qwen3_5_moe/block.rs

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,31 @@ impl Forwarder for Qwen3_5MoeBlock {
8282
let rms_1 = load_rms_norm(h, eps, cfg.residual_rms_norm, vb.pp("input_layernorm"))?;
8383
let rms_2 =
8484
load_rms_norm(h, eps, cfg.residual_rms_norm, vb.pp("post_attention_layernorm"))?;
85-
let moe = Qwen3_5MoeSparseMlp::load(vb.pp("mlp"), cfg, ctx.backend.clone())?;
85+
let moe = if let Some(storage) = &ctx.tensor_storage {
86+
// Expert offload: stream routed expert weights from disk
87+
use candle_nn::linear_no_bias as linear;
88+
let layer_prefix = format!("{name}.mlp");
89+
let provider: std::sync::Arc<dyn crate::models::common::expert_provider::ExpertProvider> =
90+
std::sync::Arc::new(crate::models::common::disk_expert_provider::DiskExpertProvider::new(
91+
storage.clone(), layer_prefix, cfg.num_experts, ctx.device.clone(), ctx.dtype,
92+
));
93+
let mlp_vb = vb.pp("mlp");
94+
let gate_w = mlp_vb.pp("gate").get((cfg.num_experts, h), "weight")?;
95+
let gate = candle_nn::Linear::new(gate_w, None);
96+
let si = cfg.shared_expert_intermediate_size.expect("shared_expert_intermediate_size");
97+
let se = mlp_vb.pp("shared_expert");
98+
let shared_gate_proj = linear(h, si, se.pp("gate_proj"))?;
99+
let shared_up_proj = linear(h, si, se.pp("up_proj"))?;
100+
let shared_down_proj = linear(si, h, se.pp("down_proj"))?;
101+
let seg_w = mlp_vb.pp("shared_expert_gate").get((1, h), "weight")?;
102+
let shared_expert_gate = candle_nn::Linear::new(seg_w, None);
103+
Qwen3_5MoeSparseMlp::with_provider(
104+
gate, provider, shared_gate_proj, shared_up_proj, shared_down_proj,
105+
shared_expert_gate, cfg.num_experts, cfg.num_experts_per_tok, ctx.backend.clone(),
106+
)
107+
} else {
108+
Qwen3_5MoeSparseMlp::load(vb.pp("mlp"), cfg, ctx.backend.clone())?
109+
};
86110

87111
if layer_type == "full_attention" {
88112
let attn = Qwen3_5FullAttention::load(vb.pp("self_attn"), cfg, ctx.backend.clone())?;

cake-core/src/models/qwen3_5_moe/moe.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,33 @@ impl Qwen3_5MoeSparseMlp {
100100
})
101101
}
102102

103+
/// Construct with a pre-built expert provider (for disk offloading).
104+
/// Shared expert + router are loaded from VarBuilder (stay in RAM).
105+
#[allow(clippy::too_many_arguments)]
106+
pub fn with_provider(
107+
gate: Linear,
108+
expert_provider: SharedExpertProvider,
109+
shared_gate_proj: Linear,
110+
shared_up_proj: Linear,
111+
shared_down_proj: Linear,
112+
shared_expert_gate: Linear,
113+
num_experts: usize,
114+
num_experts_per_tok: usize,
115+
backend: Arc<dyn ComputeBackend>,
116+
) -> Self {
117+
Self {
118+
gate,
119+
expert_provider,
120+
shared_gate_proj,
121+
shared_up_proj,
122+
shared_down_proj,
123+
shared_expert_gate,
124+
num_experts,
125+
num_experts_per_tok,
126+
backend,
127+
}
128+
}
129+
103130
pub fn forward(&self, x: &Tensor) -> anyhow::Result<Tensor> {
104131
let (b, s, h) = x.dims3().map_err(|e| anyhow!("moe dims3: {e}"))?;
105132
let n_tok = b * s;

cake-core/src/models/qwen3_moe/block.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,28 @@ impl Forwarder for Qwen3MoeBlock {
4141
let cfg = ctx.config.as_ref().expect("No config specified");
4242

4343
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg, ctx.backend.clone())?;
44-
let moe = SparseMoeMlp::load(vb.pp("mlp"), cfg, ctx.backend.clone())?;
44+
45+
let moe = if let Some(storage) = &ctx.tensor_storage {
46+
// Expert offload: stream weights from disk via DiskProvider
47+
let layer_prefix = format!("{name}.mlp");
48+
let provider: std::sync::Arc<dyn crate::models::common::expert_provider::ExpertProvider> =
49+
std::sync::Arc::new(crate::models::common::disk_expert_provider::DiskExpertProvider::new(
50+
storage.clone(),
51+
layer_prefix,
52+
cfg.num_experts,
53+
ctx.device.clone(),
54+
ctx.dtype,
55+
));
56+
// Load router gate from VarBuilder (it's small, stays in RAM)
57+
let gate_w = vb.pp("mlp").pp("gate").get((cfg.num_experts, cfg.hidden_size), "weight")?;
58+
let gate = candle_nn::Linear::new(gate_w, None);
59+
SparseMoeMlp::with_provider(
60+
gate, provider, cfg.num_experts, cfg.num_experts_per_tok,
61+
cfg.norm_topk_prob, ctx.backend.clone(),
62+
)
63+
} else {
64+
SparseMoeMlp::load(vb.pp("mlp"), cfg, ctx.backend.clone())?
65+
};
4566

4667
let eps = cfg.rms_norm_eps;
4768
let h = cfg.hidden_size;

0 commit comments

Comments
 (0)