1 //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A pass that converts loops generated by the sparse compiler into a form that 10 // can exploit SIMD instructions of the target architecture. Note that this pass 11 // ensures the sparse compiler can generate efficient SIMD (including ArmSVE 12 // support) with proper separation of concerns as far as sparsification and 13 // vectorization is concerned. However, this pass is not the final abstraction 14 // level we want, and not the general vectorizer we want either. It forms a good 15 // stepping stone for incremental future improvements though. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "CodegenUtils.h" 20 #include "LoopEmitter.h" 21 22 #include "mlir/Dialect/Affine/IR/AffineOps.h" 23 #include "mlir/Dialect/Arith/IR/Arith.h" 24 #include "mlir/Dialect/Complex/IR/Complex.h" 25 #include "mlir/Dialect/Math/IR/Math.h" 26 #include "mlir/Dialect/MemRef/IR/MemRef.h" 27 #include "mlir/Dialect/SCF/IR/SCF.h" 28 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 29 #include "mlir/Dialect/Vector/IR/VectorOps.h" 30 #include "mlir/IR/Matchers.h" 31 32 using namespace mlir; 33 using namespace mlir::sparse_tensor; 34 35 namespace { 36 37 /// Target SIMD properties: 38 /// vectorLength: # packed data elements (viz. vector<16xf32> has length 16) 39 /// enableVLAVectorization: enables scalable vectors (viz. ARMSve) 40 /// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency 41 struct VL { 42 unsigned vectorLength; 43 bool enableVLAVectorization; 44 bool enableSIMDIndex32; 45 }; 46 47 /// Helper to test for given index value. 48 static bool isIntValue(Value val, int64_t idx) { 49 if (auto ival = getConstantIntValue(val)) 50 return *ival == idx; 51 return false; 52 } 53 54 /// Helper test for invariant value (defined outside given block). 55 static bool isInvariantValue(Value val, Block *block) { 56 return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block; 57 } 58 59 /// Helper test for invariant argument (defined outside given block). 60 static bool isInvariantArg(BlockArgument arg, Block *block) { 61 return arg.getOwner() != block; 62 } 63 64 /// Constructs vector type for element type. 65 static VectorType vectorType(VL vl, Type etp) { 66 unsigned numScalableDims = vl.enableVLAVectorization; 67 return VectorType::get(vl.vectorLength, etp, numScalableDims); 68 } 69 70 /// Constructs vector type from pointer. 71 static VectorType vectorType(VL vl, Value ptr) { 72 return vectorType(vl, ptr.getType().cast<MemRefType>().getElementType()); 73 } 74 75 /// Constructs vector iteration mask. 76 static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl, 77 Value iv, Value lo, Value hi, Value step) { 78 VectorType mtp = vectorType(vl, rewriter.getI1Type()); 79 // Special case if the vector length evenly divides the trip count (for 80 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 81 // so that all subsequent masked memory operations are immediately folded 82 // into unconditional memory operations. 83 IntegerAttr loInt, hiInt, stepInt; 84 if (matchPattern(lo, m_Constant(&loInt)) && 85 matchPattern(hi, m_Constant(&hiInt)) && 86 matchPattern(step, m_Constant(&stepInt))) { 87 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) { 88 Value trueVal = constantI1(rewriter, loc, true); 89 return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal); 90 } 91 } 92 // Otherwise, generate a vector mask that avoids overrunning the upperbound 93 // during vector execution. Here we rely on subsequent loop optimizations to 94 // avoid executing the mask in all iterations, for example, by splitting the 95 // loop into an unconditional vector loop and a scalar cleanup loop. 96 auto min = AffineMap::get( 97 /*dimCount=*/2, /*symbolCount=*/1, 98 {rewriter.getAffineSymbolExpr(0), 99 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 100 rewriter.getContext()); 101 Value end = 102 rewriter.createOrFold<AffineMinOp>(loc, min, ValueRange{hi, iv, step}); 103 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 104 } 105 106 /// Generates a vectorized invariant. Here we rely on subsequent loop 107 /// optimizations to hoist the invariant broadcast out of the vector loop. 108 static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl, 109 Value val) { 110 VectorType vtp = vectorType(vl, val.getType()); 111 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 112 } 113 114 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi], 115 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note 116 /// that the sparse compiler can only generate indirect loads in 117 /// the last index, i.e. back(). 118 static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl, 119 Value ptr, ArrayRef<Value> idxs, Value vmask) { 120 VectorType vtp = vectorType(vl, ptr); 121 Value pass = constantZero(rewriter, loc, vtp); 122 if (idxs.back().getType().isa<VectorType>()) { 123 SmallVector<Value> scalarArgs(idxs.begin(), idxs.end()); 124 Value indexVec = idxs.back(); 125 scalarArgs.back() = constantIndex(rewriter, loc, 0); 126 return rewriter.create<vector::GatherOp>(loc, vtp, ptr, scalarArgs, 127 indexVec, vmask, pass); 128 } 129 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, idxs, vmask, 130 pass); 131 } 132 133 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs 134 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note 135 /// that the sparse compiler can only generate indirect stores in 136 /// the last index, i.e. back(). 137 static void genVectorStore(PatternRewriter &rewriter, Location loc, Value ptr, 138 ArrayRef<Value> idxs, Value vmask, Value rhs) { 139 if (idxs.back().getType().isa<VectorType>()) { 140 SmallVector<Value> scalarArgs(idxs.begin(), idxs.end()); 141 Value indexVec = idxs.back(); 142 scalarArgs.back() = constantIndex(rewriter, loc, 0); 143 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, vmask, 144 rhs); 145 return; 146 } 147 rewriter.create<vector::MaskedStoreOp>(loc, ptr, idxs, vmask, rhs); 148 } 149 150 /// Detects a vectorizable reduction operations and returns the 151 /// combining kind of reduction on success in `kind`. 152 static bool isVectorizableReduction(Value red, Value iter, 153 vector::CombiningKind &kind) { 154 if (auto addf = red.getDefiningOp<arith::AddFOp>()) { 155 kind = vector::CombiningKind::ADD; 156 return addf->getOperand(0) == iter || addf->getOperand(1) == iter; 157 } 158 if (auto addi = red.getDefiningOp<arith::AddIOp>()) { 159 kind = vector::CombiningKind::ADD; 160 return addi->getOperand(0) == iter || addi->getOperand(1) == iter; 161 } 162 if (auto subf = red.getDefiningOp<arith::SubFOp>()) { 163 kind = vector::CombiningKind::ADD; 164 return subf->getOperand(0) == iter; 165 } 166 if (auto subi = red.getDefiningOp<arith::SubIOp>()) { 167 kind = vector::CombiningKind::ADD; 168 return subi->getOperand(0) == iter; 169 } 170 if (auto mulf = red.getDefiningOp<arith::MulFOp>()) { 171 kind = vector::CombiningKind::MUL; 172 return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter; 173 } 174 if (auto muli = red.getDefiningOp<arith::MulIOp>()) { 175 kind = vector::CombiningKind::MUL; 176 return muli->getOperand(0) == iter || muli->getOperand(1) == iter; 177 } 178 if (auto andi = red.getDefiningOp<arith::AndIOp>()) { 179 kind = vector::CombiningKind::AND; 180 return andi->getOperand(0) == iter || andi->getOperand(1) == iter; 181 } 182 if (auto ori = red.getDefiningOp<arith::OrIOp>()) { 183 kind = vector::CombiningKind::OR; 184 return ori->getOperand(0) == iter || ori->getOperand(1) == iter; 185 } 186 if (auto xori = red.getDefiningOp<arith::XOrIOp>()) { 187 kind = vector::CombiningKind::XOR; 188 return xori->getOperand(0) == iter || xori->getOperand(1) == iter; 189 } 190 return false; 191 } 192 193 /// Generates an initial value for a vector reduction, following the scheme 194 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 195 /// initial scalar value is correctly embedded in the vector reduction value, 196 /// and a straightforward horizontal reduction will complete the operation. 197 /// Value 'r' denotes the initial value of the reduction outside the loop. 198 static Value genVectorReducInit(PatternRewriter &rewriter, Location loc, 199 Value red, Value iter, Value r, 200 VectorType vtp) { 201 vector::CombiningKind kind; 202 if (!isVectorizableReduction(red, iter, kind)) 203 llvm_unreachable("unknown reduction"); 204 switch (kind) { 205 case vector::CombiningKind::ADD: 206 case vector::CombiningKind::XOR: 207 // Initialize reduction vector to: | 0 | .. | 0 | r | 208 return rewriter.create<vector::InsertElementOp>( 209 loc, r, constantZero(rewriter, loc, vtp), 210 constantIndex(rewriter, loc, 0)); 211 case vector::CombiningKind::MUL: 212 // Initialize reduction vector to: | 1 | .. | 1 | r | 213 return rewriter.create<vector::InsertElementOp>( 214 loc, r, constantOne(rewriter, loc, vtp), 215 constantIndex(rewriter, loc, 0)); 216 case vector::CombiningKind::AND: 217 case vector::CombiningKind::OR: 218 // Initialize reduction vector to: | r | .. | r | r | 219 return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 220 default: 221 break; 222 } 223 llvm_unreachable("unknown reduction kind"); 224 } 225 226 /// This method is called twice to analyze and rewrite the given subscripts. 227 /// The first call (!codegen) does the analysis. Then, on success, the second 228 /// call (codegen) yields the proper vector form in the output parameter 229 /// vector 'idxs'. This mechanism ensures that analysis and rewriting code 230 /// stay in sync. Note that the analyis part is simple because the sparse 231 /// compiler only generates relatively simple subscript expressions. 232 /// 233 /// See https://llvm.org/docs/GetElementPtr.html for some background on 234 /// the complications described below. 235 /// 236 /// We need to generate a pointer/index load from the sparse storage scheme. 237 /// Narrower data types need to be zero extended before casting the value 238 /// into the index type used for looping and indexing. 239 /// 240 /// For the scalar case, subscripts simply zero extend narrower indices 241 /// into 64-bit values before casting to an index type without a performance 242 /// penalty. Indices that already are 64-bit, in theory, cannot express the 243 /// full range since the LLVM backend defines addressing in terms of an 244 /// unsigned pointer/signed index pair. 245 static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp, 246 VL vl, ValueRange subs, bool codegen, 247 Value vmask, SmallVectorImpl<Value> &idxs) { 248 unsigned d = 0; 249 unsigned dim = subs.size(); 250 Block *block = &forOp.getRegion().front(); 251 for (auto sub : subs) { 252 bool innermost = ++d == dim; 253 // Invariant subscripts in outer dimensions simply pass through. 254 // Note that we rely on LICM to hoist loads where all subscripts 255 // are invariant in the innermost loop. 256 // Example: 257 // a[inv][i] for inv 258 if (isInvariantValue(sub, block)) { 259 if (innermost) 260 return false; 261 if (codegen) 262 idxs.push_back(sub); 263 continue; // success so far 264 } 265 // Invariant block arguments (including outer loop indices) in outer 266 // dimensions simply pass through. Direct loop indices in the 267 // innermost loop simply pass through as well. 268 // Example: 269 // a[i][j] for both i and j 270 if (auto arg = sub.dyn_cast<BlockArgument>()) { 271 if (isInvariantArg(arg, block) == innermost) 272 return false; 273 if (codegen) 274 idxs.push_back(sub); 275 continue; // success so far 276 } 277 // Look under the hood of casting. 278 auto cast = sub; 279 while (true) { 280 if (auto icast = cast.getDefiningOp<arith::IndexCastOp>()) 281 cast = icast->getOperand(0); 282 else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>()) 283 cast = ecast->getOperand(0); 284 else 285 break; 286 } 287 // Since the index vector is used in a subsequent gather/scatter 288 // operations, which effectively defines an unsigned pointer + signed 289 // index, we must zero extend the vector to an index width. For 8-bit 290 // and 16-bit values, an 32-bit index width suffices. For 32-bit values, 291 // zero extending the elements into 64-bit loses some performance since 292 // the 32-bit indexed gather/scatter is more efficient than the 64-bit 293 // index variant (if the negative 32-bit index space is unused, the 294 // enableSIMDIndex32 flag can preserve this performance). For 64-bit 295 // values, there is no good way to state that the indices are unsigned, 296 // which creates the potential of incorrect address calculations in the 297 // unlikely case we need such extremely large offsets. 298 // Example: 299 // a[ ind[i] ] 300 if (auto load = cast.getDefiningOp<memref::LoadOp>()) { 301 if (!innermost) 302 return false; 303 if (codegen) { 304 SmallVector<Value> idxs2(load.getIndices()); // no need to analyze 305 Location loc = forOp.getLoc(); 306 Value vload = 307 genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask); 308 Type etp = vload.getType().cast<VectorType>().getElementType(); 309 if (!etp.isa<IndexType>()) { 310 if (etp.getIntOrFloatBitWidth() < 32) 311 vload = rewriter.create<arith::ExtUIOp>( 312 loc, vectorType(vl, rewriter.getI32Type()), vload); 313 else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32) 314 vload = rewriter.create<arith::ExtUIOp>( 315 loc, vectorType(vl, rewriter.getI64Type()), vload); 316 } 317 idxs.push_back(vload); 318 } 319 continue; // success so far 320 } 321 // Address calculation 'i = add inv, idx' (after LICM). 322 // Example: 323 // a[base + i] 324 if (auto load = cast.getDefiningOp<arith::AddIOp>()) { 325 Value inv = load.getOperand(0); 326 Value idx = load.getOperand(1); 327 if (isInvariantValue(inv, block)) { 328 if (auto arg = idx.dyn_cast<BlockArgument>()) { 329 if (isInvariantArg(arg, block) || !innermost) 330 return false; 331 if (codegen) 332 idxs.push_back( 333 rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx)); 334 continue; // success so far 335 } 336 } 337 } 338 return false; 339 } 340 return true; 341 } 342 343 #define UNAOP(xxx) \ 344 if (isa<xxx>(def)) { \ 345 if (codegen) \ 346 vexp = rewriter.create<xxx>(loc, vx); \ 347 return true; \ 348 } 349 350 #define TYPEDUNAOP(xxx) \ 351 if (auto x = dyn_cast<xxx>(def)) { \ 352 if (codegen) { \ 353 VectorType vtp = vectorType(vl, x.getType()); \ 354 vexp = rewriter.create<xxx>(loc, vtp, vx); \ 355 } \ 356 return true; \ 357 } 358 359 #define BINOP(xxx) \ 360 if (isa<xxx>(def)) { \ 361 if (codegen) \ 362 vexp = rewriter.create<xxx>(loc, vx, vy); \ 363 return true; \ 364 } 365 366 /// This method is called twice to analyze and rewrite the given expression. 367 /// The first call (!codegen) does the analysis. Then, on success, the second 368 /// call (codegen) yields the proper vector form in the output parameter 'vexp'. 369 /// This mechanism ensures that analysis and rewriting code stay in sync. Note 370 /// that the analyis part is simple because the sparse compiler only generates 371 /// relatively simple expressions inside the for-loops. 372 static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, 373 Value exp, bool codegen, Value vmask, Value &vexp) { 374 Location loc = forOp.getLoc(); 375 // Reject unsupported types. 376 if (!VectorType::isValidElementType(exp.getType())) 377 return false; 378 // A block argument is invariant/reduction/index. 379 if (auto arg = exp.dyn_cast<BlockArgument>()) { 380 if (arg == forOp.getInductionVar()) { 381 // We encountered a single, innermost index inside the computation, 382 // such as a[i] = i, which must convert to [i, i+1, ...]. 383 if (codegen) { 384 VectorType vtp = vectorType(vl, arg.getType()); 385 Value veci = rewriter.create<vector::BroadcastOp>(loc, vtp, arg); 386 Value incr; 387 if (vl.enableVLAVectorization) { 388 Type stepvty = vectorType(vl, rewriter.getI64Type()); 389 Value stepv = rewriter.create<LLVM::StepVectorOp>(loc, stepvty); 390 incr = rewriter.create<arith::IndexCastOp>(loc, vtp, stepv); 391 } else { 392 SmallVector<APInt> integers; 393 for (unsigned i = 0, l = vl.vectorLength; i < l; i++) 394 integers.push_back(APInt(/*width=*/64, i)); 395 auto values = DenseElementsAttr::get(vtp, integers); 396 incr = rewriter.create<arith::ConstantOp>(loc, vtp, values); 397 } 398 vexp = rewriter.create<arith::AddIOp>(loc, veci, incr); 399 } 400 return true; 401 } 402 // An invariant or reduction. In both cases, we treat this as an 403 // invariant value, and rely on later replacing and folding to 404 // construct a proper reduction chain for the latter case. 405 if (codegen) 406 vexp = genVectorInvariantValue(rewriter, vl, exp); 407 return true; 408 } 409 // Something defined outside the loop-body is invariant. 410 Operation *def = exp.getDefiningOp(); 411 Block *block = &forOp.getRegion().front(); 412 if (def->getBlock() != block) { 413 if (codegen) 414 vexp = genVectorInvariantValue(rewriter, vl, exp); 415 return true; 416 } 417 // Proper load operations. These are either values involved in the 418 // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi], 419 // or index values inside the computation that are now fetched from 420 // the sparse storage index arrays, such as a[i] = i becomes 421 // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index 422 // and 'hi = lo + vl - 1'. 423 if (auto load = dyn_cast<memref::LoadOp>(def)) { 424 auto subs = load.getIndices(); 425 SmallVector<Value> idxs; 426 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) { 427 if (codegen) 428 vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask); 429 return true; 430 } 431 return false; 432 } 433 // Inside loop-body unary and binary operations. Note that it would be 434 // nicer if we could somehow test and build the operations in a more 435 // concise manner than just listing them all (although this way we know 436 // for certain that they can vectorize). 437 // 438 // TODO: avoid visiting CSEs multiple times 439 // 440 if (def->getNumOperands() == 1) { 441 Value vx; 442 if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, 443 vx)) { 444 UNAOP(math::AbsFOp) 445 UNAOP(math::AbsIOp) 446 UNAOP(math::CeilOp) 447 UNAOP(math::FloorOp) 448 UNAOP(math::SqrtOp) 449 UNAOP(math::ExpM1Op) 450 UNAOP(math::Log1pOp) 451 UNAOP(math::SinOp) 452 UNAOP(math::TanhOp) 453 UNAOP(arith::NegFOp) 454 TYPEDUNAOP(arith::TruncFOp) 455 TYPEDUNAOP(arith::ExtFOp) 456 TYPEDUNAOP(arith::FPToSIOp) 457 TYPEDUNAOP(arith::FPToUIOp) 458 TYPEDUNAOP(arith::SIToFPOp) 459 TYPEDUNAOP(arith::UIToFPOp) 460 TYPEDUNAOP(arith::ExtSIOp) 461 TYPEDUNAOP(arith::ExtUIOp) 462 TYPEDUNAOP(arith::IndexCastOp) 463 TYPEDUNAOP(arith::TruncIOp) 464 TYPEDUNAOP(arith::BitcastOp) 465 // TODO: complex? 466 } 467 } else if (def->getNumOperands() == 2) { 468 Value vx, vy; 469 if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, 470 vx) && 471 vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask, 472 vy)) { 473 // We only accept shift-by-invariant (where the same shift factor applies 474 // to all packed elements). In the vector dialect, this is still 475 // represented with an expanded vector at the right-hand-side, however, 476 // so that we do not have to special case the code generation. 477 if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) || 478 isa<arith::ShRSIOp>(def)) { 479 Value shiftFactor = def->getOperand(1); 480 if (!isInvariantValue(shiftFactor, block)) 481 return false; 482 } 483 // Generate code. 484 BINOP(arith::MulFOp) 485 BINOP(arith::MulIOp) 486 BINOP(arith::DivFOp) 487 BINOP(arith::DivSIOp) 488 BINOP(arith::DivUIOp) 489 BINOP(arith::AddFOp) 490 BINOP(arith::AddIOp) 491 BINOP(arith::SubFOp) 492 BINOP(arith::SubIOp) 493 BINOP(arith::AndIOp) 494 BINOP(arith::OrIOp) 495 BINOP(arith::XOrIOp) 496 BINOP(arith::ShLIOp) 497 BINOP(arith::ShRUIOp) 498 BINOP(arith::ShRSIOp) 499 // TODO: complex? 500 } 501 } 502 return false; 503 } 504 505 #undef UNAOP 506 #undef TYPEDUNAOP 507 #undef BINOP 508 509 /// This method is called twice to analyze and rewrite the given for-loop. 510 /// The first call (!codegen) does the analysis. Then, on success, the second 511 /// call (codegen) rewriters the IR into vector form. This mechanism ensures 512 /// that analysis and rewriting code stay in sync. 513 static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, 514 bool codegen) { 515 Location loc = forOp.getLoc(); 516 Block &block = forOp.getRegion().front(); 517 scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator()); 518 auto &last = *++block.rbegin(); 519 scf::ForOp forOpNew; 520 521 // Perform initial set up during codegen (we know that the first analysis 522 // pass was successful). For reductions, we need to construct a completely 523 // new for-loop, since the incoming and outgoing reduction type 524 // changes into SIMD form. For stores, we can simply adjust the stride 525 // and insert in the existing for-loop. In both cases, we set up a vector 526 // mask for all operations which takes care of confining vectors to 527 // the original iteration space (later cleanup loops or other 528 // optimizations can take care of those). 529 Value vmask; 530 if (codegen) { 531 Value step = constantIndex(rewriter, loc, vl.vectorLength); 532 if (vl.enableVLAVectorization) { 533 Value vscale = 534 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); 535 step = rewriter.create<arith::MulIOp>(loc, vscale, step); 536 } 537 if (!yield.getResults().empty()) { 538 Value init = forOp.getInitArgs()[0]; 539 VectorType vtp = vectorType(vl, init.getType()); 540 Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0), 541 forOp.getRegionIterArg(0), init, vtp); 542 forOpNew = rewriter.create<scf::ForOp>( 543 loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit); 544 forOpNew->setAttr( 545 LoopEmitter::getLoopEmitterLoopAttrName(), 546 forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName())); 547 rewriter.setInsertionPointToStart(forOpNew.getBody()); 548 } else { 549 forOp.setStep(step); 550 rewriter.setInsertionPoint(yield); 551 } 552 vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(), 553 forOp.getLowerBound(), forOp.getUpperBound(), step); 554 } 555 556 // Sparse for-loops either are terminated by a non-empty yield operation 557 // (reduction loop) or otherwise by a store operation (pararallel loop). 558 if (!yield.getResults().empty()) { 559 // Analyze/vectorize reduction. 560 if (yield->getNumOperands() != 1) 561 return false; 562 Value red = yield->getOperand(0); 563 Value iter = forOp.getRegionIterArg(0); 564 vector::CombiningKind kind; 565 Value vrhs; 566 if (isVectorizableReduction(red, iter, kind) && 567 vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) { 568 if (codegen) { 569 Value partial = forOpNew.getResult(0); 570 Value vpass = genVectorInvariantValue(rewriter, vl, iter); 571 Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass); 572 rewriter.create<scf::YieldOp>(loc, vred); 573 rewriter.setInsertionPointAfter(forOpNew); 574 Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial); 575 // Now do some relinking (last one is not completely type safe 576 // but all bad ones are removed right away). This also folds away 577 // nop broadcast operations. 578 forOp.getResult(0).replaceAllUsesWith(vres); 579 forOp.getInductionVar().replaceAllUsesWith(forOpNew.getInductionVar()); 580 forOp.getRegionIterArg(0).replaceAllUsesWith( 581 forOpNew.getRegionIterArg(0)); 582 rewriter.eraseOp(forOp); 583 } 584 return true; 585 } 586 } else if (auto store = dyn_cast<memref::StoreOp>(last)) { 587 // Analyze/vectorize store operation. 588 auto subs = store.getIndices(); 589 SmallVector<Value> idxs; 590 Value rhs = store.getValue(); 591 Value vrhs; 592 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) && 593 vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) { 594 if (codegen) { 595 genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs); 596 rewriter.eraseOp(store); 597 } 598 return true; 599 } 600 } 601 602 assert(!codegen && "cannot call codegen when analysis failed"); 603 return false; 604 } 605 606 /// Basic for-loop vectorizer. 607 struct ForOpRewriter : public OpRewritePattern<scf::ForOp> { 608 public: 609 using OpRewritePattern<scf::ForOp>::OpRewritePattern; 610 611 ForOpRewriter(MLIRContext *context, unsigned vectorLength, 612 bool enableVLAVectorization, bool enableSIMDIndex32) 613 : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization, 614 enableSIMDIndex32} {} 615 616 LogicalResult matchAndRewrite(scf::ForOp op, 617 PatternRewriter &rewriter) const override { 618 // Check for single block, unit-stride for-loop that is generated by 619 // sparse compiler, which means no data dependence analysis is required, 620 // and its loop-body is very restricted in form. 621 if (!op.getRegion().hasOneBlock() || !isIntValue(op.getStep(), 1) || 622 !op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) 623 return failure(); 624 // Analyze (!codegen) and rewrite (codegen) loop-body. 625 if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) && 626 vectorizeStmt(rewriter, op, vl, /*codegen=*/true)) 627 return success(); 628 return failure(); 629 } 630 631 private: 632 const VL vl; 633 }; 634 635 /// Reduction chain cleanup. 636 /// v = for { } 637 /// s = vsum(v) v = for { } 638 /// u = expand(s) -> for (v) { } 639 /// for (u) { } 640 template <typename VectorOp> 641 struct ReducChainRewriter : public OpRewritePattern<VectorOp> { 642 public: 643 using OpRewritePattern<VectorOp>::OpRewritePattern; 644 645 LogicalResult matchAndRewrite(VectorOp op, 646 PatternRewriter &rewriter) const override { 647 Value inp = op.getSource(); 648 if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) { 649 if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) { 650 if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) { 651 rewriter.replaceOp(op, redOp.getVector()); 652 return success(); 653 } 654 } 655 } 656 return failure(); 657 } 658 }; 659 660 } // namespace 661 662 //===----------------------------------------------------------------------===// 663 // Public method for populating vectorization rules. 664 //===----------------------------------------------------------------------===// 665 666 /// Populates the given patterns list with vectorization rules. 667 void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns, 668 unsigned vectorLength, 669 bool enableVLAVectorization, 670 bool enableSIMDIndex32) { 671 assert(vectorLength > 0); 672 patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength, 673 enableVLAVectorization, enableSIMDIndex32); 674 patterns.add<ReducChainRewriter<vector::InsertElementOp>, 675 ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext()); 676 } 677