|
30 | 30 | #include <tvm/relax/expr_functor.h> |
31 | 31 | #include <tvm/relax/struct_info.h> |
32 | 32 | #include <tvm/relax/transform.h> |
| 33 | +#include <tvm/tir/analysis.h> |
33 | 34 | #include <tvm/tir/stmt_functor.h> |
34 | 35 |
|
35 | 36 | namespace tvm { |
@@ -134,13 +135,74 @@ class SymbolicVarCanonicalizer : public ExprMutator { |
134 | 135 | return output; |
135 | 136 | } |
136 | 137 |
|
| 138 | + Expr VisitExpr_(const ShapeExprNode* op) override { |
| 139 | + // For each dimension, check if it is a composite expression that symbolization |
| 140 | + ffi::Array<PrimExpr> new_values; |
| 141 | + bool changed = false; |
| 142 | + |
| 143 | + for (const auto& dim : op->values) { |
| 144 | + PrimExpr new_dim = VisitPrimExpr(dim); |
| 145 | + |
| 146 | + // Check if this is a composite expression (not a constant or simple variable) |
| 147 | + if (IsCompositePrimExpr(new_dim)) { |
| 148 | + // Introduce a new symbolic variable for this composite expression |
| 149 | + tir::Var symbolic_var = CreateSymbolicVar(new_dim); |
| 150 | + new_values.push_back(symbolic_var); |
| 151 | + changed = true; |
| 152 | + } else { |
| 153 | + new_values.push_back(new_dim); |
| 154 | + if (!new_dim.same_as(dim)) { |
| 155 | + changed = true; |
| 156 | + } |
| 157 | + } |
| 158 | + } |
| 159 | + |
| 160 | + if (!changed) { |
| 161 | + return ffi::GetRef<Expr>(op); |
| 162 | + } |
| 163 | + |
| 164 | + return ShapeExpr(new_values); |
| 165 | + } |
| 166 | + |
137 | 167 | private: |
138 | 168 | struct KnownValue { |
139 | 169 | PrimExpr expr; |
140 | 170 | MatchCast source; |
141 | 171 | }; |
142 | 172 |
|
| 173 | + bool IsCompositePrimExpr(const PrimExpr& expr) { |
| 174 | + // Constants and simple variables are not composite |
| 175 | + if (expr->IsInstance<tir::IntImmNode>() || expr->IsInstance<tir::FloatImmNode>() || |
| 176 | + expr->IsInstance<tir::VarNode>()) { |
| 177 | + return false; |
| 178 | + } |
| 179 | + |
| 180 | + // Check if the expression contains variables |
| 181 | + auto vars = tir::UndefinedVars(expr); |
| 182 | + |
| 183 | + // If it has variables, it's composite (e.g., x * y, x + 1, etc.) |
| 184 | + return vars.size() >= 1; |
| 185 | + } |
| 186 | + |
| 187 | + tir::Var CreateSymbolicVar(const PrimExpr& expr) { |
| 188 | + tir::Var symbolic_var("composite_" + std::to_string(composite_counter_++), expr->dtype); |
| 189 | + |
| 190 | + // Create PrimValue for the composite expression |
| 191 | + PrimValue prim_val(expr); |
| 192 | + PrimStructInfo prim_sinfo(symbolic_var); |
| 193 | + Var relax_var("comp_val_" + std::to_string(composite_counter_ - 1), prim_sinfo); |
| 194 | + |
| 195 | + // Emit MatchCast to define the symbolic variable |
| 196 | + auto match_cast = MatchCast(relax_var, prim_val, prim_sinfo); |
| 197 | + builder_->EmitNormalized(match_cast); |
| 198 | + |
| 199 | + known_values_[symbolic_var] = KnownValue{expr, match_cast}; |
| 200 | + |
| 201 | + return symbolic_var; |
| 202 | + } |
| 203 | + |
143 | 204 | std::unordered_map<tir::Var, KnownValue> known_values_; |
| 205 | + int composite_counter_ = 0; |
144 | 206 | }; |
145 | 207 |
|
146 | 208 | struct CanonicalizationPlan { |
|
0 commit comments