1 //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A pass that converts loops generated by the sparse compiler into a form that 10 // can exploit SIMD instructions of the target architecture. Note that this pass 11 // ensures the sparse compiler can generate efficient SIMD (including ArmSVE 12 // support) with proper separation of concerns as far as sparsification and 13 // vectorization is concerned. However, this pass is not the final abstraction 14 // level we want, and not the general vectorizer we want either. It forms a good 15 // stepping stone for incremental future improvements though. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "CodegenUtils.h" 20 21 #include "mlir/Dialect/Affine/IR/AffineOps.h" 22 #include "mlir/Dialect/Arith/IR/Arith.h" 23 #include "mlir/Dialect/Complex/IR/Complex.h" 24 #include "mlir/Dialect/Math/IR/Math.h" 25 #include "mlir/Dialect/MemRef/IR/MemRef.h" 26 #include "mlir/Dialect/SCF/IR/SCF.h" 27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 28 #include "mlir/Dialect/Vector/IR/VectorOps.h" 29 #include "mlir/IR/Matchers.h" 30 31 using namespace mlir; 32 using namespace mlir::sparse_tensor; 33 34 namespace { 35 36 /// Target SIMD properties: 37 /// vectorLength: # packed data elements (viz. vector<16xf32> has length 16) 38 /// enableVLAVectorization: enables scalable vectors (viz. ARMSve) 39 /// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency 40 struct VL { 41 unsigned vectorLength; 42 bool enableVLAVectorization; 43 bool enableSIMDIndex32; 44 }; 45 46 /// Helper to test for given index value. 47 static bool isIntValue(Value val, int64_t idx) { 48 if (auto ival = getConstantIntValue(val)) 49 return *ival == idx; 50 return false; 51 } 52 53 /// Constructs vector type for element type. 54 static VectorType vectorType(VL vl, Type etp) { 55 unsigned numScalableDims = vl.enableVLAVectorization; 56 return VectorType::get(vl.vectorLength, etp, numScalableDims); 57 } 58 59 /// Constructs vector type from pointer. 60 static VectorType vectorType(VL vl, Value ptr) { 61 return vectorType(vl, ptr.getType().cast<MemRefType>().getElementType()); 62 } 63 64 /// Constructs vector iteration mask. 65 static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl, 66 Value iv, Value lo, Value hi, Value step) { 67 VectorType mtp = vectorType(vl, rewriter.getI1Type()); 68 // Special case if the vector length evenly divides the trip count (for 69 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 70 // so that all subsequent masked memory operations are immediately folded 71 // into unconditional memory operations. 72 IntegerAttr loInt, hiInt, stepInt; 73 if (matchPattern(lo, m_Constant(&loInt)) && 74 matchPattern(hi, m_Constant(&hiInt)) && 75 matchPattern(step, m_Constant(&stepInt))) { 76 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) { 77 Value trueVal = constantI1(rewriter, loc, true); 78 return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal); 79 } 80 } 81 // Otherwise, generate a vector mask that avoids overrunning the upperbound 82 // during vector execution. Here we rely on subsequent loop optimizations to 83 // avoid executing the mask in all iterations, for example, by splitting the 84 // loop into an unconditional vector loop and a scalar cleanup loop. 85 auto min = AffineMap::get( 86 /*dimCount=*/2, /*symbolCount=*/1, 87 {rewriter.getAffineSymbolExpr(0), 88 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 89 rewriter.getContext()); 90 Value end = 91 rewriter.createOrFold<AffineMinOp>(loc, min, ValueRange{hi, iv, step}); 92 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 93 } 94 95 /// Generates a vectorized invariant. Here we rely on subsequent loop 96 /// optimizations to hoist the invariant broadcast out of the vector loop. 97 static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl, 98 Value val) { 99 VectorType vtp = vectorType(vl, val.getType()); 100 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 101 } 102 103 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi], 104 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note 105 /// that the sparse compiler can only generate indirect loads in 106 /// the last index, i.e. back(). 107 static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl, 108 Value ptr, ArrayRef<Value> idxs, Value vmask) { 109 VectorType vtp = vectorType(vl, ptr); 110 Value pass = constantZero(rewriter, loc, vtp); 111 if (idxs.back().getType().isa<VectorType>()) { 112 SmallVector<Value> scalarArgs(idxs.begin(), idxs.end()); 113 Value indexVec = idxs.back(); 114 scalarArgs.back() = constantIndex(rewriter, loc, 0); 115 return rewriter.create<vector::GatherOp>(loc, vtp, ptr, scalarArgs, 116 indexVec, vmask, pass); 117 } 118 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, idxs, vmask, 119 pass); 120 } 121 122 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs 123 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note 124 /// that the sparse compiler can only generate indirect stores in 125 /// the last index, i.e. back(). 126 static void genVectorStore(PatternRewriter &rewriter, Location loc, Value ptr, 127 ArrayRef<Value> idxs, Value vmask, Value rhs) { 128 if (idxs.back().getType().isa<VectorType>()) { 129 SmallVector<Value> scalarArgs(idxs.begin(), idxs.end()); 130 Value indexVec = idxs.back(); 131 scalarArgs.back() = constantIndex(rewriter, loc, 0); 132 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, vmask, 133 rhs); 134 return; 135 } 136 rewriter.create<vector::MaskedStoreOp>(loc, ptr, idxs, vmask, rhs); 137 } 138 139 /// Detects a vectorizable reduction operations and returns the 140 /// combining kind of reduction on success in `kind`. 141 static bool isVectorizableReduction(Value red, Value iter, 142 vector::CombiningKind &kind) { 143 if (auto addf = red.getDefiningOp<arith::AddFOp>()) { 144 kind = vector::CombiningKind::ADD; 145 return addf->getOperand(0) == iter || addf->getOperand(1) == iter; 146 } 147 if (auto addi = red.getDefiningOp<arith::AddIOp>()) { 148 kind = vector::CombiningKind::ADD; 149 return addi->getOperand(0) == iter || addi->getOperand(1) == iter; 150 } 151 if (auto subf = red.getDefiningOp<arith::SubFOp>()) { 152 kind = vector::CombiningKind::ADD; 153 return subf->getOperand(0) == iter; 154 } 155 if (auto subi = red.getDefiningOp<arith::SubIOp>()) { 156 kind = vector::CombiningKind::ADD; 157 return subi->getOperand(0) == iter; 158 } 159 if (auto mulf = red.getDefiningOp<arith::MulFOp>()) { 160 kind = vector::CombiningKind::MUL; 161 return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter; 162 } 163 if (auto muli = red.getDefiningOp<arith::MulIOp>()) { 164 kind = vector::CombiningKind::MUL; 165 return muli->getOperand(0) == iter || muli->getOperand(1) == iter; 166 } 167 if (auto andi = red.getDefiningOp<arith::AndIOp>()) { 168 kind = vector::CombiningKind::AND; 169 return andi->getOperand(0) == iter || andi->getOperand(1) == iter; 170 } 171 if (auto ori = red.getDefiningOp<arith::OrIOp>()) { 172 kind = vector::CombiningKind::OR; 173 return ori->getOperand(0) == iter || ori->getOperand(1) == iter; 174 } 175 if (auto xori = red.getDefiningOp<arith::XOrIOp>()) { 176 kind = vector::CombiningKind::XOR; 177 return xori->getOperand(0) == iter || xori->getOperand(1) == iter; 178 } 179 return false; 180 } 181 182 /// Generates an initial value for a vector reduction, following the scheme 183 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 184 /// initial scalar value is correctly embedded in the vector reduction value, 185 /// and a straightforward horizontal reduction will complete the operation. 186 /// Value 'r' denotes the initial value of the reduction outside the loop. 187 static Value genVectorReducInit(PatternRewriter &rewriter, Location loc, 188 Value red, Value iter, Value r, 189 VectorType vtp) { 190 vector::CombiningKind kind; 191 if (!isVectorizableReduction(red, iter, kind)) 192 llvm_unreachable("unknown reduction"); 193 switch (kind) { 194 case vector::CombiningKind::ADD: 195 case vector::CombiningKind::XOR: 196 // Initialize reduction vector to: | 0 | .. | 0 | r | 197 return rewriter.create<vector::InsertElementOp>( 198 loc, r, constantZero(rewriter, loc, vtp), 199 constantIndex(rewriter, loc, 0)); 200 case vector::CombiningKind::MUL: 201 // Initialize reduction vector to: | 1 | .. | 1 | r | 202 return rewriter.create<vector::InsertElementOp>( 203 loc, r, constantOne(rewriter, loc, vtp), 204 constantIndex(rewriter, loc, 0)); 205 case vector::CombiningKind::AND: 206 case vector::CombiningKind::OR: 207 // Initialize reduction vector to: | r | .. | r | r | 208 return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 209 default: 210 break; 211 } 212 llvm_unreachable("unknown reduction kind"); 213 } 214 215 /// This method is called twice to analyze and rewrite the given subscripts. 216 /// The first call (!codegen) does the analysis. Then, on success, the second 217 /// call (codegen) yields the proper vector form in the output parameter 218 /// vector 'idxs'. This mechanism ensures that analysis and rewriting code 219 /// stay in sync. Note that the analyis part is simple because the sparse 220 /// compiler only generates relatively simple subscript expressions. 221 /// 222 /// See https://llvm.org/docs/GetElementPtr.html for some background on 223 /// the complications described below. 224 /// 225 /// We need to generate a pointer/index load from the sparse storage scheme. 226 /// Narrower data types need to be zero extended before casting the value 227 /// into the index type used for looping and indexing. 228 /// 229 /// For the scalar case, subscripts simply zero extend narrower indices 230 /// into 64-bit values before casting to an index type without a performance 231 /// penalty. Indices that already are 64-bit, in theory, cannot express the 232 /// full range since the LLVM backend defines addressing in terms of an 233 /// unsigned pointer/signed index pair. 234 static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp, 235 VL vl, ValueRange subs, bool codegen, 236 Value vmask, SmallVectorImpl<Value> &idxs) { 237 for (auto sub : subs) { 238 // Invariant/loop indices simply pass through. 239 if (sub.dyn_cast<BlockArgument>() || 240 sub.getDefiningOp()->getBlock() != &forOp.getRegion().front()) { 241 if (codegen) 242 idxs.push_back(sub); 243 continue; // success so far 244 } 245 // Look under the hood of casting. 246 auto cast = sub; 247 while (1) { 248 if (auto icast = cast.getDefiningOp<arith::IndexCastOp>()) 249 cast = icast->getOperand(0); 250 else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>()) 251 cast = ecast->getOperand(0); 252 else 253 break; 254 } 255 // Since the index vector is used in a subsequent gather/scatter 256 // operations, which effectively defines an unsigned pointer + signed 257 // index, we must zero extend the vector to an index width. For 8-bit 258 // and 16-bit values, an 32-bit index width suffices. For 32-bit values, 259 // zero extending the elements into 64-bit loses some performance since 260 // the 32-bit indexed gather/scatter is more efficient than the 64-bit 261 // index variant (if the negative 32-bit index space is unused, the 262 // enableSIMDIndex32 flag can preserve this performance). For 64-bit 263 // values, there is no good way to state that the indices are unsigned, 264 // which creates the potential of incorrect address calculations in the 265 // unlikely case we need such extremely large offsets. 266 if (auto load = cast.getDefiningOp<memref::LoadOp>()) { 267 if (codegen) { 268 SmallVector<Value> idxs2(load.getIndices()); // no need to analyze 269 Location loc = forOp.getLoc(); 270 Value vload = 271 genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask); 272 Type etp = vload.getType().cast<VectorType>().getElementType(); 273 if (!etp.isa<IndexType>()) { 274 if (etp.getIntOrFloatBitWidth() < 32) 275 vload = rewriter.create<arith::ExtUIOp>( 276 loc, vectorType(vl, rewriter.getI32Type()), vload); 277 else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32) 278 vload = rewriter.create<arith::ExtUIOp>( 279 loc, vectorType(vl, rewriter.getI64Type()), vload); 280 } 281 idxs.push_back(vload); 282 } 283 continue; // success so far 284 } 285 return false; 286 } 287 return true; 288 } 289 290 #define UNAOP(xxx) \ 291 if (isa<xxx>(def)) { \ 292 if (codegen) \ 293 vexp = rewriter.create<xxx>(loc, vx); \ 294 return true; \ 295 } 296 297 #define TYPEDUNAOP(xxx) \ 298 if (auto x = dyn_cast<xxx>(def)) { \ 299 if (codegen) { \ 300 VectorType vtp = vectorType(vl, x.getType()); \ 301 vexp = rewriter.create<xxx>(loc, vtp, vx); \ 302 } \ 303 return true; \ 304 } 305 306 #define BINOP(xxx) \ 307 if (isa<xxx>(def)) { \ 308 if (codegen) \ 309 vexp = rewriter.create<xxx>(loc, vx, vy); \ 310 return true; \ 311 } 312 313 /// This method is called twice to analyze and rewrite the given expression. 314 /// The first call (!codegen) does the analysis. Then, on success, the second 315 /// call (codegen) yields the proper vector form in the output parameter 'vexp'. 316 /// This mechanism ensures that analysis and rewriting code stay in sync. Note 317 /// that the analyis part is simple because the sparse compiler only generates 318 /// relatively simple expressions inside the for-loops. 319 static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, 320 Value exp, bool codegen, Value vmask, Value &vexp) { 321 Location loc = forOp.getLoc(); 322 // Reject unsupported types. 323 if (!VectorType::isValidElementType(exp.getType())) 324 return false; 325 // A block argument is invariant/reduction/index. 326 if (auto arg = exp.dyn_cast<BlockArgument>()) { 327 if (arg == forOp.getInductionVar()) { 328 // We encountered a single, innermost index inside the computation, 329 // such as a[i] = i, which must convert to [i, i+1, ...]. 330 if (codegen) { 331 VectorType vtp = vectorType(vl, arg.getType()); 332 Value veci = rewriter.create<vector::BroadcastOp>(loc, vtp, arg); 333 Value incr; 334 if (vl.enableVLAVectorization) { 335 Type stepvty = vectorType(vl, rewriter.getI64Type()); 336 Value stepv = rewriter.create<LLVM::StepVectorOp>(loc, stepvty); 337 incr = rewriter.create<arith::IndexCastOp>(loc, vtp, stepv); 338 } else { 339 SmallVector<APInt> integers; 340 for (unsigned i = 0, l = vl.vectorLength; i < l; i++) 341 integers.push_back(APInt(/*width=*/64, i)); 342 auto values = DenseElementsAttr::get(vtp, integers); 343 incr = rewriter.create<arith::ConstantOp>(loc, vtp, values); 344 } 345 vexp = rewriter.create<arith::AddIOp>(loc, veci, incr); 346 } 347 return true; 348 } else { 349 // An invariant or reduction. In both cases, we treat this as an 350 // invariant value, and rely on later replacing and folding to 351 // construct a proper reduction chain for the latter case. 352 if (codegen) 353 vexp = genVectorInvariantValue(rewriter, vl, exp); 354 return true; 355 } 356 } 357 // Something defined outside the loop-body is invariant. 358 Operation *def = exp.getDefiningOp(); 359 if (def->getBlock() != &forOp.getRegion().front()) { 360 if (codegen) 361 vexp = genVectorInvariantValue(rewriter, vl, exp); 362 return true; 363 } 364 // Proper load operations. These are either values involved in the 365 // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi], 366 // or index values inside the computation that are now fetched from 367 // the sparse storage index arrays, such as a[i] = i becomes 368 // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index 369 // and 'hi = lo + vl - 1'. 370 if (auto load = dyn_cast<memref::LoadOp>(def)) { 371 auto subs = load.getIndices(); 372 SmallVector<Value> idxs; 373 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) { 374 if (codegen) 375 vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask); 376 return true; 377 } 378 return false; 379 } 380 // Inside loop-body unary and binary operations. Note that it would be 381 // nicer if we could somehow test and build the operations in a more 382 // concise manner than just listing them all (although this way we know 383 // for certain that they can vectorize). 384 // 385 // TODO: avoid visiting CSEs multiple times 386 // 387 if (def->getNumOperands() == 1) { 388 Value vx; 389 if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, 390 vx)) { 391 UNAOP(math::AbsFOp) 392 UNAOP(math::AbsIOp) 393 UNAOP(math::CeilOp) 394 UNAOP(math::FloorOp) 395 UNAOP(math::SqrtOp) 396 UNAOP(math::ExpM1Op) 397 UNAOP(math::Log1pOp) 398 UNAOP(math::SinOp) 399 UNAOP(math::TanhOp) 400 UNAOP(arith::NegFOp) 401 TYPEDUNAOP(arith::TruncFOp) 402 TYPEDUNAOP(arith::ExtFOp) 403 TYPEDUNAOP(arith::FPToSIOp) 404 TYPEDUNAOP(arith::FPToUIOp) 405 TYPEDUNAOP(arith::SIToFPOp) 406 TYPEDUNAOP(arith::UIToFPOp) 407 TYPEDUNAOP(arith::ExtSIOp) 408 TYPEDUNAOP(arith::ExtUIOp) 409 TYPEDUNAOP(arith::IndexCastOp) 410 TYPEDUNAOP(arith::TruncIOp) 411 TYPEDUNAOP(arith::BitcastOp) 412 } 413 } else if (def->getNumOperands() == 2) { 414 Value vx, vy; 415 if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, 416 vx) && 417 vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask, 418 vy)) { 419 BINOP(arith::MulFOp) 420 BINOP(arith::MulIOp) 421 BINOP(arith::DivFOp) 422 BINOP(arith::DivSIOp) 423 BINOP(arith::DivUIOp) 424 BINOP(arith::AddFOp) 425 BINOP(arith::AddIOp) 426 BINOP(arith::SubFOp) 427 BINOP(arith::SubIOp) 428 BINOP(arith::AndIOp) 429 BINOP(arith::OrIOp) 430 BINOP(arith::XOrIOp) 431 // TODO: shift by invariant? 432 } 433 } 434 return false; 435 } 436 437 #undef UNAOP 438 #undef TYPEDUNAOP 439 #undef BINOP 440 441 /// This method is called twice to analyze and rewrite the given for-loop. 442 /// The first call (!codegen) does the analysis. Then, on success, the second 443 /// call (codegen) rewriters the IR into vector form. This mechanism ensures 444 /// that analysis and rewriting code stay in sync. 445 static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, 446 bool codegen) { 447 Location loc = forOp.getLoc(); 448 Block &block = forOp.getRegion().front(); 449 scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator()); 450 auto &last = *++block.rbegin(); 451 scf::ForOp forOpNew; 452 453 // Perform initial set up during codegen (we know that the first analysis 454 // pass was successful). For reductions, we need to construct a completely 455 // new for-loop, since the incoming and outgoing reduction type 456 // changes into SIMD form. For stores, we can simply adjust the stride 457 // and insert in the existing for-loop. In both cases, we set up a vector 458 // mask for all operations which takes care of confining vectors to 459 // the original iteration space (later cleanup loops or other 460 // optimizations can take care of those). 461 Value vmask; 462 if (codegen) { 463 Value step = constantIndex(rewriter, loc, vl.vectorLength); 464 if (vl.enableVLAVectorization) { 465 Value vscale = 466 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); 467 step = rewriter.create<arith::MulIOp>(loc, vscale, step); 468 } 469 if (!yield.getResults().empty()) { 470 Value init = forOp.getInitArgs()[0]; 471 VectorType vtp = vectorType(vl, init.getType()); 472 Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0), 473 forOp.getRegionIterArg(0), init, vtp); 474 forOpNew = rewriter.create<scf::ForOp>( 475 loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit); 476 forOpNew->setAttr( 477 SparseTensorLoopEmitter::getLoopEmitterLoopAttrName(), 478 forOp->getAttr( 479 SparseTensorLoopEmitter::getLoopEmitterLoopAttrName())); 480 rewriter.setInsertionPointToStart(forOpNew.getBody()); 481 } else { 482 forOp.setStep(step); 483 rewriter.setInsertionPoint(yield); 484 } 485 vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(), 486 forOp.getLowerBound(), forOp.getUpperBound(), step); 487 } 488 489 // Sparse for-loops either are terminated by a non-empty yield operation 490 // (reduction loop) or otherwise by a store operation (pararallel loop). 491 if (!yield.getResults().empty()) { 492 // Analyze/vectorize reduction. 493 if (yield->getNumOperands() != 1) 494 return false; 495 Value red = yield->getOperand(0); 496 Value iter = forOp.getRegionIterArg(0); 497 vector::CombiningKind kind; 498 Value vrhs; 499 if (isVectorizableReduction(red, iter, kind) && 500 vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) { 501 if (codegen) { 502 Value partial = forOpNew.getResult(0); 503 Value vpass = genVectorInvariantValue(rewriter, vl, iter); 504 Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass); 505 rewriter.create<scf::YieldOp>(loc, vred); 506 rewriter.setInsertionPointAfter(forOpNew); 507 Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial); 508 // Now do some relinking (last one is not completely type safe 509 // but all bad ones are removed right away). This also folds away 510 // nop broadcast operations. 511 forOp.getResult(0).replaceAllUsesWith(vres); 512 forOp.getInductionVar().replaceAllUsesWith(forOpNew.getInductionVar()); 513 forOp.getRegionIterArg(0).replaceAllUsesWith( 514 forOpNew.getRegionIterArg(0)); 515 rewriter.eraseOp(forOp); 516 } 517 return true; 518 } 519 } else if (auto store = dyn_cast<memref::StoreOp>(last)) { 520 // Analyze/vectorize store operation. 521 auto subs = store.getIndices(); 522 SmallVector<Value> idxs; 523 Value rhs = store.getValue(); 524 Value vrhs; 525 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) && 526 vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) { 527 if (codegen) { 528 genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs); 529 rewriter.eraseOp(store); 530 } 531 return true; 532 } 533 } 534 535 assert(!codegen && "cannot call codegen when analysis failed"); 536 return false; 537 } 538 539 /// Basic for-loop vectorizer. 540 struct ForOpRewriter : public OpRewritePattern<scf::ForOp> { 541 public: 542 using OpRewritePattern<scf::ForOp>::OpRewritePattern; 543 544 ForOpRewriter(MLIRContext *context, unsigned vectorLength, 545 bool enableVLAVectorization, bool enableSIMDIndex32) 546 : OpRewritePattern(context), 547 vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {} 548 549 LogicalResult matchAndRewrite(scf::ForOp op, 550 PatternRewriter &rewriter) const override { 551 // Check for single block, unit-stride for-loop that is generated by 552 // sparse compiler, which means no data dependence analysis is required, 553 // and its loop-body is very restricted in form. 554 if (!op.getRegion().hasOneBlock() || !isIntValue(op.getStep(), 1) || 555 !op->hasAttr(SparseTensorLoopEmitter::getLoopEmitterLoopAttrName())) 556 return failure(); 557 // Analyze (!codegen) and rewrite (codegen) loop-body. 558 if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) && 559 vectorizeStmt(rewriter, op, vl, /*codegen=*/true)) 560 return success(); 561 return failure(); 562 } 563 564 private: 565 const VL vl; 566 }; 567 568 /// Reduction chain cleanup. 569 /// v = for { } 570 /// s = vsum(v) v = for { } 571 /// u = expand(s) -> for (v) { } 572 /// for (u) { } 573 template <typename VectorOp> 574 struct ReducChainRewriter : public OpRewritePattern<VectorOp> { 575 public: 576 using OpRewritePattern<VectorOp>::OpRewritePattern; 577 578 LogicalResult matchAndRewrite(VectorOp op, 579 PatternRewriter &rewriter) const override { 580 Value inp = op.getSource(); 581 if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) { 582 if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) { 583 if (forOp->hasAttr( 584 SparseTensorLoopEmitter::getLoopEmitterLoopAttrName())) { 585 rewriter.replaceOp(op, redOp.getVector()); 586 return success(); 587 } 588 } 589 } 590 return failure(); 591 } 592 }; 593 594 } // namespace 595 596 //===----------------------------------------------------------------------===// 597 // Public method for populating vectorization rules. 598 //===----------------------------------------------------------------------===// 599 600 /// Populates the given patterns list with vectorization rules. 601 void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns, 602 unsigned vectorLength, 603 bool enableVLAVectorization, 604 bool enableSIMDIndex32) { 605 patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength, 606 enableVLAVectorization, enableSIMDIndex32); 607 patterns.add<ReducChainRewriter<vector::InsertElementOp>, 608 ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext()); 609 } 610