1 //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A pass that converts loops generated by the sparse compiler into a form that 10 // can exploit SIMD instructions of the target architecture. Note that this pass 11 // ensures the sparse compiler can generate efficient SIMD (including ArmSVE 12 // support) with proper separation of concerns as far as sparsification and 13 // vectorization is concerned. However, this pass is not the final abstraction 14 // level we want, and not the general vectorizer we want either. It forms a good 15 // stepping stone for incremental future improvements though. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "CodegenUtils.h" 20 21 #include "mlir/Dialect/Affine/IR/AffineOps.h" 22 #include "mlir/Dialect/Arith/IR/Arith.h" 23 #include "mlir/Dialect/Complex/IR/Complex.h" 24 #include "mlir/Dialect/Math/IR/Math.h" 25 #include "mlir/Dialect/MemRef/IR/MemRef.h" 26 #include "mlir/Dialect/SCF/IR/SCF.h" 27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 28 #include "mlir/Dialect/Vector/IR/VectorOps.h" 29 #include "mlir/IR/Matchers.h" 30 31 using namespace mlir; 32 using namespace mlir::sparse_tensor; 33 34 namespace { 35 36 /// Target SIMD properties: 37 /// vectorLength: # packed data elements (viz. vector<16xf32> has length 16) 38 /// enableVLAVectorization: enables scalable vectors (viz. ARMSve) 39 /// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency 40 struct VL { 41 unsigned vectorLength; 42 bool enableVLAVectorization; 43 bool enableSIMDIndex32; 44 }; 45 46 /// Helper to test for given index value. 47 static bool isIntValue(Value val, int64_t idx) { 48 if (auto ival = getConstantIntValue(val)) 49 return *ival == idx; 50 return false; 51 } 52 53 /// Constructs vector type for element type. 54 static VectorType vectorType(VL vl, Type etp) { 55 unsigned numScalableDims = vl.enableVLAVectorization; 56 return VectorType::get(vl.vectorLength, etp, numScalableDims); 57 } 58 59 /// Constructs vector type from pointer. 60 static VectorType vectorType(VL vl, Value ptr) { 61 return vectorType(vl, ptr.getType().cast<MemRefType>().getElementType()); 62 } 63 64 /// Constructs vector iteration mask. 65 static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl, 66 Value iv, Value lo, Value hi, Value step) { 67 VectorType mtp = vectorType(vl, rewriter.getI1Type()); 68 // Special case if the vector length evenly divides the trip count (for 69 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 70 // so that all subsequent masked memory operations are immediately folded 71 // into unconditional memory operations. 72 IntegerAttr loInt, hiInt, stepInt; 73 if (matchPattern(lo, m_Constant(&loInt)) && 74 matchPattern(hi, m_Constant(&hiInt)) && 75 matchPattern(step, m_Constant(&stepInt))) { 76 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) { 77 Value trueVal = constantI1(rewriter, loc, true); 78 return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal); 79 } 80 } 81 // Otherwise, generate a vector mask that avoids overrunning the upperbound 82 // during vector execution. Here we rely on subsequent loop optimizations to 83 // avoid executing the mask in all iterations, for example, by splitting the 84 // loop into an unconditional vector loop and a scalar cleanup loop. 85 auto min = AffineMap::get( 86 /*dimCount=*/2, /*symbolCount=*/1, 87 {rewriter.getAffineSymbolExpr(0), 88 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 89 rewriter.getContext()); 90 Value end = 91 rewriter.createOrFold<AffineMinOp>(loc, min, ValueRange{hi, iv, step}); 92 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 93 } 94 95 /// Generates a vectorized invariant. Here we rely on subsequent loop 96 /// optimizations to hoist the invariant broadcast out of the vector loop. 97 static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl, 98 Value val) { 99 VectorType vtp = vectorType(vl, val.getType()); 100 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 101 } 102 103 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi], 104 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note 105 /// that the sparse compiler can only generate indirect loads in 106 /// the last index, i.e. back(). 107 static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl, 108 Value ptr, ArrayRef<Value> idxs, Value vmask) { 109 VectorType vtp = vectorType(vl, ptr); 110 Value pass = constantZero(rewriter, loc, vtp); 111 if (idxs.back().getType().isa<VectorType>()) { 112 SmallVector<Value> scalarArgs(idxs.begin(), idxs.end()); 113 Value indexVec = idxs.back(); 114 scalarArgs.back() = constantIndex(rewriter, loc, 0); 115 return rewriter.create<vector::GatherOp>(loc, vtp, ptr, scalarArgs, 116 indexVec, vmask, pass); 117 } 118 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, idxs, vmask, 119 pass); 120 } 121 122 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs 123 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note 124 /// that the sparse compiler can only generate indirect stores in 125 /// the last index, i.e. back(). 126 static void genVectorStore(PatternRewriter &rewriter, Location loc, Value ptr, 127 ArrayRef<Value> idxs, Value vmask, Value rhs) { 128 if (idxs.back().getType().isa<VectorType>()) { 129 SmallVector<Value> scalarArgs(idxs.begin(), idxs.end()); 130 Value indexVec = idxs.back(); 131 scalarArgs.back() = constantIndex(rewriter, loc, 0); 132 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, vmask, 133 rhs); 134 return; 135 } 136 rewriter.create<vector::MaskedStoreOp>(loc, ptr, idxs, vmask, rhs); 137 } 138 139 /// Detects a vectorizable reduction operations and returns the 140 /// combining kind of reduction on success in `kind`. 141 static bool isVectorizableReduction(Value red, Value iter, 142 vector::CombiningKind &kind) { 143 if (auto addf = red.getDefiningOp<arith::AddFOp>()) { 144 kind = vector::CombiningKind::ADD; 145 return addf->getOperand(0) == iter || addf->getOperand(1) == iter; 146 } 147 if (auto addi = red.getDefiningOp<arith::AddIOp>()) { 148 kind = vector::CombiningKind::ADD; 149 return addi->getOperand(0) == iter || addi->getOperand(1) == iter; 150 } 151 if (auto subf = red.getDefiningOp<arith::SubFOp>()) { 152 kind = vector::CombiningKind::ADD; 153 return subf->getOperand(0) == iter; 154 } 155 if (auto subi = red.getDefiningOp<arith::SubIOp>()) { 156 kind = vector::CombiningKind::ADD; 157 return subi->getOperand(0) == iter; 158 } 159 if (auto mulf = red.getDefiningOp<arith::MulFOp>()) { 160 kind = vector::CombiningKind::MUL; 161 return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter; 162 } 163 if (auto muli = red.getDefiningOp<arith::MulIOp>()) { 164 kind = vector::CombiningKind::MUL; 165 return muli->getOperand(0) == iter || muli->getOperand(1) == iter; 166 } 167 if (auto andi = red.getDefiningOp<arith::AndIOp>()) { 168 kind = vector::CombiningKind::AND; 169 return andi->getOperand(0) == iter || andi->getOperand(1) == iter; 170 } 171 if (auto ori = red.getDefiningOp<arith::OrIOp>()) { 172 kind = vector::CombiningKind::OR; 173 return ori->getOperand(0) == iter || ori->getOperand(1) == iter; 174 } 175 if (auto xori = red.getDefiningOp<arith::XOrIOp>()) { 176 kind = vector::CombiningKind::XOR; 177 return xori->getOperand(0) == iter || xori->getOperand(1) == iter; 178 } 179 return false; 180 } 181 182 /// Generates an initial value for a vector reduction, following the scheme 183 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 184 /// initial scalar value is correctly embedded in the vector reduction value, 185 /// and a straightforward horizontal reduction will complete the operation. 186 /// Value 'r' denotes the initial value of the reduction outside the loop. 187 static Value genVectorReducInit(PatternRewriter &rewriter, Location loc, 188 Value red, Value iter, Value r, 189 VectorType vtp) { 190 vector::CombiningKind kind; 191 if (!isVectorizableReduction(red, iter, kind)) 192 llvm_unreachable("unknown reduction"); 193 switch (kind) { 194 case vector::CombiningKind::ADD: 195 case vector::CombiningKind::XOR: 196 // Initialize reduction vector to: | 0 | .. | 0 | r | 197 return rewriter.create<vector::InsertElementOp>( 198 loc, r, constantZero(rewriter, loc, vtp), 199 constantIndex(rewriter, loc, 0)); 200 case vector::CombiningKind::MUL: 201 // Initialize reduction vector to: | 1 | .. | 1 | r | 202 return rewriter.create<vector::InsertElementOp>( 203 loc, r, constantOne(rewriter, loc, vtp), 204 constantIndex(rewriter, loc, 0)); 205 case vector::CombiningKind::AND: 206 case vector::CombiningKind::OR: 207 // Initialize reduction vector to: | r | .. | r | r | 208 return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 209 default: 210 break; 211 } 212 llvm_unreachable("unknown reduction kind"); 213 } 214 215 /// This method is called twice to analyze and rewrite the given subscripts. 216 /// The first call (!codegen) does the analysis. Then, on success, the second 217 /// call (codegen) yields the proper vector form in the output parameter 218 /// vector 'idxs'. This mechanism ensures that analysis and rewriting code 219 /// stay in sync. 220 /// 221 /// See https://llvm.org/docs/GetElementPtr.html for some background on 222 /// the complications described below. 223 /// 224 /// We need to generate a pointer/index load from the sparse storage scheme. 225 /// Narrower data types need to be zero extended before casting the value 226 /// into the index type used for looping and indexing. 227 /// 228 /// For the scalar case, subscripts simply zero extend narrower indices 229 /// into 64-bit values before casting to an index type without a performance 230 /// penalty. Indices that already are 64-bit, in theory, cannot express the 231 /// full range since the LLVM backend defines addressing in terms of an 232 /// unsigned pointer/signed index pair. 233 static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp, 234 VL vl, ValueRange subs, bool codegen, 235 Value vmask, SmallVectorImpl<Value> &idxs) { 236 for (auto sub : subs) { 237 // Invariant indices simply pass through. 238 if (sub.dyn_cast<BlockArgument>() || 239 sub.getDefiningOp()->getBlock() != &forOp.getRegion().front()) { 240 if (codegen) 241 idxs.push_back(sub); 242 continue; // success so far 243 } 244 // Look under the hood of casting. 245 auto cast = sub; 246 while (1) { 247 if (auto icast = cast.getDefiningOp<arith::IndexCastOp>()) 248 cast = icast->getOperand(0); 249 else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>()) 250 cast = ecast->getOperand(0); 251 else 252 break; 253 } 254 // Since the index vector is used in a subsequent gather/scatter 255 // operations, which effectively defines an unsigned pointer + signed 256 // index, we must zero extend the vector to an index width. For 8-bit 257 // and 16-bit values, an 32-bit index width suffices. For 32-bit values, 258 // zero extending the elements into 64-bit loses some performance since 259 // the 32-bit indexed gather/scatter is more efficient than the 64-bit 260 // index variant (if the negative 32-bit index space is unused, the 261 // enableSIMDIndex32 flag can preserve this performance). For 64-bit 262 // values, there is no good way to state that the indices are unsigned, 263 // which creates the potential of incorrect address calculations in the 264 // unlikely case we need such extremely large offsets. 265 if (auto load = cast.getDefiningOp<memref::LoadOp>()) { 266 if (codegen) { 267 SmallVector<Value> idxs2(load.getIndices()); // no need to analyze 268 Location loc = forOp.getLoc(); 269 Value vload = 270 genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask); 271 Type etp = vload.getType().cast<VectorType>().getElementType(); 272 if (!etp.isa<IndexType>()) { 273 if (etp.getIntOrFloatBitWidth() < 32) 274 vload = rewriter.create<arith::ExtUIOp>( 275 loc, vectorType(vl, rewriter.getI32Type()), vload); 276 else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32) 277 vload = rewriter.create<arith::ExtUIOp>( 278 loc, vectorType(vl, rewriter.getI64Type()), vload); 279 } 280 idxs.push_back(vload); 281 } 282 continue; // success so far 283 } 284 return false; 285 } 286 return true; 287 } 288 289 #define UNAOP(xxx) \ 290 if (isa<xxx>(def)) { \ 291 if (codegen) \ 292 vexp = rewriter.create<xxx>(loc, vx); \ 293 return true; \ 294 } 295 296 #define BINOP(xxx) \ 297 if (isa<xxx>(def)) { \ 298 if (codegen) \ 299 vexp = rewriter.create<xxx>(loc, vx, vy); \ 300 return true; \ 301 } 302 303 /// This method is called twice to analyze and rewrite the given expression. 304 /// The first call (!codegen) does the analysis. Then, on success, the second 305 /// call (codegen) yields the proper vector form in the output parameter 'vexp'. 306 /// This mechanism ensures that analysis and rewriting code stay in sync. 307 static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, 308 Value exp, bool codegen, Value vmask, Value &vexp) { 309 // A block argument in invariant. 310 if (auto arg = exp.dyn_cast<BlockArgument>()) { 311 if (codegen) 312 vexp = genVectorInvariantValue(rewriter, vl, exp); 313 return true; 314 } 315 // Something defined outside the loop-body is invariant as well. 316 Operation *def = exp.getDefiningOp(); 317 if (def->getBlock() != &forOp.getRegion().front()) { 318 if (codegen) 319 vexp = genVectorInvariantValue(rewriter, vl, exp); 320 return true; 321 } 322 // Inside loop-body unary and binary operations. Note that it would be 323 // nicer if we could somehow test and build the operations in a more 324 // concise manner than just listing them all (although this way we know 325 // for certain that they can vectorize). 326 Location loc = forOp.getLoc(); 327 if (auto load = dyn_cast<memref::LoadOp>(def)) { 328 auto subs = load.getIndices(); 329 SmallVector<Value> idxs; 330 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) { 331 if (codegen) 332 vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask); 333 return true; 334 } 335 } else if (def->getNumOperands() == 1) { 336 Value vx; 337 if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, 338 vx)) { 339 UNAOP(math::AbsFOp) 340 UNAOP(math::AbsIOp) 341 UNAOP(math::CeilOp) 342 UNAOP(math::FloorOp) 343 UNAOP(math::SqrtOp) 344 UNAOP(math::ExpM1Op) 345 UNAOP(math::Log1pOp) 346 UNAOP(math::SinOp) 347 UNAOP(math::TanhOp) 348 UNAOP(arith::NegFOp) 349 } 350 } else if (def->getNumOperands() == 2) { 351 Value vx, vy; 352 if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, 353 vx) && 354 vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask, 355 vy)) { 356 BINOP(arith::MulFOp) 357 BINOP(arith::MulIOp) 358 BINOP(arith::DivFOp) 359 BINOP(arith::DivSIOp) 360 BINOP(arith::DivUIOp) 361 BINOP(arith::AddFOp) 362 BINOP(arith::AddIOp) 363 BINOP(arith::SubFOp) 364 BINOP(arith::SubIOp) 365 BINOP(arith::AndIOp) 366 BINOP(arith::OrIOp) 367 BINOP(arith::XOrIOp) 368 } 369 } 370 return false; 371 } 372 373 #undef UNAOP 374 #undef BINOP 375 376 /// This method is called twice to analyze and rewrite the given for-loop. 377 /// The first call (!codegen) does the analysis. Then, on success, the second 378 /// call (codegen) rewriters the IR into vector form. This mechanism ensures 379 /// that analysis and rewriting code stay in sync. 380 static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, 381 bool codegen) { 382 Location loc = forOp.getLoc(); 383 Block &block = forOp.getRegion().front(); 384 scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator()); 385 auto &last = *++block.rbegin(); 386 scf::ForOp forOpNew; 387 388 // Perform initial set up during codegen (we know that the first analysis 389 // pass was successful). For reductions, we need to construct a completely 390 // new for-loop, since the incoming and outgoing reduction type 391 // changes into SIMD form. For stores, we can simply adjust the stride 392 // and insert in the existing for-loop. In both cases, we set up a vector 393 // mask for all operations which takes care of confining vectors to 394 // the original iteration space (later cleanup loops or other 395 // optimizations can take care of those). 396 Value vmask; 397 if (codegen) { 398 Value step = constantIndex(rewriter, loc, vl.vectorLength); 399 if (vl.enableVLAVectorization) { 400 Value vscale = 401 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); 402 step = rewriter.create<arith::MulIOp>(loc, vscale, step); 403 } 404 if (!yield.getResults().empty()) { 405 Value init = forOp.getInitArgs()[0]; 406 VectorType vtp = vectorType(vl, init.getType()); 407 Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0), 408 forOp.getRegionIterArg(0), init, vtp); 409 forOpNew = rewriter.create<scf::ForOp>( 410 loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit); 411 forOpNew->setAttr( 412 SparseTensorLoopEmitter::getLoopEmitterLoopAttrName(), 413 forOp->getAttr( 414 SparseTensorLoopEmitter::getLoopEmitterLoopAttrName())); 415 rewriter.setInsertionPointToStart(forOpNew.getBody()); 416 } else { 417 forOp.setStep(step); 418 rewriter.setInsertionPoint(yield); 419 } 420 vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(), 421 forOp.getLowerBound(), forOp.getUpperBound(), step); 422 } 423 424 // Sparse for-loops either are terminated by a non-empty yield operation 425 // (reduction loop) or otherwise by a store operation (pararallel loop). 426 if (!yield.getResults().empty()) { 427 // Analyze/vectorize reduction. 428 if (yield->getNumOperands() != 1) 429 return false; 430 Value red = yield->getOperand(0); 431 Value iter = forOp.getRegionIterArg(0); 432 vector::CombiningKind kind; 433 Value vrhs; 434 if (isVectorizableReduction(red, iter, kind) && 435 vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) { 436 if (codegen) { 437 Value partial = forOpNew.getResult(0); 438 Value vpass = genVectorInvariantValue(rewriter, vl, iter); 439 Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass); 440 rewriter.create<scf::YieldOp>(loc, vred); 441 rewriter.setInsertionPointAfter(forOpNew); 442 Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial); 443 // Now do some relinking (last one is not completely type safe 444 // but all bad ones are removed right away). This also folds away 445 // nop broadcast operations. 446 forOp.getResult(0).replaceAllUsesWith(vres); 447 forOp.getInductionVar().replaceAllUsesWith(forOpNew.getInductionVar()); 448 forOp.getRegionIterArg(0).replaceAllUsesWith( 449 forOpNew.getRegionIterArg(0)); 450 rewriter.eraseOp(forOp); 451 } 452 return true; 453 } 454 } else if (auto store = dyn_cast<memref::StoreOp>(last)) { 455 // Analyze/vectorize store operation. 456 auto subs = store.getIndices(); 457 SmallVector<Value> idxs; 458 Value rhs = store.getValue(); 459 Value vrhs; 460 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) && 461 vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) { 462 if (codegen) { 463 genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs); 464 rewriter.eraseOp(store); 465 } 466 return true; 467 } 468 } 469 470 assert(!codegen && "cannot call codegen when analysis failed"); 471 return false; 472 } 473 474 /// Basic for-loop vectorizer. 475 struct ForOpRewriter : public OpRewritePattern<scf::ForOp> { 476 public: 477 using OpRewritePattern<scf::ForOp>::OpRewritePattern; 478 479 ForOpRewriter(MLIRContext *context, unsigned vectorLength, 480 bool enableVLAVectorization, bool enableSIMDIndex32) 481 : OpRewritePattern(context), 482 vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {} 483 484 LogicalResult matchAndRewrite(scf::ForOp op, 485 PatternRewriter &rewriter) const override { 486 // Check for single block, unit-stride for-loop that is generated by 487 // sparse compiler, which means no data dependence analysis is required, 488 // and its loop-body is very restricted in form. 489 if (!op.getRegion().hasOneBlock() || !isIntValue(op.getStep(), 1) || 490 !op->hasAttr(SparseTensorLoopEmitter::getLoopEmitterLoopAttrName())) 491 return failure(); 492 // Analyze (!codegen) and rewrite (codegen) loop-body. 493 if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) && 494 vectorizeStmt(rewriter, op, vl, /*codegen=*/true)) 495 return success(); 496 return failure(); 497 } 498 499 private: 500 const VL vl; 501 }; 502 503 /// Reduction chain cleanup. 504 /// v = for { } 505 /// s = vsum(v) v = for { } 506 /// u = expand(s) -> for (v) { } 507 /// for (u) { } 508 template <typename VectorOp> 509 struct ReducChainRewriter : public OpRewritePattern<VectorOp> { 510 public: 511 using OpRewritePattern<VectorOp>::OpRewritePattern; 512 513 LogicalResult matchAndRewrite(VectorOp op, 514 PatternRewriter &rewriter) const override { 515 Value inp = op.getSource(); 516 if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) { 517 if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) { 518 if (forOp->hasAttr( 519 SparseTensorLoopEmitter::getLoopEmitterLoopAttrName())) { 520 rewriter.replaceOp(op, redOp.getVector()); 521 return success(); 522 } 523 } 524 } 525 return failure(); 526 } 527 }; 528 529 } // namespace 530 531 //===----------------------------------------------------------------------===// 532 // Public method for populating vectorization rules. 533 //===----------------------------------------------------------------------===// 534 535 /// Populates the given patterns list with vectorization rules. 536 void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns, 537 unsigned vectorLength, 538 bool enableVLAVectorization, 539 bool enableSIMDIndex32) { 540 patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength, 541 enableVLAVectorization, enableSIMDIndex32); 542 patterns.add<ReducChainRewriter<vector::InsertElementOp>, 543 ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext()); 544 } 545