Skip to content

Commit ccfe7ee

Browse files
committed
Disable FMul associativity and factoring rules that cause egraph explosion
FMul associativity and FP factoring rules interact to cause exponential egraph growth on shaders with many shared multiply chains (e.g. sky-shader with 1,264 FMul nodes). FMul associativity alone creates 305K+ matches via Catalan-number re-parenthesization across shared e-classes; combined with factoring (304K matches), tuples explode from 2,875 to 916K in one iteration, causing OOM/SIGKILL. Constant chain merging (FMul(FMul(x, const_a), const_b) → FMul(x, a*b)) is handled by dedicated rules and does not need general associativity.
1 parent a897143 commit ccfe7ee

6 files changed

Lines changed: 42 additions & 38 deletions

File tree

rust/spirv-tools-opt/src/direct/emit.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,9 +1301,7 @@ fn emit_pattern(
13011301
// Refine it to avoid signed/unsigned mismatches:
13021302
// 1. Signed int ops (ConvertSToF, SLessThan, etc.) → signed int type
13031303
// 2. Other cross-class ops → infer from Sym operand's actual type
1304-
let operand_type = if operand_class == TypeClass::Int
1305-
&& result_class != operand_class
1306-
{
1304+
let operand_type = if operand_class == TypeClass::Int && result_class != operand_class {
13071305
match opcode {
13081306
Op::ConvertSToF
13091307
| Op::ConvertFToS
@@ -1315,9 +1313,9 @@ fn emit_pattern(
13151313
| Op::SRem
13161314
| Op::SMod
13171315
| Op::SConvert => ctx.signed_int32_type.unwrap_or(operand_type),
1318-
_ => infer_operand_type_from_args(
1319-
args, arity, operand_class, operand_type, ctx,
1320-
),
1316+
_ => {
1317+
infer_operand_type_from_args(args, arity, operand_class, operand_type, ctx)
1318+
}
13211319
}
13221320
} else if result_class != operand_class {
13231321
// Non-int cross-class: infer from Sym args for consistency

rust/spirv-tools-opt/src/direct/mod.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,8 +1037,7 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
10371037
// SPIR-V types (e.g. two constants with the same bit
10381038
// pattern but different type IDs, or Vec4/Vec2 both as
10391039
// Expr). CopyObject requires operand type == result type.
1040-
let type_matches = ctx.id_to_type.get(&id)
1041-
== ctx.id_to_type.get(&alias_id);
1040+
let type_matches = ctx.id_to_type.get(&id) == ctx.id_to_type.get(&alias_id);
10421041
if type_matches {
10431042
id_aliases.insert(id, alias_id);
10441043
used_ids.insert(alias_id);
@@ -1091,8 +1090,8 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
10911090
// (e.g. Vec4 and Vec2 both as Expr). Emitting a
10921091
// CopyObject with mismatched types would cause
10931092
// validation errors.
1094-
let type_matches = ctx.id_to_type.get(&id)
1095-
== ctx.id_to_type.get(&final_id);
1093+
let type_matches =
1094+
ctx.id_to_type.get(&id) == ctx.id_to_type.get(&final_id);
10961095
if type_matches {
10971096
id_aliases.insert(id, final_id);
10981097
used_ids.insert(final_id);
@@ -2357,9 +2356,7 @@ fn find_spirv_type(module: &Module, opcode: Op, width: Option<u32>) -> Option<Wo
23572356
.find(|inst| {
23582357
inst.class.opcode == opcode
23592358
&& match width {
2360-
Some(w) => {
2361-
inst.operands.first() == Some(&rspirv::dr::Operand::LiteralBit32(w))
2362-
}
2359+
Some(w) => inst.operands.first() == Some(&rspirv::dr::Operand::LiteralBit32(w)),
23632360
None => true,
23642361
}
23652362
})

rust/spirv-tools-opt/src/egglog_opt/tests.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6587,10 +6587,7 @@ fn test_boolconst_float_comparison_reflexive() {
65876587

65886588
// FOrdLt(x,x) IS correctly folded to BoolConst(0)
65896589
let check2 = egraph.parse_and_run_program(None, "(check (= lt_root false_val))");
6590-
assert!(
6591-
check2.is_ok(),
6592-
"FOrdLt(x, x) should fold to BoolConst(0)"
6593-
);
6590+
assert!(check2.is_ok(), "FOrdLt(x, x) should fold to BoolConst(0)");
65946591
}
65956592

65966593
#[test]

rust/spirv-tools-opt/src/rules/arithmetic.egg

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,11 @@
209209
; We only want to FACTOR (collect terms), not DISTRIBUTE (expand terms).
210210

211211
; (x * a) + (x * b) => x * (a + b) - factoring out common term
212+
; NOTE: These are safe because there is no general integer Mul associativity
213+
; rule (only constant-folding variants). Without associativity creating new
214+
; Mul groupings, the feedback loop that causes FP factoring explosion cannot
215+
; form. If a general (Mul (Mul a b) c) -> (Mul a (Mul b c)) rule is ever
216+
; added, these must be disabled too (see floating_point.egg factoring).
212217
(rule ((= e (Add (Mul x a) (Mul x b))))
213218
((union e (Mul x (Add a b)))))
214219
(rule ((= e (Add (Mul a x) (Mul b x))))

rust/spirv-tools-opt/src/rules/floating_point.egg

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -201,18 +201,30 @@
201201
; 2. Any FAdd(x, Y) matches FAdd(FAdd(x, 0), Y) and cascades endlessly
202202
;
203203
; Safe variant: Only apply when 'b' is an FAdd/FMul (not an FConst that could be 0 or 1).
204+
; NOTE: FAdd associativity is safe with current workloads (only 11 matches at iter 3
205+
; on sky-shader), but could exhibit Catalan blowup if a shader has deep FAdd chains
206+
; with shared sub-expressions. Monitor if new shaders cause FAdd explosion.
204207
(rule ((= e (FAdd (FAdd a b) c)) (= b (FAdd _ _)))
205208
((union e (FAdd a (FAdd b c)))))
206-
(rule ((= e (FMul (FMul a b) c)) (= b (FMul _ _)))
207-
((union e (FMul a (FMul b c)))))
209+
; DISABLED: FMul associativity causes exponential blowup on its own.
210+
; With ~1,264 FMul nodes sharing e-classes (common in sky/lighting shaders),
211+
; re-parenthesization creates 305K+ matches and 455K FMul tuples in one
212+
; iteration via Catalan-number growth across shared sub-expressions.
213+
; Constant chain merging (FMul(FMul(x, const_a), const_b) → FMul(x, const_a*b))
214+
; is handled by dedicated rules below and does NOT need this.
215+
; (rule ((= e (FMul (FMul a b) c)) (= b (FMul _ _)))
216+
; ((union e (FMul a (FMul b c)))))
208217

209218
; Factoring: a*c + b*c = (a+b)*c
210-
; One-directional to prevent expansion that causes e-graph explosion.
211-
; We recognize expanded patterns and factor them, but don't distribute.
212-
(rule ((= e (FAdd (FMul a c) (FMul b c))))
213-
((union e (FMul (FAdd a b) c))))
214-
(rule ((= e (FAdd (FMul c a) (FMul c b))))
215-
((union e (FMul c (FAdd a b)))))
219+
; DISABLED: Creates a feedback loop with FMul associativity above.
220+
; Factoring creates new FMul+FAdd nodes, associativity re-parenthesizes the
221+
; new FMul nodes, exposing more factoring opportunities → exponential blowup.
222+
; In the sky-shader, this explodes from 2,875 to 916,652 tuples in one iteration
223+
; (305K associativity matches + 304K factoring matches).
224+
; (rule ((= e (FAdd (FMul a c) (FMul b c))))
225+
; ((union e (FMul (FAdd a b) c))))
226+
; (rule ((= e (FAdd (FMul c a) (FMul c b))))
227+
; ((union e (FMul c (FAdd a b)))))
216228

217229
; x + x = 2*x - ONE-DIRECTIONAL to prevent explosion
218230
; (FMul x 2 -> FAdd x x would expand forever)

rust/spirv-tools-opt/tests/opt_block_cli.rs

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5443,20 +5443,18 @@ fn build_loop_with_selection_in_continue_block() -> Vec<u32> {
54435443
std::iter::empty(),
54445444
)
54455445
.unwrap();
5446-
let val = b.load(int, None, counter, None, std::iter::empty()).unwrap();
5447-
let cond = b
5448-
.s_less_than(bool_ty, None, val, c10)
5449-
.expect("less than");
5446+
let val = b
5447+
.load(int, None, counter, None, std::iter::empty())
5448+
.unwrap();
5449+
let cond = b.s_less_than(bool_ty, None, val, c10).expect("less than");
54505450
b.branch_conditional(cond, continue_label, merge_label, std::iter::empty())
54515451
.unwrap();
54525452

54535453
// Continue block: contains a selection construct
54545454
b.begin_block(Some(continue_label)).unwrap();
54555455
let c2 = b.constant_bit32(int, 2);
54565456
let is_even_rem = b.s_mod(int, None, val, c2).unwrap();
5457-
let is_even = b
5458-
.i_equal(bool_ty, None, is_even_rem, c0)
5459-
.expect("is_even");
5457+
let is_even = b.i_equal(bool_ty, None, is_even_rem, c0).expect("is_even");
54605458
b.selection_merge(sel_merge_label, SelectionControl::NONE)
54615459
.unwrap();
54625460
b.branch_conditional(is_even, sel_then_label, sel_else_label, std::iter::empty())
@@ -5536,8 +5534,7 @@ fn build_store_with_typed_value() -> Vec<u32> {
55365534
let _sint = b.type_int(32, 1);
55375535
let void = b.type_void();
55385536
let func_ty = b.type_function(void, vec![uint]);
5539-
let ptr_uint =
5540-
b.type_pointer(None, rspirv::spirv::StorageClass::Function, uint);
5537+
let ptr_uint = b.type_pointer(None, rspirv::spirv::StorageClass::Function, uint);
55415538

55425539
let func = b
55435540
.begin_function(void, None, FunctionControl::NONE, func_ty)
@@ -5610,8 +5607,7 @@ fn build_float_to_signed_int_module() -> Vec<u32> {
56105607
let float = b.type_float(32, None);
56115608
let void = b.type_void();
56125609
let func_ty = b.type_function(void, vec![sint]);
5613-
let ptr_sint =
5614-
b.type_pointer(None, rspirv::spirv::StorageClass::Function, sint);
5610+
let ptr_sint = b.type_pointer(None, rspirv::spirv::StorageClass::Function, sint);
56155611

56165612
let func = b
56175613
.begin_function(void, None, FunctionControl::NONE, func_ty)
@@ -5684,8 +5680,7 @@ fn build_matrix_multiply_chain() -> Vec<u32> {
56845680
let mat4 = b.type_matrix(vec4, 4);
56855681
let void = b.type_void();
56865682
let func_ty = b.type_function(void, vec![mat4, mat4, mat4, vec4]);
5687-
let ptr_vec4 =
5688-
b.type_pointer(None, rspirv::spirv::StorageClass::Function, vec4);
5683+
let ptr_vec4 = b.type_pointer(None, rspirv::spirv::StorageClass::Function, vec4);
56895684

56905685
let func = b
56915686
.begin_function(void, None, FunctionControl::NONE, func_ty)

0 commit comments

Comments
 (0)