1 //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A pass that converts loops generated by the sparse compiler into a form that 10 // can exploit SIMD instructions of the target architecture. Note that this pass 11 // ensures the sparse compiler can generate efficient SIMD (including ArmSVE 12 // support) with proper separation of concerns as far as sparsification and 13 // vectorization is concerned. However, this pass is not the final abstraction 14 // level we want, and not the general vectorizer we want either. It forms a good 15 // stepping stone for incremental future improvements though. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "CodegenUtils.h" 20 #include "LoopEmitter.h" 21 22 #include "mlir/Dialect/Affine/IR/AffineOps.h" 23 #include "mlir/Dialect/Arith/IR/Arith.h" 24 #include "mlir/Dialect/Complex/IR/Complex.h" 25 #include "mlir/Dialect/Math/IR/Math.h" 26 #include "mlir/Dialect/MemRef/IR/MemRef.h" 27 #include "mlir/Dialect/SCF/IR/SCF.h" 28 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 29 #include "mlir/Dialect/Vector/IR/VectorOps.h" 30 #include "mlir/IR/Matchers.h" 31 32 using namespace mlir; 33 using namespace mlir::sparse_tensor; 34 35 namespace { 36 37 /// Target SIMD properties: 38 /// vectorLength: # packed data elements (viz. vector<16xf32> has length 16) 39 /// enableVLAVectorization: enables scalable vectors (viz. ARMSve) 40 /// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency 41 struct VL { 42 unsigned vectorLength; 43 bool enableVLAVectorization; 44 bool enableSIMDIndex32; 45 }; 46 47 /// Helper test for invariant value (defined outside given block). 48 static bool isInvariantValue(Value val, Block *block) { 49 return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block; 50 } 51 52 /// Helper test for invariant argument (defined outside given block). 53 static bool isInvariantArg(BlockArgument arg, Block *block) { 54 return arg.getOwner() != block; 55 } 56 57 /// Constructs vector type for element type. 58 static VectorType vectorType(VL vl, Type etp) { 59 unsigned numScalableDims = vl.enableVLAVectorization; 60 return VectorType::get(vl.vectorLength, etp, numScalableDims); 61 } 62 63 /// Constructs vector type from a memref value. 64 static VectorType vectorType(VL vl, Value mem) { 65 return vectorType(vl, getMemRefType(mem).getElementType()); 66 } 67 68 /// Constructs vector iteration mask. 69 static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl, 70 Value iv, Value lo, Value hi, Value step) { 71 VectorType mtp = vectorType(vl, rewriter.getI1Type()); 72 // Special case if the vector length evenly divides the trip count (for 73 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 74 // so that all subsequent masked memory operations are immediately folded 75 // into unconditional memory operations. 76 IntegerAttr loInt, hiInt, stepInt; 77 if (matchPattern(lo, m_Constant(&loInt)) && 78 matchPattern(hi, m_Constant(&hiInt)) && 79 matchPattern(step, m_Constant(&stepInt))) { 80 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) { 81 Value trueVal = constantI1(rewriter, loc, true); 82 return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal); 83 } 84 } 85 // Otherwise, generate a vector mask that avoids overrunning the upperbound 86 // during vector execution. Here we rely on subsequent loop optimizations to 87 // avoid executing the mask in all iterations, for example, by splitting the 88 // loop into an unconditional vector loop and a scalar cleanup loop. 89 auto min = AffineMap::get( 90 /*dimCount=*/2, /*symbolCount=*/1, 91 {rewriter.getAffineSymbolExpr(0), 92 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 93 rewriter.getContext()); 94 Value end = rewriter.createOrFold<affine::AffineMinOp>( 95 loc, min, ValueRange{hi, iv, step}); 96 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 97 } 98 99 /// Generates a vectorized invariant. Here we rely on subsequent loop 100 /// optimizations to hoist the invariant broadcast out of the vector loop. 101 static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl, 102 Value val) { 103 VectorType vtp = vectorType(vl, val.getType()); 104 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 105 } 106 107 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi], 108 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note 109 /// that the sparse compiler can only generate indirect loads in 110 /// the last index, i.e. back(). 111 static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl, 112 Value mem, ArrayRef<Value> idxs, Value vmask) { 113 VectorType vtp = vectorType(vl, mem); 114 Value pass = constantZero(rewriter, loc, vtp); 115 if (llvm::isa<VectorType>(idxs.back().getType())) { 116 SmallVector<Value> scalarArgs(idxs.begin(), idxs.end()); 117 Value indexVec = idxs.back(); 118 scalarArgs.back() = constantIndex(rewriter, loc, 0); 119 return rewriter.create<vector::GatherOp>(loc, vtp, mem, scalarArgs, 120 indexVec, vmask, pass); 121 } 122 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, mem, idxs, vmask, 123 pass); 124 } 125 126 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs 127 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note 128 /// that the sparse compiler can only generate indirect stores in 129 /// the last index, i.e. back(). 130 static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem, 131 ArrayRef<Value> idxs, Value vmask, Value rhs) { 132 if (llvm::isa<VectorType>(idxs.back().getType())) { 133 SmallVector<Value> scalarArgs(idxs.begin(), idxs.end()); 134 Value indexVec = idxs.back(); 135 scalarArgs.back() = constantIndex(rewriter, loc, 0); 136 rewriter.create<vector::ScatterOp>(loc, mem, scalarArgs, indexVec, vmask, 137 rhs); 138 return; 139 } 140 rewriter.create<vector::MaskedStoreOp>(loc, mem, idxs, vmask, rhs); 141 } 142 143 /// Detects a vectorizable reduction operations and returns the 144 /// combining kind of reduction on success in `kind`. 145 static bool isVectorizableReduction(Value red, Value iter, 146 vector::CombiningKind &kind) { 147 if (auto addf = red.getDefiningOp<arith::AddFOp>()) { 148 kind = vector::CombiningKind::ADD; 149 return addf->getOperand(0) == iter || addf->getOperand(1) == iter; 150 } 151 if (auto addi = red.getDefiningOp<arith::AddIOp>()) { 152 kind = vector::CombiningKind::ADD; 153 return addi->getOperand(0) == iter || addi->getOperand(1) == iter; 154 } 155 if (auto subf = red.getDefiningOp<arith::SubFOp>()) { 156 kind = vector::CombiningKind::ADD; 157 return subf->getOperand(0) == iter; 158 } 159 if (auto subi = red.getDefiningOp<arith::SubIOp>()) { 160 kind = vector::CombiningKind::ADD; 161 return subi->getOperand(0) == iter; 162 } 163 if (auto mulf = red.getDefiningOp<arith::MulFOp>()) { 164 kind = vector::CombiningKind::MUL; 165 return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter; 166 } 167 if (auto muli = red.getDefiningOp<arith::MulIOp>()) { 168 kind = vector::CombiningKind::MUL; 169 return muli->getOperand(0) == iter || muli->getOperand(1) == iter; 170 } 171 if (auto andi = red.getDefiningOp<arith::AndIOp>()) { 172 kind = vector::CombiningKind::AND; 173 return andi->getOperand(0) == iter || andi->getOperand(1) == iter; 174 } 175 if (auto ori = red.getDefiningOp<arith::OrIOp>()) { 176 kind = vector::CombiningKind::OR; 177 return ori->getOperand(0) == iter || ori->getOperand(1) == iter; 178 } 179 if (auto xori = red.getDefiningOp<arith::XOrIOp>()) { 180 kind = vector::CombiningKind::XOR; 181 return xori->getOperand(0) == iter || xori->getOperand(1) == iter; 182 } 183 return false; 184 } 185 186 /// Generates an initial value for a vector reduction, following the scheme 187 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 188 /// initial scalar value is correctly embedded in the vector reduction value, 189 /// and a straightforward horizontal reduction will complete the operation. 190 /// Value 'r' denotes the initial value of the reduction outside the loop. 191 static Value genVectorReducInit(PatternRewriter &rewriter, Location loc, 192 Value red, Value iter, Value r, 193 VectorType vtp) { 194 vector::CombiningKind kind; 195 if (!isVectorizableReduction(red, iter, kind)) 196 llvm_unreachable("unknown reduction"); 197 switch (kind) { 198 case vector::CombiningKind::ADD: 199 case vector::CombiningKind::XOR: 200 // Initialize reduction vector to: | 0 | .. | 0 | r | 201 return rewriter.create<vector::InsertElementOp>( 202 loc, r, constantZero(rewriter, loc, vtp), 203 constantIndex(rewriter, loc, 0)); 204 case vector::CombiningKind::MUL: 205 // Initialize reduction vector to: | 1 | .. | 1 | r | 206 return rewriter.create<vector::InsertElementOp>( 207 loc, r, constantOne(rewriter, loc, vtp), 208 constantIndex(rewriter, loc, 0)); 209 case vector::CombiningKind::AND: 210 case vector::CombiningKind::OR: 211 // Initialize reduction vector to: | r | .. | r | r | 212 return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 213 default: 214 break; 215 } 216 llvm_unreachable("unknown reduction kind"); 217 } 218 219 /// This method is called twice to analyze and rewrite the given subscripts. 220 /// The first call (!codegen) does the analysis. Then, on success, the second 221 /// call (codegen) yields the proper vector form in the output parameter 222 /// vector 'idxs'. This mechanism ensures that analysis and rewriting code 223 /// stay in sync. Note that the analyis part is simple because the sparse 224 /// compiler only generates relatively simple subscript expressions. 225 /// 226 /// See https://llvm.org/docs/GetElementPtr.html for some background on 227 /// the complications described below. 228 /// 229 /// We need to generate a position/coordinate load from the sparse storage 230 /// scheme. Narrower data types need to be zero extended before casting 231 /// the value into the `index` type used for looping and indexing. 232 /// 233 /// For the scalar case, subscripts simply zero extend narrower indices 234 /// into 64-bit values before casting to an index type without a performance 235 /// penalty. Indices that already are 64-bit, in theory, cannot express the 236 /// full range since the LLVM backend defines addressing in terms of an 237 /// unsigned pointer/signed index pair. 238 static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp, 239 VL vl, ValueRange subs, bool codegen, 240 Value vmask, SmallVectorImpl<Value> &idxs) { 241 unsigned d = 0; 242 unsigned dim = subs.size(); 243 Block *block = &forOp.getRegion().front(); 244 for (auto sub : subs) { 245 bool innermost = ++d == dim; 246 // Invariant subscripts in outer dimensions simply pass through. 247 // Note that we rely on LICM to hoist loads where all subscripts 248 // are invariant in the innermost loop. 249 // Example: 250 // a[inv][i] for inv 251 if (isInvariantValue(sub, block)) { 252 if (innermost) 253 return false; 254 if (codegen) 255 idxs.push_back(sub); 256 continue; // success so far 257 } 258 // Invariant block arguments (including outer loop indices) in outer 259 // dimensions simply pass through. Direct loop indices in the 260 // innermost loop simply pass through as well. 261 // Example: 262 // a[i][j] for both i and j 263 if (auto arg = llvm::dyn_cast<BlockArgument>(sub)) { 264 if (isInvariantArg(arg, block) == innermost) 265 return false; 266 if (codegen) 267 idxs.push_back(sub); 268 continue; // success so far 269 } 270 // Look under the hood of casting. 271 auto cast = sub; 272 while (true) { 273 if (auto icast = cast.getDefiningOp<arith::IndexCastOp>()) 274 cast = icast->getOperand(0); 275 else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>()) 276 cast = ecast->getOperand(0); 277 else 278 break; 279 } 280 // Since the index vector is used in a subsequent gather/scatter 281 // operations, which effectively defines an unsigned pointer + signed 282 // index, we must zero extend the vector to an index width. For 8-bit 283 // and 16-bit values, an 32-bit index width suffices. For 32-bit values, 284 // zero extending the elements into 64-bit loses some performance since 285 // the 32-bit indexed gather/scatter is more efficient than the 64-bit 286 // index variant (if the negative 32-bit index space is unused, the 287 // enableSIMDIndex32 flag can preserve this performance). For 64-bit 288 // values, there is no good way to state that the indices are unsigned, 289 // which creates the potential of incorrect address calculations in the 290 // unlikely case we need such extremely large offsets. 291 // Example: 292 // a[ ind[i] ] 293 if (auto load = cast.getDefiningOp<memref::LoadOp>()) { 294 if (!innermost) 295 return false; 296 if (codegen) { 297 SmallVector<Value> idxs2(load.getIndices()); // no need to analyze 298 Location loc = forOp.getLoc(); 299 Value vload = 300 genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask); 301 Type etp = llvm::cast<VectorType>(vload.getType()).getElementType(); 302 if (!llvm::isa<IndexType>(etp)) { 303 if (etp.getIntOrFloatBitWidth() < 32) 304 vload = rewriter.create<arith::ExtUIOp>( 305 loc, vectorType(vl, rewriter.getI32Type()), vload); 306 else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32) 307 vload = rewriter.create<arith::ExtUIOp>( 308 loc, vectorType(vl, rewriter.getI64Type()), vload); 309 } 310 idxs.push_back(vload); 311 } 312 continue; // success so far 313 } 314 // Address calculation 'i = add inv, idx' (after LICM). 315 // Example: 316 // a[base + i] 317 if (auto load = cast.getDefiningOp<arith::AddIOp>()) { 318 Value inv = load.getOperand(0); 319 Value idx = load.getOperand(1); 320 if (isInvariantValue(inv, block)) { 321 if (auto arg = llvm::dyn_cast<BlockArgument>(idx)) { 322 if (isInvariantArg(arg, block) || !innermost) 323 return false; 324 if (codegen) 325 idxs.push_back( 326 rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx)); 327 continue; // success so far 328 } 329 } 330 } 331 return false; 332 } 333 return true; 334 } 335 336 #define UNAOP(xxx) \ 337 if (isa<xxx>(def)) { \ 338 if (codegen) \ 339 vexp = rewriter.create<xxx>(loc, vx); \ 340 return true; \ 341 } 342 343 #define TYPEDUNAOP(xxx) \ 344 if (auto x = dyn_cast<xxx>(def)) { \ 345 if (codegen) { \ 346 VectorType vtp = vectorType(vl, x.getType()); \ 347 vexp = rewriter.create<xxx>(loc, vtp, vx); \ 348 } \ 349 return true; \ 350 } 351 352 #define BINOP(xxx) \ 353 if (isa<xxx>(def)) { \ 354 if (codegen) \ 355 vexp = rewriter.create<xxx>(loc, vx, vy); \ 356 return true; \ 357 } 358 359 /// This method is called twice to analyze and rewrite the given expression. 360 /// The first call (!codegen) does the analysis. Then, on success, the second 361 /// call (codegen) yields the proper vector form in the output parameter 'vexp'. 362 /// This mechanism ensures that analysis and rewriting code stay in sync. Note 363 /// that the analyis part is simple because the sparse compiler only generates 364 /// relatively simple expressions inside the for-loops. 365 static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, 366 Value exp, bool codegen, Value vmask, Value &vexp) { 367 Location loc = forOp.getLoc(); 368 // Reject unsupported types. 369 if (!VectorType::isValidElementType(exp.getType())) 370 return false; 371 // A block argument is invariant/reduction/index. 372 if (auto arg = llvm::dyn_cast<BlockArgument>(exp)) { 373 if (arg == forOp.getInductionVar()) { 374 // We encountered a single, innermost index inside the computation, 375 // such as a[i] = i, which must convert to [i, i+1, ...]. 376 if (codegen) { 377 VectorType vtp = vectorType(vl, arg.getType()); 378 Value veci = rewriter.create<vector::BroadcastOp>(loc, vtp, arg); 379 Value incr; 380 if (vl.enableVLAVectorization) { 381 Type stepvty = vectorType(vl, rewriter.getI64Type()); 382 Value stepv = rewriter.create<LLVM::StepVectorOp>(loc, stepvty); 383 incr = rewriter.create<arith::IndexCastOp>(loc, vtp, stepv); 384 } else { 385 SmallVector<APInt> integers; 386 for (unsigned i = 0, l = vl.vectorLength; i < l; i++) 387 integers.push_back(APInt(/*width=*/64, i)); 388 auto values = DenseElementsAttr::get(vtp, integers); 389 incr = rewriter.create<arith::ConstantOp>(loc, vtp, values); 390 } 391 vexp = rewriter.create<arith::AddIOp>(loc, veci, incr); 392 } 393 return true; 394 } 395 // An invariant or reduction. In both cases, we treat this as an 396 // invariant value, and rely on later replacing and folding to 397 // construct a proper reduction chain for the latter case. 398 if (codegen) 399 vexp = genVectorInvariantValue(rewriter, vl, exp); 400 return true; 401 } 402 // Something defined outside the loop-body is invariant. 403 Operation *def = exp.getDefiningOp(); 404 Block *block = &forOp.getRegion().front(); 405 if (def->getBlock() != block) { 406 if (codegen) 407 vexp = genVectorInvariantValue(rewriter, vl, exp); 408 return true; 409 } 410 // Proper load operations. These are either values involved in the 411 // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi], 412 // or coordinate values inside the computation that are now fetched from 413 // the sparse storage coordinates arrays, such as a[i] = i becomes 414 // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index 415 // and 'hi = lo + vl - 1'. 416 if (auto load = dyn_cast<memref::LoadOp>(def)) { 417 auto subs = load.getIndices(); 418 SmallVector<Value> idxs; 419 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) { 420 if (codegen) 421 vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask); 422 return true; 423 } 424 return false; 425 } 426 // Inside loop-body unary and binary operations. Note that it would be 427 // nicer if we could somehow test and build the operations in a more 428 // concise manner than just listing them all (although this way we know 429 // for certain that they can vectorize). 430 // 431 // TODO: avoid visiting CSEs multiple times 432 // 433 if (def->getNumOperands() == 1) { 434 Value vx; 435 if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, 436 vx)) { 437 UNAOP(math::AbsFOp) 438 UNAOP(math::AbsIOp) 439 UNAOP(math::CeilOp) 440 UNAOP(math::FloorOp) 441 UNAOP(math::SqrtOp) 442 UNAOP(math::ExpM1Op) 443 UNAOP(math::Log1pOp) 444 UNAOP(math::SinOp) 445 UNAOP(math::TanhOp) 446 UNAOP(arith::NegFOp) 447 TYPEDUNAOP(arith::TruncFOp) 448 TYPEDUNAOP(arith::ExtFOp) 449 TYPEDUNAOP(arith::FPToSIOp) 450 TYPEDUNAOP(arith::FPToUIOp) 451 TYPEDUNAOP(arith::SIToFPOp) 452 TYPEDUNAOP(arith::UIToFPOp) 453 TYPEDUNAOP(arith::ExtSIOp) 454 TYPEDUNAOP(arith::ExtUIOp) 455 TYPEDUNAOP(arith::IndexCastOp) 456 TYPEDUNAOP(arith::TruncIOp) 457 TYPEDUNAOP(arith::BitcastOp) 458 // TODO: complex? 459 } 460 } else if (def->getNumOperands() == 2) { 461 Value vx, vy; 462 if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, 463 vx) && 464 vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask, 465 vy)) { 466 // We only accept shift-by-invariant (where the same shift factor applies 467 // to all packed elements). In the vector dialect, this is still 468 // represented with an expanded vector at the right-hand-side, however, 469 // so that we do not have to special case the code generation. 470 if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) || 471 isa<arith::ShRSIOp>(def)) { 472 Value shiftFactor = def->getOperand(1); 473 if (!isInvariantValue(shiftFactor, block)) 474 return false; 475 } 476 // Generate code. 477 BINOP(arith::MulFOp) 478 BINOP(arith::MulIOp) 479 BINOP(arith::DivFOp) 480 BINOP(arith::DivSIOp) 481 BINOP(arith::DivUIOp) 482 BINOP(arith::AddFOp) 483 BINOP(arith::AddIOp) 484 BINOP(arith::SubFOp) 485 BINOP(arith::SubIOp) 486 BINOP(arith::AndIOp) 487 BINOP(arith::OrIOp) 488 BINOP(arith::XOrIOp) 489 BINOP(arith::ShLIOp) 490 BINOP(arith::ShRUIOp) 491 BINOP(arith::ShRSIOp) 492 // TODO: complex? 493 } 494 } 495 return false; 496 } 497 498 #undef UNAOP 499 #undef TYPEDUNAOP 500 #undef BINOP 501 502 /// This method is called twice to analyze and rewrite the given for-loop. 503 /// The first call (!codegen) does the analysis. Then, on success, the second 504 /// call (codegen) rewriters the IR into vector form. This mechanism ensures 505 /// that analysis and rewriting code stay in sync. 506 static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, 507 bool codegen) { 508 Location loc = forOp.getLoc(); 509 Block &block = forOp.getRegion().front(); 510 scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator()); 511 auto &last = *++block.rbegin(); 512 scf::ForOp forOpNew; 513 514 // Perform initial set up during codegen (we know that the first analysis 515 // pass was successful). For reductions, we need to construct a completely 516 // new for-loop, since the incoming and outgoing reduction type 517 // changes into SIMD form. For stores, we can simply adjust the stride 518 // and insert in the existing for-loop. In both cases, we set up a vector 519 // mask for all operations which takes care of confining vectors to 520 // the original iteration space (later cleanup loops or other 521 // optimizations can take care of those). 522 Value vmask; 523 if (codegen) { 524 Value step = constantIndex(rewriter, loc, vl.vectorLength); 525 if (vl.enableVLAVectorization) { 526 Value vscale = 527 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); 528 step = rewriter.create<arith::MulIOp>(loc, vscale, step); 529 } 530 if (!yield.getResults().empty()) { 531 Value init = forOp.getInitArgs()[0]; 532 VectorType vtp = vectorType(vl, init.getType()); 533 Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0), 534 forOp.getRegionIterArg(0), init, vtp); 535 forOpNew = rewriter.create<scf::ForOp>( 536 loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit); 537 forOpNew->setAttr( 538 LoopEmitter::getLoopEmitterLoopAttrName(), 539 forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName())); 540 rewriter.setInsertionPointToStart(forOpNew.getBody()); 541 } else { 542 rewriter.updateRootInPlace(forOp, [&]() { forOp.setStep(step); }); 543 rewriter.setInsertionPoint(yield); 544 } 545 vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(), 546 forOp.getLowerBound(), forOp.getUpperBound(), step); 547 } 548 549 // Sparse for-loops either are terminated by a non-empty yield operation 550 // (reduction loop) or otherwise by a store operation (pararallel loop). 551 if (!yield.getResults().empty()) { 552 // Analyze/vectorize reduction. 553 if (yield->getNumOperands() != 1) 554 return false; 555 Value red = yield->getOperand(0); 556 Value iter = forOp.getRegionIterArg(0); 557 vector::CombiningKind kind; 558 Value vrhs; 559 if (isVectorizableReduction(red, iter, kind) && 560 vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) { 561 if (codegen) { 562 Value partial = forOpNew.getResult(0); 563 Value vpass = genVectorInvariantValue(rewriter, vl, iter); 564 Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass); 565 rewriter.create<scf::YieldOp>(loc, vred); 566 rewriter.setInsertionPointAfter(forOpNew); 567 Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial); 568 // Now do some relinking (last one is not completely type safe 569 // but all bad ones are removed right away). This also folds away 570 // nop broadcast operations. 571 rewriter.replaceAllUsesWith(forOp.getResult(0), vres); 572 rewriter.replaceAllUsesWith(forOp.getInductionVar(), 573 forOpNew.getInductionVar()); 574 rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0), 575 forOpNew.getRegionIterArg(0)); 576 rewriter.eraseOp(forOp); 577 } 578 return true; 579 } 580 } else if (auto store = dyn_cast<memref::StoreOp>(last)) { 581 // Analyze/vectorize store operation. 582 auto subs = store.getIndices(); 583 SmallVector<Value> idxs; 584 Value rhs = store.getValue(); 585 Value vrhs; 586 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) && 587 vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) { 588 if (codegen) { 589 genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs); 590 rewriter.eraseOp(store); 591 } 592 return true; 593 } 594 } 595 596 assert(!codegen && "cannot call codegen when analysis failed"); 597 return false; 598 } 599 600 /// Basic for-loop vectorizer. 601 struct ForOpRewriter : public OpRewritePattern<scf::ForOp> { 602 public: 603 using OpRewritePattern<scf::ForOp>::OpRewritePattern; 604 605 ForOpRewriter(MLIRContext *context, unsigned vectorLength, 606 bool enableVLAVectorization, bool enableSIMDIndex32) 607 : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization, 608 enableSIMDIndex32} {} 609 610 LogicalResult matchAndRewrite(scf::ForOp op, 611 PatternRewriter &rewriter) const override { 612 // Check for single block, unit-stride for-loop that is generated by 613 // sparse compiler, which means no data dependence analysis is required, 614 // and its loop-body is very restricted in form. 615 if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) || 616 !op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) 617 return failure(); 618 // Analyze (!codegen) and rewrite (codegen) loop-body. 619 if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) && 620 vectorizeStmt(rewriter, op, vl, /*codegen=*/true)) 621 return success(); 622 return failure(); 623 } 624 625 private: 626 const VL vl; 627 }; 628 629 /// Reduction chain cleanup. 630 /// v = for { } 631 /// s = vsum(v) v = for { } 632 /// u = expand(s) -> for (v) { } 633 /// for (u) { } 634 template <typename VectorOp> 635 struct ReducChainRewriter : public OpRewritePattern<VectorOp> { 636 public: 637 using OpRewritePattern<VectorOp>::OpRewritePattern; 638 639 LogicalResult matchAndRewrite(VectorOp op, 640 PatternRewriter &rewriter) const override { 641 Value inp = op.getSource(); 642 if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) { 643 if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) { 644 if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) { 645 rewriter.replaceOp(op, redOp.getVector()); 646 return success(); 647 } 648 } 649 } 650 return failure(); 651 } 652 }; 653 654 } // namespace 655 656 //===----------------------------------------------------------------------===// 657 // Public method for populating vectorization rules. 658 //===----------------------------------------------------------------------===// 659 660 /// Populates the given patterns list with vectorization rules. 661 void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns, 662 unsigned vectorLength, 663 bool enableVLAVectorization, 664 bool enableSIMDIndex32) { 665 assert(vectorLength > 0); 666 patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength, 667 enableVLAVectorization, enableSIMDIndex32); 668 patterns.add<ReducChainRewriter<vector::InsertElementOp>, 669 ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext()); 670 } 671