1 //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A pass that converts loops generated by the sparse compiler into a form that 10 // can exploit SIMD instructions of the target architecture. Note that this pass 11 // ensures the sparse compiler can generate efficient SIMD (including ArmSVE 12 // support) with proper separation of concerns as far as sparsification and 13 // vectorization is concerned. However, this pass is not the final abstraction 14 // level we want, and not the general vectorizer we want either. It forms a good 15 // stepping stone for incremental future improvements though. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "CodegenUtils.h" 20 21 #include "mlir/Dialect/Affine/IR/AffineOps.h" 22 #include "mlir/Dialect/Arith/IR/Arith.h" 23 #include "mlir/Dialect/Complex/IR/Complex.h" 24 #include "mlir/Dialect/Math/IR/Math.h" 25 #include "mlir/Dialect/MemRef/IR/MemRef.h" 26 #include "mlir/Dialect/SCF/IR/SCF.h" 27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 28 #include "mlir/Dialect/Vector/IR/VectorOps.h" 29 #include "mlir/IR/Matchers.h" 30 31 using namespace mlir; 32 using namespace mlir::sparse_tensor; 33 34 namespace { 35 36 /// Target SIMD properties: 37 /// vectorLength: # packed data elements (viz. vector<16xf32> has length 16) 38 /// enableVLAVectorization: enables scalable vectors (viz. ARMSve) 39 /// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency 40 struct VL { 41 unsigned vectorLength; 42 bool enableVLAVectorization; 43 bool enableSIMDIndex32; 44 }; 45 46 /// Helper to test for given index value. 47 static bool isIntValue(Value val, int64_t idx) { 48 if (auto ival = getConstantIntValue(val)) 49 return *ival == idx; 50 return false; 51 } 52 53 /// Constructs vector type for element type. 54 static VectorType vectorType(VL vl, Type etp) { 55 unsigned numScalableDims = vl.enableVLAVectorization; 56 return VectorType::get(vl.vectorLength, etp, numScalableDims); 57 } 58 59 /// Constructs vector type from pointer. 60 static VectorType vectorType(VL vl, Value ptr) { 61 return vectorType(vl, ptr.getType().cast<MemRefType>().getElementType()); 62 } 63 64 /// Constructs vector iteration mask. 65 static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl, 66 Value iv, Value lo, Value hi, Value step) { 67 VectorType mtp = vectorType(vl, rewriter.getI1Type()); 68 // Special case if the vector length evenly divides the trip count (for 69 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 70 // so that all subsequent masked memory operations are immediately folded 71 // into unconditional memory operations. 72 IntegerAttr loInt, hiInt, stepInt; 73 if (matchPattern(lo, m_Constant(&loInt)) && 74 matchPattern(hi, m_Constant(&hiInt)) && 75 matchPattern(step, m_Constant(&stepInt))) { 76 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) { 77 Value trueVal = constantI1(rewriter, loc, true); 78 return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal); 79 } 80 } 81 // Otherwise, generate a vector mask that avoids overrunning the upperbound 82 // during vector execution. Here we rely on subsequent loop optimizations to 83 // avoid executing the mask in all iterations, for example, by splitting the 84 // loop into an unconditional vector loop and a scalar cleanup loop. 85 auto min = AffineMap::get( 86 /*dimCount=*/2, /*symbolCount=*/1, 87 {rewriter.getAffineSymbolExpr(0), 88 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 89 rewriter.getContext()); 90 Value end = 91 rewriter.createOrFold<AffineMinOp>(loc, min, ValueRange{hi, iv, step}); 92 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 93 } 94 95 /// Generates a vectorized invariant. Here we rely on subsequent loop 96 /// optimizations to hoist the invariant broadcast out of the vector loop. 97 static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl, 98 Value val) { 99 VectorType vtp = vectorType(vl, val.getType()); 100 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 101 } 102 103 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi], 104 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note 105 /// that the sparse compiler can only generate indirect loads in 106 /// the last index, i.e. back(). 107 static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl, 108 Value ptr, ArrayRef<Value> idxs, Value vmask) { 109 VectorType vtp = vectorType(vl, ptr); 110 Value pass = constantZero(rewriter, loc, vtp); 111 if (idxs.back().getType().isa<VectorType>()) { 112 SmallVector<Value> scalarArgs(idxs.begin(), idxs.end()); 113 Value indexVec = idxs.back(); 114 scalarArgs.back() = constantIndex(rewriter, loc, 0); 115 return rewriter.create<vector::GatherOp>(loc, vtp, ptr, scalarArgs, 116 indexVec, vmask, pass); 117 } 118 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, idxs, vmask, 119 pass); 120 } 121 122 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs 123 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note 124 /// that the sparse compiler can only generate indirect stores in 125 /// the last index, i.e. back(). 126 static void genVectorStore(PatternRewriter &rewriter, Location loc, Value ptr, 127 ArrayRef<Value> idxs, Value vmask, Value rhs) { 128 if (idxs.back().getType().isa<VectorType>()) { 129 SmallVector<Value> scalarArgs(idxs.begin(), idxs.end()); 130 Value indexVec = idxs.back(); 131 scalarArgs.back() = constantIndex(rewriter, loc, 0); 132 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, vmask, 133 rhs); 134 return; 135 } 136 rewriter.create<vector::MaskedStoreOp>(loc, ptr, idxs, vmask, rhs); 137 } 138 139 /// Detects a vectorizable reduction operations and returns the 140 /// combining kind of reduction on success in `kind`. 141 static bool isVectorizableReduction(Value red, Value iter, 142 vector::CombiningKind &kind) { 143 if (auto addf = red.getDefiningOp<arith::AddFOp>()) { 144 kind = vector::CombiningKind::ADD; 145 return addf->getOperand(0) == iter || addf->getOperand(1) == iter; 146 } 147 if (auto addi = red.getDefiningOp<arith::AddIOp>()) { 148 kind = vector::CombiningKind::ADD; 149 return addi->getOperand(0) == iter || addi->getOperand(1) == iter; 150 } 151 if (auto subf = red.getDefiningOp<arith::SubFOp>()) { 152 kind = vector::CombiningKind::ADD; 153 return subf->getOperand(0) == iter; 154 } 155 if (auto subi = red.getDefiningOp<arith::SubIOp>()) { 156 kind = vector::CombiningKind::ADD; 157 return subi->getOperand(0) == iter; 158 } 159 if (auto mulf = red.getDefiningOp<arith::MulFOp>()) { 160 kind = vector::CombiningKind::MUL; 161 return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter; 162 } 163 if (auto muli = red.getDefiningOp<arith::MulIOp>()) { 164 kind = vector::CombiningKind::MUL; 165 return muli->getOperand(0) == iter || muli->getOperand(1) == iter; 166 } 167 if (auto andi = red.getDefiningOp<arith::AndIOp>()) { 168 kind = vector::CombiningKind::AND; 169 return andi->getOperand(0) == iter || andi->getOperand(1) == iter; 170 } 171 if (auto ori = red.getDefiningOp<arith::OrIOp>()) { 172 kind = vector::CombiningKind::OR; 173 return ori->getOperand(0) == iter || ori->getOperand(1) == iter; 174 } 175 if (auto xori = red.getDefiningOp<arith::XOrIOp>()) { 176 kind = vector::CombiningKind::XOR; 177 return xori->getOperand(0) == iter || xori->getOperand(1) == iter; 178 } 179 return false; 180 } 181 182 /// Generates an initial value for a vector reduction, following the scheme 183 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 184 /// initial scalar value is correctly embedded in the vector reduction value, 185 /// and a straightforward horizontal reduction will complete the operation. 186 /// Value 'r' denotes the initial value of the reduction outside the loop. 187 static Value genVectorReducInit(PatternRewriter &rewriter, Location loc, 188 Value red, Value iter, Value r, 189 VectorType vtp) { 190 vector::CombiningKind kind; 191 if (!isVectorizableReduction(red, iter, kind)) 192 llvm_unreachable("unknown reduction"); 193 switch (kind) { 194 case vector::CombiningKind::ADD: 195 case vector::CombiningKind::XOR: 196 // Initialize reduction vector to: | 0 | .. | 0 | r | 197 return rewriter.create<vector::InsertElementOp>( 198 loc, r, constantZero(rewriter, loc, vtp), 199 constantIndex(rewriter, loc, 0)); 200 case vector::CombiningKind::MUL: 201 // Initialize reduction vector to: | 1 | .. | 1 | r | 202 return rewriter.create<vector::InsertElementOp>( 203 loc, r, constantOne(rewriter, loc, vtp), 204 constantIndex(rewriter, loc, 0)); 205 case vector::CombiningKind::AND: 206 case vector::CombiningKind::OR: 207 // Initialize reduction vector to: | r | .. | r | r | 208 return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 209 default: 210 break; 211 } 212 llvm_unreachable("unknown reduction kind"); 213 } 214 215 /// This method is called twice to analyze and rewrite the given subscripts. 216 /// The first call (!codegen) does the analysis. Then, on success, the second 217 /// call (codegen) yields the proper vector form in the output parameter 218 /// vector 'idxs'. This mechanism ensures that analysis and rewriting code 219 /// stay in sync. Note that the analyis part is simple because the sparse 220 /// compiler only generates relatively simple subscript expressions. 221 /// 222 /// See https://llvm.org/docs/GetElementPtr.html for some background on 223 /// the complications described below. 224 /// 225 /// We need to generate a pointer/index load from the sparse storage scheme. 226 /// Narrower data types need to be zero extended before casting the value 227 /// into the index type used for looping and indexing. 228 /// 229 /// For the scalar case, subscripts simply zero extend narrower indices 230 /// into 64-bit values before casting to an index type without a performance 231 /// penalty. Indices that already are 64-bit, in theory, cannot express the 232 /// full range since the LLVM backend defines addressing in terms of an 233 /// unsigned pointer/signed index pair. 234 static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp, 235 VL vl, ValueRange subs, bool codegen, 236 Value vmask, SmallVectorImpl<Value> &idxs) { 237 for (auto sub : subs) { 238 // Invariant/loop indices simply pass through. 239 if (sub.dyn_cast<BlockArgument>() || 240 sub.getDefiningOp()->getBlock() != &forOp.getRegion().front()) { 241 if (codegen) 242 idxs.push_back(sub); 243 continue; // success so far 244 } 245 // Look under the hood of casting. 246 auto cast = sub; 247 while (true) { 248 if (auto icast = cast.getDefiningOp<arith::IndexCastOp>()) 249 cast = icast->getOperand(0); 250 else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>()) 251 cast = ecast->getOperand(0); 252 else 253 break; 254 } 255 // Since the index vector is used in a subsequent gather/scatter 256 // operations, which effectively defines an unsigned pointer + signed 257 // index, we must zero extend the vector to an index width. For 8-bit 258 // and 16-bit values, an 32-bit index width suffices. For 32-bit values, 259 // zero extending the elements into 64-bit loses some performance since 260 // the 32-bit indexed gather/scatter is more efficient than the 64-bit 261 // index variant (if the negative 32-bit index space is unused, the 262 // enableSIMDIndex32 flag can preserve this performance). For 64-bit 263 // values, there is no good way to state that the indices are unsigned, 264 // which creates the potential of incorrect address calculations in the 265 // unlikely case we need such extremely large offsets. 266 if (auto load = cast.getDefiningOp<memref::LoadOp>()) { 267 if (codegen) { 268 SmallVector<Value> idxs2(load.getIndices()); // no need to analyze 269 Location loc = forOp.getLoc(); 270 Value vload = 271 genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask); 272 Type etp = vload.getType().cast<VectorType>().getElementType(); 273 if (!etp.isa<IndexType>()) { 274 if (etp.getIntOrFloatBitWidth() < 32) 275 vload = rewriter.create<arith::ExtUIOp>( 276 loc, vectorType(vl, rewriter.getI32Type()), vload); 277 else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32) 278 vload = rewriter.create<arith::ExtUIOp>( 279 loc, vectorType(vl, rewriter.getI64Type()), vload); 280 } 281 idxs.push_back(vload); 282 } 283 continue; // success so far 284 } 285 // Address calculation 'i = add inv, idx' (after LICM). 286 if (auto load = cast.getDefiningOp<arith::AddIOp>()) { 287 Value inv = load.getOperand(0); 288 Value idx = load.getOperand(1); 289 if (!inv.dyn_cast<BlockArgument>() && 290 inv.getDefiningOp()->getBlock() != &forOp.getRegion().front() && 291 idx.dyn_cast<BlockArgument>()) { 292 if (codegen) 293 idxs.push_back( 294 rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx)); 295 continue; // success so far 296 } 297 } 298 return false; 299 } 300 return true; 301 } 302 303 #define UNAOP(xxx) \ 304 if (isa<xxx>(def)) { \ 305 if (codegen) \ 306 vexp = rewriter.create<xxx>(loc, vx); \ 307 return true; \ 308 } 309 310 #define TYPEDUNAOP(xxx) \ 311 if (auto x = dyn_cast<xxx>(def)) { \ 312 if (codegen) { \ 313 VectorType vtp = vectorType(vl, x.getType()); \ 314 vexp = rewriter.create<xxx>(loc, vtp, vx); \ 315 } \ 316 return true; \ 317 } 318 319 #define BINOP(xxx) \ 320 if (isa<xxx>(def)) { \ 321 if (codegen) \ 322 vexp = rewriter.create<xxx>(loc, vx, vy); \ 323 return true; \ 324 } 325 326 /// This method is called twice to analyze and rewrite the given expression. 327 /// The first call (!codegen) does the analysis. Then, on success, the second 328 /// call (codegen) yields the proper vector form in the output parameter 'vexp'. 329 /// This mechanism ensures that analysis and rewriting code stay in sync. Note 330 /// that the analyis part is simple because the sparse compiler only generates 331 /// relatively simple expressions inside the for-loops. 332 static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, 333 Value exp, bool codegen, Value vmask, Value &vexp) { 334 Location loc = forOp.getLoc(); 335 // Reject unsupported types. 336 if (!VectorType::isValidElementType(exp.getType())) 337 return false; 338 // A block argument is invariant/reduction/index. 339 if (auto arg = exp.dyn_cast<BlockArgument>()) { 340 if (arg == forOp.getInductionVar()) { 341 // We encountered a single, innermost index inside the computation, 342 // such as a[i] = i, which must convert to [i, i+1, ...]. 343 if (codegen) { 344 VectorType vtp = vectorType(vl, arg.getType()); 345 Value veci = rewriter.create<vector::BroadcastOp>(loc, vtp, arg); 346 Value incr; 347 if (vl.enableVLAVectorization) { 348 Type stepvty = vectorType(vl, rewriter.getI64Type()); 349 Value stepv = rewriter.create<LLVM::StepVectorOp>(loc, stepvty); 350 incr = rewriter.create<arith::IndexCastOp>(loc, vtp, stepv); 351 } else { 352 SmallVector<APInt> integers; 353 for (unsigned i = 0, l = vl.vectorLength; i < l; i++) 354 integers.push_back(APInt(/*width=*/64, i)); 355 auto values = DenseElementsAttr::get(vtp, integers); 356 incr = rewriter.create<arith::ConstantOp>(loc, vtp, values); 357 } 358 vexp = rewriter.create<arith::AddIOp>(loc, veci, incr); 359 } 360 return true; 361 } else { 362 // An invariant or reduction. In both cases, we treat this as an 363 // invariant value, and rely on later replacing and folding to 364 // construct a proper reduction chain for the latter case. 365 if (codegen) 366 vexp = genVectorInvariantValue(rewriter, vl, exp); 367 return true; 368 } 369 } 370 // Something defined outside the loop-body is invariant. 371 Operation *def = exp.getDefiningOp(); 372 if (def->getBlock() != &forOp.getRegion().front()) { 373 if (codegen) 374 vexp = genVectorInvariantValue(rewriter, vl, exp); 375 return true; 376 } 377 // Proper load operations. These are either values involved in the 378 // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi], 379 // or index values inside the computation that are now fetched from 380 // the sparse storage index arrays, such as a[i] = i becomes 381 // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index 382 // and 'hi = lo + vl - 1'. 383 if (auto load = dyn_cast<memref::LoadOp>(def)) { 384 auto subs = load.getIndices(); 385 SmallVector<Value> idxs; 386 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) { 387 if (codegen) 388 vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask); 389 return true; 390 } 391 return false; 392 } 393 // Inside loop-body unary and binary operations. Note that it would be 394 // nicer if we could somehow test and build the operations in a more 395 // concise manner than just listing them all (although this way we know 396 // for certain that they can vectorize). 397 // 398 // TODO: avoid visiting CSEs multiple times 399 // 400 if (def->getNumOperands() == 1) { 401 Value vx; 402 if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, 403 vx)) { 404 UNAOP(math::AbsFOp) 405 UNAOP(math::AbsIOp) 406 UNAOP(math::CeilOp) 407 UNAOP(math::FloorOp) 408 UNAOP(math::SqrtOp) 409 UNAOP(math::ExpM1Op) 410 UNAOP(math::Log1pOp) 411 UNAOP(math::SinOp) 412 UNAOP(math::TanhOp) 413 UNAOP(arith::NegFOp) 414 TYPEDUNAOP(arith::TruncFOp) 415 TYPEDUNAOP(arith::ExtFOp) 416 TYPEDUNAOP(arith::FPToSIOp) 417 TYPEDUNAOP(arith::FPToUIOp) 418 TYPEDUNAOP(arith::SIToFPOp) 419 TYPEDUNAOP(arith::UIToFPOp) 420 TYPEDUNAOP(arith::ExtSIOp) 421 TYPEDUNAOP(arith::ExtUIOp) 422 TYPEDUNAOP(arith::IndexCastOp) 423 TYPEDUNAOP(arith::TruncIOp) 424 TYPEDUNAOP(arith::BitcastOp) 425 // TODO: complex? 426 } 427 } else if (def->getNumOperands() == 2) { 428 Value vx, vy; 429 if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, 430 vx) && 431 vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask, 432 vy)) { 433 BINOP(arith::MulFOp) 434 BINOP(arith::MulIOp) 435 BINOP(arith::DivFOp) 436 BINOP(arith::DivSIOp) 437 BINOP(arith::DivUIOp) 438 BINOP(arith::AddFOp) 439 BINOP(arith::AddIOp) 440 BINOP(arith::SubFOp) 441 BINOP(arith::SubIOp) 442 BINOP(arith::AndIOp) 443 BINOP(arith::OrIOp) 444 BINOP(arith::XOrIOp) 445 // TODO: complex? 446 // TODO: shift by invariant? 447 } 448 } 449 return false; 450 } 451 452 #undef UNAOP 453 #undef TYPEDUNAOP 454 #undef BINOP 455 456 /// This method is called twice to analyze and rewrite the given for-loop. 457 /// The first call (!codegen) does the analysis. Then, on success, the second 458 /// call (codegen) rewriters the IR into vector form. This mechanism ensures 459 /// that analysis and rewriting code stay in sync. 460 static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, 461 bool codegen) { 462 Location loc = forOp.getLoc(); 463 Block &block = forOp.getRegion().front(); 464 scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator()); 465 auto &last = *++block.rbegin(); 466 scf::ForOp forOpNew; 467 468 // Perform initial set up during codegen (we know that the first analysis 469 // pass was successful). For reductions, we need to construct a completely 470 // new for-loop, since the incoming and outgoing reduction type 471 // changes into SIMD form. For stores, we can simply adjust the stride 472 // and insert in the existing for-loop. In both cases, we set up a vector 473 // mask for all operations which takes care of confining vectors to 474 // the original iteration space (later cleanup loops or other 475 // optimizations can take care of those). 476 Value vmask; 477 if (codegen) { 478 Value step = constantIndex(rewriter, loc, vl.vectorLength); 479 if (vl.enableVLAVectorization) { 480 Value vscale = 481 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); 482 step = rewriter.create<arith::MulIOp>(loc, vscale, step); 483 } 484 if (!yield.getResults().empty()) { 485 Value init = forOp.getInitArgs()[0]; 486 VectorType vtp = vectorType(vl, init.getType()); 487 Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0), 488 forOp.getRegionIterArg(0), init, vtp); 489 forOpNew = rewriter.create<scf::ForOp>( 490 loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit); 491 forOpNew->setAttr( 492 SparseTensorLoopEmitter::getLoopEmitterLoopAttrName(), 493 forOp->getAttr( 494 SparseTensorLoopEmitter::getLoopEmitterLoopAttrName())); 495 rewriter.setInsertionPointToStart(forOpNew.getBody()); 496 } else { 497 forOp.setStep(step); 498 rewriter.setInsertionPoint(yield); 499 } 500 vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(), 501 forOp.getLowerBound(), forOp.getUpperBound(), step); 502 } 503 504 // Sparse for-loops either are terminated by a non-empty yield operation 505 // (reduction loop) or otherwise by a store operation (pararallel loop). 506 if (!yield.getResults().empty()) { 507 // Analyze/vectorize reduction. 508 if (yield->getNumOperands() != 1) 509 return false; 510 Value red = yield->getOperand(0); 511 Value iter = forOp.getRegionIterArg(0); 512 vector::CombiningKind kind; 513 Value vrhs; 514 if (isVectorizableReduction(red, iter, kind) && 515 vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) { 516 if (codegen) { 517 Value partial = forOpNew.getResult(0); 518 Value vpass = genVectorInvariantValue(rewriter, vl, iter); 519 Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass); 520 rewriter.create<scf::YieldOp>(loc, vred); 521 rewriter.setInsertionPointAfter(forOpNew); 522 Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial); 523 // Now do some relinking (last one is not completely type safe 524 // but all bad ones are removed right away). This also folds away 525 // nop broadcast operations. 526 forOp.getResult(0).replaceAllUsesWith(vres); 527 forOp.getInductionVar().replaceAllUsesWith(forOpNew.getInductionVar()); 528 forOp.getRegionIterArg(0).replaceAllUsesWith( 529 forOpNew.getRegionIterArg(0)); 530 rewriter.eraseOp(forOp); 531 } 532 return true; 533 } 534 } else if (auto store = dyn_cast<memref::StoreOp>(last)) { 535 // Analyze/vectorize store operation. 536 auto subs = store.getIndices(); 537 SmallVector<Value> idxs; 538 Value rhs = store.getValue(); 539 Value vrhs; 540 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) && 541 vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) { 542 if (codegen) { 543 genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs); 544 rewriter.eraseOp(store); 545 } 546 return true; 547 } 548 } 549 550 assert(!codegen && "cannot call codegen when analysis failed"); 551 return false; 552 } 553 554 /// Basic for-loop vectorizer. 555 struct ForOpRewriter : public OpRewritePattern<scf::ForOp> { 556 public: 557 using OpRewritePattern<scf::ForOp>::OpRewritePattern; 558 559 ForOpRewriter(MLIRContext *context, unsigned vectorLength, 560 bool enableVLAVectorization, bool enableSIMDIndex32) 561 : OpRewritePattern(context), 562 vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {} 563 564 LogicalResult matchAndRewrite(scf::ForOp op, 565 PatternRewriter &rewriter) const override { 566 // Check for single block, unit-stride for-loop that is generated by 567 // sparse compiler, which means no data dependence analysis is required, 568 // and its loop-body is very restricted in form. 569 if (!op.getRegion().hasOneBlock() || !isIntValue(op.getStep(), 1) || 570 !op->hasAttr(SparseTensorLoopEmitter::getLoopEmitterLoopAttrName())) 571 return failure(); 572 // Analyze (!codegen) and rewrite (codegen) loop-body. 573 if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) && 574 vectorizeStmt(rewriter, op, vl, /*codegen=*/true)) 575 return success(); 576 return failure(); 577 } 578 579 private: 580 const VL vl; 581 }; 582 583 /// Reduction chain cleanup. 584 /// v = for { } 585 /// s = vsum(v) v = for { } 586 /// u = expand(s) -> for (v) { } 587 /// for (u) { } 588 template <typename VectorOp> 589 struct ReducChainRewriter : public OpRewritePattern<VectorOp> { 590 public: 591 using OpRewritePattern<VectorOp>::OpRewritePattern; 592 593 LogicalResult matchAndRewrite(VectorOp op, 594 PatternRewriter &rewriter) const override { 595 Value inp = op.getSource(); 596 if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) { 597 if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) { 598 if (forOp->hasAttr( 599 SparseTensorLoopEmitter::getLoopEmitterLoopAttrName())) { 600 rewriter.replaceOp(op, redOp.getVector()); 601 return success(); 602 } 603 } 604 } 605 return failure(); 606 } 607 }; 608 609 } // namespace 610 611 //===----------------------------------------------------------------------===// 612 // Public method for populating vectorization rules. 613 //===----------------------------------------------------------------------===// 614 615 /// Populates the given patterns list with vectorization rules. 616 void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns, 617 unsigned vectorLength, 618 bool enableVLAVectorization, 619 bool enableSIMDIndex32) { 620 assert(vectorLength > 0); 621 patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength, 622 enableVLAVectorization, enableSIMDIndex32); 623 patterns.add<ReducChainRewriter<vector::InsertElementOp>, 624 ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext()); 625 } 626