Skip to content

Commit b98b5a7

Browse files
committed
symbolizing composite PrimExpr
1 parent 2265bd1 commit b98b5a7

1 file changed

Lines changed: 62 additions & 0 deletions

File tree

src/relax/transform/canonicalize_bindings.cc

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <tvm/relax/expr_functor.h>
3131
#include <tvm/relax/struct_info.h>
3232
#include <tvm/relax/transform.h>
33+
#include <tvm/tir/analysis.h>
3334
#include <tvm/tir/stmt_functor.h>
3435

3536
namespace tvm {
@@ -134,13 +135,74 @@ class SymbolicVarCanonicalizer : public ExprMutator {
134135
return output;
135136
}
136137

138+
Expr VisitExpr_(const ShapeExprNode* op) override {
139+
// For each dimension, check if it is a composite expression that symbolization
140+
ffi::Array<PrimExpr> new_values;
141+
bool changed = false;
142+
143+
for (const auto& dim : op->values) {
144+
PrimExpr new_dim = VisitPrimExpr(dim);
145+
146+
// Check if this is a composite expression (not a constant or simple variable)
147+
if (IsCompositePrimExpr(new_dim)) {
148+
// Introduce a new symbolic variable for this composite expression
149+
tir::Var symbolic_var = CreateSymbolicVar(new_dim);
150+
new_values.push_back(symbolic_var);
151+
changed = true;
152+
} else {
153+
new_values.push_back(new_dim);
154+
if (!new_dim.same_as(dim)) {
155+
changed = true;
156+
}
157+
}
158+
}
159+
160+
if (!changed) {
161+
return ffi::GetRef<Expr>(op);
162+
}
163+
164+
return ShapeExpr(new_values);
165+
}
166+
137167
private:
138168
struct KnownValue {
139169
PrimExpr expr;
140170
MatchCast source;
141171
};
142172

173+
bool IsCompositePrimExpr(const PrimExpr& expr) {
174+
// Constants and simple variables are not composite
175+
if (expr->IsInstance<tir::IntImmNode>() || expr->IsInstance<tir::FloatImmNode>() ||
176+
expr->IsInstance<tir::VarNode>()) {
177+
return false;
178+
}
179+
180+
// Check if the expression contains variables
181+
auto vars = tir::UndefinedVars(expr);
182+
183+
// If it has variables, it's composite (e.g., x * y, x + 1, etc.)
184+
return vars.size() >= 1;
185+
}
186+
187+
tir::Var CreateSymbolicVar(const PrimExpr& expr) {
188+
tir::Var symbolic_var("composite_" + std::to_string(composite_counter_++), expr->dtype);
189+
190+
// Create PrimValue for the composite expression
191+
PrimValue prim_val(expr);
192+
PrimStructInfo prim_sinfo(symbolic_var);
193+
Var relax_var("comp_val_" + std::to_string(composite_counter_ - 1), prim_sinfo);
194+
195+
// Emit MatchCast to define the symbolic variable
196+
auto match_cast = MatchCast(relax_var, prim_val, prim_sinfo);
197+
builder_->EmitNormalized(match_cast);
198+
199+
known_values_[symbolic_var] = KnownValue{expr, match_cast};
200+
201+
return symbolic_var;
202+
}
203+
143204
std::unordered_map<tir::Var, KnownValue> known_values_;
205+
int composite_counter_ = 0;
144206
};
145207

146208
struct CanonicalizationPlan {

0 commit comments

Comments
 (0)