199b3849dSAart Bik //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===// 299b3849dSAart Bik // 399b3849dSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 499b3849dSAart Bik // See https://llvm.org/LICENSE.txt for license information. 599b3849dSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 699b3849dSAart Bik // 799b3849dSAart Bik //===----------------------------------------------------------------------===// 899b3849dSAart Bik // 9c43e6274STim Harvey // A pass that converts loops generated by the sparsifier into a form that 1099b3849dSAart Bik // can exploit SIMD instructions of the target architecture. Note that this pass 11c43e6274STim Harvey // ensures the sparsifier can generate efficient SIMD (including ArmSVE 1299b3849dSAart Bik // support) with proper separation of concerns as far as sparsification and 1399b3849dSAart Bik // vectorization is concerned. However, this pass is not the final abstraction 1499b3849dSAart Bik // level we want, and not the general vectorizer we want either. It forms a good 1599b3849dSAart Bik // stepping stone for incremental future improvements though. 1699b3849dSAart Bik // 1799b3849dSAart Bik //===----------------------------------------------------------------------===// 1899b3849dSAart Bik 19365777ecSAart Bik #include "Utils/CodegenUtils.h" 20365777ecSAart Bik #include "Utils/LoopEmitter.h" 2199b3849dSAart Bik 2299b3849dSAart Bik #include "mlir/Dialect/Affine/IR/AffineOps.h" 2399b3849dSAart Bik #include "mlir/Dialect/Arith/IR/Arith.h" 2499b3849dSAart Bik #include "mlir/Dialect/Complex/IR/Complex.h" 2599b3849dSAart Bik #include "mlir/Dialect/Math/IR/Math.h" 2699b3849dSAart Bik #include "mlir/Dialect/MemRef/IR/MemRef.h" 2799b3849dSAart Bik #include "mlir/Dialect/SCF/IR/SCF.h" 2899b3849dSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 2999b3849dSAart Bik #include "mlir/Dialect/Vector/IR/VectorOps.h" 30*a6e72f93SManupa Karunaratne #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 3199b3849dSAart Bik #include "mlir/IR/Matchers.h" 3299b3849dSAart Bik 3399b3849dSAart Bik using namespace mlir; 3499b3849dSAart Bik using namespace mlir::sparse_tensor; 3599b3849dSAart Bik 3699b3849dSAart Bik namespace { 3799b3849dSAart Bik 3899b3849dSAart Bik /// Target SIMD properties: 3999b3849dSAart Bik /// vectorLength: # packed data elements (viz. vector<16xf32> has length 16) 4099b3849dSAart Bik /// enableVLAVectorization: enables scalable vectors (viz. ARMSve) 4199b3849dSAart Bik /// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency 4299b3849dSAart Bik struct VL { 4399b3849dSAart Bik unsigned vectorLength; 4499b3849dSAart Bik bool enableVLAVectorization; 4599b3849dSAart Bik bool enableSIMDIndex32; 4699b3849dSAart Bik }; 4799b3849dSAart Bik 48431f6a54SAart Bik /// Helper test for invariant value (defined outside given block). 49431f6a54SAart Bik static bool isInvariantValue(Value val, Block *block) { 50431f6a54SAart Bik return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block; 51431f6a54SAart Bik } 52431f6a54SAart Bik 53431f6a54SAart Bik /// Helper test for invariant argument (defined outside given block). 54431f6a54SAart Bik static bool isInvariantArg(BlockArgument arg, Block *block) { 55431f6a54SAart Bik return arg.getOwner() != block; 56431f6a54SAart Bik } 57431f6a54SAart Bik 5899b3849dSAart Bik /// Constructs vector type for element type. 5999b3849dSAart Bik static VectorType vectorType(VL vl, Type etp) { 60f22af204SAndrzej Warzynski return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization); 6199b3849dSAart Bik } 6299b3849dSAart Bik 6384cd51bbSwren romano /// Constructs vector type from a memref value. 6484cd51bbSwren romano static VectorType vectorType(VL vl, Value mem) { 6584cd51bbSwren romano return vectorType(vl, getMemRefType(mem).getElementType()); 6699b3849dSAart Bik } 6799b3849dSAart Bik 6899b3849dSAart Bik /// Constructs vector iteration mask. 6999b3849dSAart Bik static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl, 7099b3849dSAart Bik Value iv, Value lo, Value hi, Value step) { 7199b3849dSAart Bik VectorType mtp = vectorType(vl, rewriter.getI1Type()); 7299b3849dSAart Bik // Special case if the vector length evenly divides the trip count (for 7399b3849dSAart Bik // example, "for i = 0, 128, 16"). A constant all-true mask is generated 7499b3849dSAart Bik // so that all subsequent masked memory operations are immediately folded 7599b3849dSAart Bik // into unconditional memory operations. 7699b3849dSAart Bik IntegerAttr loInt, hiInt, stepInt; 7799b3849dSAart Bik if (matchPattern(lo, m_Constant(&loInt)) && 7899b3849dSAart Bik matchPattern(hi, m_Constant(&hiInt)) && 7999b3849dSAart Bik matchPattern(step, m_Constant(&stepInt))) { 8099b3849dSAart Bik if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) { 8199b3849dSAart Bik Value trueVal = constantI1(rewriter, loc, true); 8299b3849dSAart Bik return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal); 8399b3849dSAart Bik } 8499b3849dSAart Bik } 8599b3849dSAart Bik // Otherwise, generate a vector mask that avoids overrunning the upperbound 8699b3849dSAart Bik // during vector execution. Here we rely on subsequent loop optimizations to 8799b3849dSAart Bik // avoid executing the mask in all iterations, for example, by splitting the 8899b3849dSAart Bik // loop into an unconditional vector loop and a scalar cleanup loop. 8999b3849dSAart Bik auto min = AffineMap::get( 9099b3849dSAart Bik /*dimCount=*/2, /*symbolCount=*/1, 9199b3849dSAart Bik {rewriter.getAffineSymbolExpr(0), 9299b3849dSAart Bik rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 9399b3849dSAart Bik rewriter.getContext()); 944c48f016SMatthias Springer Value end = rewriter.createOrFold<affine::AffineMinOp>( 954c48f016SMatthias Springer loc, min, ValueRange{hi, iv, step}); 9699b3849dSAart Bik return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 9799b3849dSAart Bik } 9899b3849dSAart Bik 9999b3849dSAart Bik /// Generates a vectorized invariant. Here we rely on subsequent loop 10099b3849dSAart Bik /// optimizations to hoist the invariant broadcast out of the vector loop. 10199b3849dSAart Bik static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl, 10299b3849dSAart Bik Value val) { 10399b3849dSAart Bik VectorType vtp = vectorType(vl, val.getType()); 10499b3849dSAart Bik return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 10599b3849dSAart Bik } 10699b3849dSAart Bik 10799b3849dSAart Bik /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi], 108cb82d375SAart Bik /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note 109c43e6274STim Harvey /// that the sparsifier can only generate indirect loads in 110cb82d375SAart Bik /// the last index, i.e. back(). 11199b3849dSAart Bik static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl, 11284cd51bbSwren romano Value mem, ArrayRef<Value> idxs, Value vmask) { 11384cd51bbSwren romano VectorType vtp = vectorType(vl, mem); 11499b3849dSAart Bik Value pass = constantZero(rewriter, loc, vtp); 115c1fa60b4STres Popp if (llvm::isa<VectorType>(idxs.back().getType())) { 1165262865aSKazu Hirata SmallVector<Value> scalarArgs(idxs); 11799b3849dSAart Bik Value indexVec = idxs.back(); 11899b3849dSAart Bik scalarArgs.back() = constantIndex(rewriter, loc, 0); 11984cd51bbSwren romano return rewriter.create<vector::GatherOp>(loc, vtp, mem, scalarArgs, 12099b3849dSAart Bik indexVec, vmask, pass); 12199b3849dSAart Bik } 12284cd51bbSwren romano return rewriter.create<vector::MaskedLoadOp>(loc, vtp, mem, idxs, vmask, 12399b3849dSAart Bik pass); 12499b3849dSAart Bik } 12599b3849dSAart Bik 12699b3849dSAart Bik /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs 127cb82d375SAart Bik /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note 128c43e6274STim Harvey /// that the sparsifier can only generate indirect stores in 129cb82d375SAart Bik /// the last index, i.e. back(). 13084cd51bbSwren romano static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem, 13199b3849dSAart Bik ArrayRef<Value> idxs, Value vmask, Value rhs) { 132c1fa60b4STres Popp if (llvm::isa<VectorType>(idxs.back().getType())) { 1335262865aSKazu Hirata SmallVector<Value> scalarArgs(idxs); 13499b3849dSAart Bik Value indexVec = idxs.back(); 13599b3849dSAart Bik scalarArgs.back() = constantIndex(rewriter, loc, 0); 13684cd51bbSwren romano rewriter.create<vector::ScatterOp>(loc, mem, scalarArgs, indexVec, vmask, 13799b3849dSAart Bik rhs); 13899b3849dSAart Bik return; 13999b3849dSAart Bik } 14084cd51bbSwren romano rewriter.create<vector::MaskedStoreOp>(loc, mem, idxs, vmask, rhs); 14199b3849dSAart Bik } 14299b3849dSAart Bik 143cb82d375SAart Bik /// Detects a vectorizable reduction operations and returns the 144cb82d375SAart Bik /// combining kind of reduction on success in `kind`. 145cb82d375SAart Bik static bool isVectorizableReduction(Value red, Value iter, 146cb82d375SAart Bik vector::CombiningKind &kind) { 147cb82d375SAart Bik if (auto addf = red.getDefiningOp<arith::AddFOp>()) { 148cb82d375SAart Bik kind = vector::CombiningKind::ADD; 149cb82d375SAart Bik return addf->getOperand(0) == iter || addf->getOperand(1) == iter; 150cb82d375SAart Bik } 151cb82d375SAart Bik if (auto addi = red.getDefiningOp<arith::AddIOp>()) { 152cb82d375SAart Bik kind = vector::CombiningKind::ADD; 153cb82d375SAart Bik return addi->getOperand(0) == iter || addi->getOperand(1) == iter; 154cb82d375SAart Bik } 155cb82d375SAart Bik if (auto subf = red.getDefiningOp<arith::SubFOp>()) { 156cb82d375SAart Bik kind = vector::CombiningKind::ADD; 157cb82d375SAart Bik return subf->getOperand(0) == iter; 158cb82d375SAart Bik } 159cb82d375SAart Bik if (auto subi = red.getDefiningOp<arith::SubIOp>()) { 160cb82d375SAart Bik kind = vector::CombiningKind::ADD; 161cb82d375SAart Bik return subi->getOperand(0) == iter; 162cb82d375SAart Bik } 163cb82d375SAart Bik if (auto mulf = red.getDefiningOp<arith::MulFOp>()) { 164cb82d375SAart Bik kind = vector::CombiningKind::MUL; 165cb82d375SAart Bik return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter; 166cb82d375SAart Bik } 167cb82d375SAart Bik if (auto muli = red.getDefiningOp<arith::MulIOp>()) { 168cb82d375SAart Bik kind = vector::CombiningKind::MUL; 169cb82d375SAart Bik return muli->getOperand(0) == iter || muli->getOperand(1) == iter; 170cb82d375SAart Bik } 171cb82d375SAart Bik if (auto andi = red.getDefiningOp<arith::AndIOp>()) { 172cb82d375SAart Bik kind = vector::CombiningKind::AND; 173cb82d375SAart Bik return andi->getOperand(0) == iter || andi->getOperand(1) == iter; 174cb82d375SAart Bik } 175cb82d375SAart Bik if (auto ori = red.getDefiningOp<arith::OrIOp>()) { 176cb82d375SAart Bik kind = vector::CombiningKind::OR; 177cb82d375SAart Bik return ori->getOperand(0) == iter || ori->getOperand(1) == iter; 178cb82d375SAart Bik } 179cb82d375SAart Bik if (auto xori = red.getDefiningOp<arith::XOrIOp>()) { 180cb82d375SAart Bik kind = vector::CombiningKind::XOR; 181cb82d375SAart Bik return xori->getOperand(0) == iter || xori->getOperand(1) == iter; 182cb82d375SAart Bik } 183cb82d375SAart Bik return false; 18499b3849dSAart Bik } 18599b3849dSAart Bik 18699b3849dSAart Bik /// Generates an initial value for a vector reduction, following the scheme 18799b3849dSAart Bik /// given in Chapter 5 of "The Software Vectorization Handbook", where the 18899b3849dSAart Bik /// initial scalar value is correctly embedded in the vector reduction value, 18999b3849dSAart Bik /// and a straightforward horizontal reduction will complete the operation. 190cb82d375SAart Bik /// Value 'r' denotes the initial value of the reduction outside the loop. 19199b3849dSAart Bik static Value genVectorReducInit(PatternRewriter &rewriter, Location loc, 192cb82d375SAart Bik Value red, Value iter, Value r, 193cb82d375SAart Bik VectorType vtp) { 194cb82d375SAart Bik vector::CombiningKind kind; 195cb82d375SAart Bik if (!isVectorizableReduction(red, iter, kind)) 196cb82d375SAart Bik llvm_unreachable("unknown reduction"); 19799b3849dSAart Bik switch (kind) { 19899b3849dSAart Bik case vector::CombiningKind::ADD: 19999b3849dSAart Bik case vector::CombiningKind::XOR: 20099b3849dSAart Bik // Initialize reduction vector to: | 0 | .. | 0 | r | 20199b3849dSAart Bik return rewriter.create<vector::InsertElementOp>( 20299b3849dSAart Bik loc, r, constantZero(rewriter, loc, vtp), 20399b3849dSAart Bik constantIndex(rewriter, loc, 0)); 20499b3849dSAart Bik case vector::CombiningKind::MUL: 20599b3849dSAart Bik // Initialize reduction vector to: | 1 | .. | 1 | r | 20699b3849dSAart Bik return rewriter.create<vector::InsertElementOp>( 20799b3849dSAart Bik loc, r, constantOne(rewriter, loc, vtp), 20899b3849dSAart Bik constantIndex(rewriter, loc, 0)); 20999b3849dSAart Bik case vector::CombiningKind::AND: 21099b3849dSAart Bik case vector::CombiningKind::OR: 21199b3849dSAart Bik // Initialize reduction vector to: | r | .. | r | r | 21299b3849dSAart Bik return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 21399b3849dSAart Bik default: 21499b3849dSAart Bik break; 21599b3849dSAart Bik } 21699b3849dSAart Bik llvm_unreachable("unknown reduction kind"); 21799b3849dSAart Bik } 21899b3849dSAart Bik 21999b3849dSAart Bik /// This method is called twice to analyze and rewrite the given subscripts. 22099b3849dSAart Bik /// The first call (!codegen) does the analysis. Then, on success, the second 22199b3849dSAart Bik /// call (codegen) yields the proper vector form in the output parameter 22299b3849dSAart Bik /// vector 'idxs'. This mechanism ensures that analysis and rewriting code 223c43e6274STim Harvey /// stay in sync. Note that the analyis part is simple because the sparsifier 224c43e6274STim Harvey /// only generates relatively simple subscript expressions. 22599b3849dSAart Bik /// 22699b3849dSAart Bik /// See https://llvm.org/docs/GetElementPtr.html for some background on 22799b3849dSAart Bik /// the complications described below. 22899b3849dSAart Bik /// 22984cd51bbSwren romano /// We need to generate a position/coordinate load from the sparse storage 23084cd51bbSwren romano /// scheme. Narrower data types need to be zero extended before casting 23184cd51bbSwren romano /// the value into the `index` type used for looping and indexing. 23299b3849dSAart Bik /// 23399b3849dSAart Bik /// For the scalar case, subscripts simply zero extend narrower indices 23499b3849dSAart Bik /// into 64-bit values before casting to an index type without a performance 23599b3849dSAart Bik /// penalty. Indices that already are 64-bit, in theory, cannot express the 23699b3849dSAart Bik /// full range since the LLVM backend defines addressing in terms of an 23799b3849dSAart Bik /// unsigned pointer/signed index pair. 23899b3849dSAart Bik static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp, 23999b3849dSAart Bik VL vl, ValueRange subs, bool codegen, 24099b3849dSAart Bik Value vmask, SmallVectorImpl<Value> &idxs) { 24170ac5981SAart Bik unsigned d = 0; 24270ac5981SAart Bik unsigned dim = subs.size(); 243431f6a54SAart Bik Block *block = &forOp.getRegion().front(); 24499b3849dSAart Bik for (auto sub : subs) { 24570ac5981SAart Bik bool innermost = ++d == dim; 24670ac5981SAart Bik // Invariant subscripts in outer dimensions simply pass through. 24770ac5981SAart Bik // Note that we rely on LICM to hoist loads where all subscripts 24870ac5981SAart Bik // are invariant in the innermost loop. 249431f6a54SAart Bik // Example: 250431f6a54SAart Bik // a[inv][i] for inv 251431f6a54SAart Bik if (isInvariantValue(sub, block)) { 25270ac5981SAart Bik if (innermost) 25370ac5981SAart Bik return false; 25470ac5981SAart Bik if (codegen) 25570ac5981SAart Bik idxs.push_back(sub); 25670ac5981SAart Bik continue; // success so far 25770ac5981SAart Bik } 25870ac5981SAart Bik // Invariant block arguments (including outer loop indices) in outer 25970ac5981SAart Bik // dimensions simply pass through. Direct loop indices in the 26070ac5981SAart Bik // innermost loop simply pass through as well. 261431f6a54SAart Bik // Example: 262431f6a54SAart Bik // a[i][j] for both i and j 263c1fa60b4STres Popp if (auto arg = llvm::dyn_cast<BlockArgument>(sub)) { 264431f6a54SAart Bik if (isInvariantArg(arg, block) == innermost) 26570ac5981SAart Bik return false; 26699b3849dSAart Bik if (codegen) 26799b3849dSAart Bik idxs.push_back(sub); 26899b3849dSAart Bik continue; // success so far 26999b3849dSAart Bik } 27099b3849dSAart Bik // Look under the hood of casting. 27199b3849dSAart Bik auto cast = sub; 272f083c9bdSAdrian Kuegel while (true) { 27399b3849dSAart Bik if (auto icast = cast.getDefiningOp<arith::IndexCastOp>()) 27499b3849dSAart Bik cast = icast->getOperand(0); 27599b3849dSAart Bik else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>()) 27699b3849dSAart Bik cast = ecast->getOperand(0); 27799b3849dSAart Bik else 27899b3849dSAart Bik break; 27999b3849dSAart Bik } 28099b3849dSAart Bik // Since the index vector is used in a subsequent gather/scatter 28199b3849dSAart Bik // operations, which effectively defines an unsigned pointer + signed 28299b3849dSAart Bik // index, we must zero extend the vector to an index width. For 8-bit 28399b3849dSAart Bik // and 16-bit values, an 32-bit index width suffices. For 32-bit values, 28499b3849dSAart Bik // zero extending the elements into 64-bit loses some performance since 28599b3849dSAart Bik // the 32-bit indexed gather/scatter is more efficient than the 64-bit 28699b3849dSAart Bik // index variant (if the negative 32-bit index space is unused, the 28799b3849dSAart Bik // enableSIMDIndex32 flag can preserve this performance). For 64-bit 28899b3849dSAart Bik // values, there is no good way to state that the indices are unsigned, 28999b3849dSAart Bik // which creates the potential of incorrect address calculations in the 29099b3849dSAart Bik // unlikely case we need such extremely large offsets. 291431f6a54SAart Bik // Example: 292431f6a54SAart Bik // a[ ind[i] ] 29399b3849dSAart Bik if (auto load = cast.getDefiningOp<memref::LoadOp>()) { 29470ac5981SAart Bik if (!innermost) 29570ac5981SAart Bik return false; 29699b3849dSAart Bik if (codegen) { 29799b3849dSAart Bik SmallVector<Value> idxs2(load.getIndices()); // no need to analyze 29899b3849dSAart Bik Location loc = forOp.getLoc(); 29999b3849dSAart Bik Value vload = 30099b3849dSAart Bik genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask); 301c1fa60b4STres Popp Type etp = llvm::cast<VectorType>(vload.getType()).getElementType(); 302c1fa60b4STres Popp if (!llvm::isa<IndexType>(etp)) { 30399b3849dSAart Bik if (etp.getIntOrFloatBitWidth() < 32) 30499b3849dSAart Bik vload = rewriter.create<arith::ExtUIOp>( 30599b3849dSAart Bik loc, vectorType(vl, rewriter.getI32Type()), vload); 30699b3849dSAart Bik else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32) 30799b3849dSAart Bik vload = rewriter.create<arith::ExtUIOp>( 30899b3849dSAart Bik loc, vectorType(vl, rewriter.getI64Type()), vload); 30999b3849dSAart Bik } 31099b3849dSAart Bik idxs.push_back(vload); 31199b3849dSAart Bik } 31299b3849dSAart Bik continue; // success so far 31399b3849dSAart Bik } 31416aa4e4bSAart Bik // Address calculation 'i = add inv, idx' (after LICM). 315431f6a54SAart Bik // Example: 316431f6a54SAart Bik // a[base + i] 31716aa4e4bSAart Bik if (auto load = cast.getDefiningOp<arith::AddIOp>()) { 31816aa4e4bSAart Bik Value inv = load.getOperand(0); 31916aa4e4bSAart Bik Value idx = load.getOperand(1); 3205a9af39aSAart Bik // Swap non-invariant. 3215a9af39aSAart Bik if (!isInvariantValue(inv, block)) { 3225a9af39aSAart Bik inv = idx; 3235a9af39aSAart Bik idx = load.getOperand(0); 3245a9af39aSAart Bik } 3255a9af39aSAart Bik // Inspect. 326431f6a54SAart Bik if (isInvariantValue(inv, block)) { 327c1fa60b4STres Popp if (auto arg = llvm::dyn_cast<BlockArgument>(idx)) { 328431f6a54SAart Bik if (isInvariantArg(arg, block) || !innermost) 32970ac5981SAart Bik return false; 33016aa4e4bSAart Bik if (codegen) 33116aa4e4bSAart Bik idxs.push_back( 33216aa4e4bSAart Bik rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx)); 33316aa4e4bSAart Bik continue; // success so far 33416aa4e4bSAart Bik } 33516aa4e4bSAart Bik } 336431f6a54SAart Bik } 33799b3849dSAart Bik return false; 33899b3849dSAart Bik } 33999b3849dSAart Bik return true; 34099b3849dSAart Bik } 34199b3849dSAart Bik 34299b3849dSAart Bik #define UNAOP(xxx) \ 34399b3849dSAart Bik if (isa<xxx>(def)) { \ 34499b3849dSAart Bik if (codegen) \ 34599b3849dSAart Bik vexp = rewriter.create<xxx>(loc, vx); \ 34699b3849dSAart Bik return true; \ 34799b3849dSAart Bik } 34899b3849dSAart Bik 3492fda6207SAart Bik #define TYPEDUNAOP(xxx) \ 3502fda6207SAart Bik if (auto x = dyn_cast<xxx>(def)) { \ 3512fda6207SAart Bik if (codegen) { \ 3522fda6207SAart Bik VectorType vtp = vectorType(vl, x.getType()); \ 3532fda6207SAart Bik vexp = rewriter.create<xxx>(loc, vtp, vx); \ 3542fda6207SAart Bik } \ 3552fda6207SAart Bik return true; \ 3562fda6207SAart Bik } 3572fda6207SAart Bik 35899b3849dSAart Bik #define BINOP(xxx) \ 35999b3849dSAart Bik if (isa<xxx>(def)) { \ 36099b3849dSAart Bik if (codegen) \ 36199b3849dSAart Bik vexp = rewriter.create<xxx>(loc, vx, vy); \ 36299b3849dSAart Bik return true; \ 36399b3849dSAart Bik } 36499b3849dSAart Bik 36599b3849dSAart Bik /// This method is called twice to analyze and rewrite the given expression. 36699b3849dSAart Bik /// The first call (!codegen) does the analysis. Then, on success, the second 36799b3849dSAart Bik /// call (codegen) yields the proper vector form in the output parameter 'vexp'. 3682fda6207SAart Bik /// This mechanism ensures that analysis and rewriting code stay in sync. Note 369c43e6274STim Harvey /// that the analyis part is simple because the sparsifier only generates 3702fda6207SAart Bik /// relatively simple expressions inside the for-loops. 37199b3849dSAart Bik static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, 37299b3849dSAart Bik Value exp, bool codegen, Value vmask, Value &vexp) { 3732fda6207SAart Bik Location loc = forOp.getLoc(); 3742fda6207SAart Bik // Reject unsupported types. 3752fda6207SAart Bik if (!VectorType::isValidElementType(exp.getType())) 3762fda6207SAart Bik return false; 3772fda6207SAart Bik // A block argument is invariant/reduction/index. 378c1fa60b4STres Popp if (auto arg = llvm::dyn_cast<BlockArgument>(exp)) { 3792fda6207SAart Bik if (arg == forOp.getInductionVar()) { 3802fda6207SAart Bik // We encountered a single, innermost index inside the computation, 3812fda6207SAart Bik // such as a[i] = i, which must convert to [i, i+1, ...]. 3822fda6207SAart Bik if (codegen) { 3832fda6207SAart Bik VectorType vtp = vectorType(vl, arg.getType()); 3842fda6207SAart Bik Value veci = rewriter.create<vector::BroadcastOp>(loc, vtp, arg); 385074414fdSCullen Rhodes Value incr = rewriter.create<vector::StepOp>(loc, vtp); 3862fda6207SAart Bik vexp = rewriter.create<arith::AddIOp>(loc, veci, incr); 3872fda6207SAart Bik } 3882fda6207SAart Bik return true; 3893ab00672SAart Bik } 3903ab00672SAart Bik // An invariant or reduction. In both cases, we treat this as an 3912fda6207SAart Bik // invariant value, and rely on later replacing and folding to 3922fda6207SAart Bik // construct a proper reduction chain for the latter case. 39399b3849dSAart Bik if (codegen) 39499b3849dSAart Bik vexp = genVectorInvariantValue(rewriter, vl, exp); 39599b3849dSAart Bik return true; 39699b3849dSAart Bik } 3972fda6207SAart Bik // Something defined outside the loop-body is invariant. 39899b3849dSAart Bik Operation *def = exp.getDefiningOp(); 399431f6a54SAart Bik Block *block = &forOp.getRegion().front(); 400431f6a54SAart Bik if (def->getBlock() != block) { 40199b3849dSAart Bik if (codegen) 40299b3849dSAart Bik vexp = genVectorInvariantValue(rewriter, vl, exp); 40399b3849dSAart Bik return true; 40499b3849dSAart Bik } 4052fda6207SAart Bik // Proper load operations. These are either values involved in the 4062fda6207SAart Bik // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi], 40784cd51bbSwren romano // or coordinate values inside the computation that are now fetched from 40884cd51bbSwren romano // the sparse storage coordinates arrays, such as a[i] = i becomes 4092fda6207SAart Bik // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index 4102fda6207SAart Bik // and 'hi = lo + vl - 1'. 41199b3849dSAart Bik if (auto load = dyn_cast<memref::LoadOp>(def)) { 41299b3849dSAart Bik auto subs = load.getIndices(); 41399b3849dSAart Bik SmallVector<Value> idxs; 41499b3849dSAart Bik if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) { 41599b3849dSAart Bik if (codegen) 41699b3849dSAart Bik vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask); 41799b3849dSAart Bik return true; 41899b3849dSAart Bik } 4192fda6207SAart Bik return false; 4202fda6207SAart Bik } 4212fda6207SAart Bik // Inside loop-body unary and binary operations. Note that it would be 4222fda6207SAart Bik // nicer if we could somehow test and build the operations in a more 4232fda6207SAart Bik // concise manner than just listing them all (although this way we know 4242fda6207SAart Bik // for certain that they can vectorize). 4252fda6207SAart Bik // 4262fda6207SAart Bik // TODO: avoid visiting CSEs multiple times 4272fda6207SAart Bik // 4282fda6207SAart Bik if (def->getNumOperands() == 1) { 42999b3849dSAart Bik Value vx; 43099b3849dSAart Bik if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, 43199b3849dSAart Bik vx)) { 43299b3849dSAart Bik UNAOP(math::AbsFOp) 43399b3849dSAart Bik UNAOP(math::AbsIOp) 43499b3849dSAart Bik UNAOP(math::CeilOp) 43599b3849dSAart Bik UNAOP(math::FloorOp) 43699b3849dSAart Bik UNAOP(math::SqrtOp) 43799b3849dSAart Bik UNAOP(math::ExpM1Op) 43899b3849dSAart Bik UNAOP(math::Log1pOp) 43999b3849dSAart Bik UNAOP(math::SinOp) 44099b3849dSAart Bik UNAOP(math::TanhOp) 44199b3849dSAart Bik UNAOP(arith::NegFOp) 4422fda6207SAart Bik TYPEDUNAOP(arith::TruncFOp) 4432fda6207SAart Bik TYPEDUNAOP(arith::ExtFOp) 4442fda6207SAart Bik TYPEDUNAOP(arith::FPToSIOp) 4452fda6207SAart Bik TYPEDUNAOP(arith::FPToUIOp) 4462fda6207SAart Bik TYPEDUNAOP(arith::SIToFPOp) 4472fda6207SAart Bik TYPEDUNAOP(arith::UIToFPOp) 4482fda6207SAart Bik TYPEDUNAOP(arith::ExtSIOp) 4492fda6207SAart Bik TYPEDUNAOP(arith::ExtUIOp) 4502fda6207SAart Bik TYPEDUNAOP(arith::IndexCastOp) 4512fda6207SAart Bik TYPEDUNAOP(arith::TruncIOp) 4522fda6207SAart Bik TYPEDUNAOP(arith::BitcastOp) 45316aa4e4bSAart Bik // TODO: complex? 45499b3849dSAart Bik } 45599b3849dSAart Bik } else if (def->getNumOperands() == 2) { 45699b3849dSAart Bik Value vx, vy; 45799b3849dSAart Bik if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, 45899b3849dSAart Bik vx) && 45999b3849dSAart Bik vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask, 46099b3849dSAart Bik vy)) { 461431f6a54SAart Bik // We only accept shift-by-invariant (where the same shift factor applies 462431f6a54SAart Bik // to all packed elements). In the vector dialect, this is still 463431f6a54SAart Bik // represented with an expanded vector at the right-hand-side, however, 464431f6a54SAart Bik // so that we do not have to special case the code generation. 465431f6a54SAart Bik if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) || 466431f6a54SAart Bik isa<arith::ShRSIOp>(def)) { 467431f6a54SAart Bik Value shiftFactor = def->getOperand(1); 468431f6a54SAart Bik if (!isInvariantValue(shiftFactor, block)) 469431f6a54SAart Bik return false; 470431f6a54SAart Bik } 471431f6a54SAart Bik // Generate code. 47299b3849dSAart Bik BINOP(arith::MulFOp) 47399b3849dSAart Bik BINOP(arith::MulIOp) 47499b3849dSAart Bik BINOP(arith::DivFOp) 47599b3849dSAart Bik BINOP(arith::DivSIOp) 47699b3849dSAart Bik BINOP(arith::DivUIOp) 47799b3849dSAart Bik BINOP(arith::AddFOp) 47899b3849dSAart Bik BINOP(arith::AddIOp) 47999b3849dSAart Bik BINOP(arith::SubFOp) 48099b3849dSAart Bik BINOP(arith::SubIOp) 48199b3849dSAart Bik BINOP(arith::AndIOp) 48299b3849dSAart Bik BINOP(arith::OrIOp) 48399b3849dSAart Bik BINOP(arith::XOrIOp) 484431f6a54SAart Bik BINOP(arith::ShLIOp) 485431f6a54SAart Bik BINOP(arith::ShRUIOp) 486431f6a54SAart Bik BINOP(arith::ShRSIOp) 48716aa4e4bSAart Bik // TODO: complex? 48899b3849dSAart Bik } 48999b3849dSAart Bik } 49099b3849dSAart Bik return false; 49199b3849dSAart Bik } 49299b3849dSAart Bik 49399b3849dSAart Bik #undef UNAOP 4942fda6207SAart Bik #undef TYPEDUNAOP 49599b3849dSAart Bik #undef BINOP 49699b3849dSAart Bik 49799b3849dSAart Bik /// This method is called twice to analyze and rewrite the given for-loop. 49899b3849dSAart Bik /// The first call (!codegen) does the analysis. Then, on success, the second 49999b3849dSAart Bik /// call (codegen) rewriters the IR into vector form. This mechanism ensures 50099b3849dSAart Bik /// that analysis and rewriting code stay in sync. 50199b3849dSAart Bik static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, 50299b3849dSAart Bik bool codegen) { 50399b3849dSAart Bik Block &block = forOp.getRegion().front(); 50423dc96bbSPeiming Liu // For loops with single yield statement (as below) could be generated 50523dc96bbSPeiming Liu // when custom reduce is used with unary operation. 50623dc96bbSPeiming Liu // for (...) 50723dc96bbSPeiming Liu // yield c_0 50823dc96bbSPeiming Liu if (block.getOperations().size() <= 1) 50923dc96bbSPeiming Liu return false; 51023dc96bbSPeiming Liu 51123dc96bbSPeiming Liu Location loc = forOp.getLoc(); 51299b3849dSAart Bik scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator()); 51399b3849dSAart Bik auto &last = *++block.rbegin(); 51499b3849dSAart Bik scf::ForOp forOpNew; 51599b3849dSAart Bik 51699b3849dSAart Bik // Perform initial set up during codegen (we know that the first analysis 51799b3849dSAart Bik // pass was successful). For reductions, we need to construct a completely 51899b3849dSAart Bik // new for-loop, since the incoming and outgoing reduction type 51999b3849dSAart Bik // changes into SIMD form. For stores, we can simply adjust the stride 52099b3849dSAart Bik // and insert in the existing for-loop. In both cases, we set up a vector 52199b3849dSAart Bik // mask for all operations which takes care of confining vectors to 52299b3849dSAart Bik // the original iteration space (later cleanup loops or other 52399b3849dSAart Bik // optimizations can take care of those). 52499b3849dSAart Bik Value vmask; 52599b3849dSAart Bik if (codegen) { 52699b3849dSAart Bik Value step = constantIndex(rewriter, loc, vl.vectorLength); 52799b3849dSAart Bik if (vl.enableVLAVectorization) { 52899b3849dSAart Bik Value vscale = 52999b3849dSAart Bik rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); 53099b3849dSAart Bik step = rewriter.create<arith::MulIOp>(loc, vscale, step); 53199b3849dSAart Bik } 53299b3849dSAart Bik if (!yield.getResults().empty()) { 53399b3849dSAart Bik Value init = forOp.getInitArgs()[0]; 53499b3849dSAart Bik VectorType vtp = vectorType(vl, init.getType()); 535cb82d375SAart Bik Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0), 536cb82d375SAart Bik forOp.getRegionIterArg(0), init, vtp); 53799b3849dSAart Bik forOpNew = rewriter.create<scf::ForOp>( 53899b3849dSAart Bik loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit); 539cb82d375SAart Bik forOpNew->setAttr( 540781eabebSPeiming Liu LoopEmitter::getLoopEmitterLoopAttrName(), 541781eabebSPeiming Liu forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName())); 54299b3849dSAart Bik rewriter.setInsertionPointToStart(forOpNew.getBody()); 54399b3849dSAart Bik } else { 5445fcf907bSMatthias Springer rewriter.modifyOpInPlace(forOp, [&]() { forOp.setStep(step); }); 54599b3849dSAart Bik rewriter.setInsertionPoint(yield); 54699b3849dSAart Bik } 54799b3849dSAart Bik vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(), 54899b3849dSAart Bik forOp.getLowerBound(), forOp.getUpperBound(), step); 54999b3849dSAart Bik } 55099b3849dSAart Bik 55199b3849dSAart Bik // Sparse for-loops either are terminated by a non-empty yield operation 55299b3849dSAart Bik // (reduction loop) or otherwise by a store operation (pararallel loop). 55399b3849dSAart Bik if (!yield.getResults().empty()) { 554cb82d375SAart Bik // Analyze/vectorize reduction. 55599b3849dSAart Bik if (yield->getNumOperands() != 1) 55699b3849dSAart Bik return false; 557cb82d375SAart Bik Value red = yield->getOperand(0); 558cb82d375SAart Bik Value iter = forOp.getRegionIterArg(0); 559cb82d375SAart Bik vector::CombiningKind kind; 56099b3849dSAart Bik Value vrhs; 561cb82d375SAart Bik if (isVectorizableReduction(red, iter, kind) && 562cb82d375SAart Bik vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) { 56399b3849dSAart Bik if (codegen) { 564cb82d375SAart Bik Value partial = forOpNew.getResult(0); 565cb82d375SAart Bik Value vpass = genVectorInvariantValue(rewriter, vl, iter); 56699b3849dSAart Bik Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass); 56799b3849dSAart Bik rewriter.create<scf::YieldOp>(loc, vred); 56899b3849dSAart Bik rewriter.setInsertionPointAfter(forOpNew); 569cb82d375SAart Bik Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial); 57099b3849dSAart Bik // Now do some relinking (last one is not completely type safe 57199b3849dSAart Bik // but all bad ones are removed right away). This also folds away 57299b3849dSAart Bik // nop broadcast operations. 573ae9e1d1dSMatthias Springer rewriter.replaceAllUsesWith(forOp.getResult(0), vres); 574ae9e1d1dSMatthias Springer rewriter.replaceAllUsesWith(forOp.getInductionVar(), 575ae9e1d1dSMatthias Springer forOpNew.getInductionVar()); 576ae9e1d1dSMatthias Springer rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0), 57799b3849dSAart Bik forOpNew.getRegionIterArg(0)); 57899b3849dSAart Bik rewriter.eraseOp(forOp); 57999b3849dSAart Bik } 58099b3849dSAart Bik return true; 58199b3849dSAart Bik } 58299b3849dSAart Bik } else if (auto store = dyn_cast<memref::StoreOp>(last)) { 58399b3849dSAart Bik // Analyze/vectorize store operation. 58499b3849dSAart Bik auto subs = store.getIndices(); 58599b3849dSAart Bik SmallVector<Value> idxs; 58699b3849dSAart Bik Value rhs = store.getValue(); 58799b3849dSAart Bik Value vrhs; 58899b3849dSAart Bik if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) && 58999b3849dSAart Bik vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) { 59099b3849dSAart Bik if (codegen) { 59199b3849dSAart Bik genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs); 59299b3849dSAart Bik rewriter.eraseOp(store); 59399b3849dSAart Bik } 59499b3849dSAart Bik return true; 59599b3849dSAart Bik } 59699b3849dSAart Bik } 59799b3849dSAart Bik 59899b3849dSAart Bik assert(!codegen && "cannot call codegen when analysis failed"); 59999b3849dSAart Bik return false; 60099b3849dSAart Bik } 60199b3849dSAart Bik 60299b3849dSAart Bik /// Basic for-loop vectorizer. 60399b3849dSAart Bik struct ForOpRewriter : public OpRewritePattern<scf::ForOp> { 60499b3849dSAart Bik public: 60599b3849dSAart Bik using OpRewritePattern<scf::ForOp>::OpRewritePattern; 60699b3849dSAart Bik 60799b3849dSAart Bik ForOpRewriter(MLIRContext *context, unsigned vectorLength, 60899b3849dSAart Bik bool enableVLAVectorization, bool enableSIMDIndex32) 609781eabebSPeiming Liu : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization, 610781eabebSPeiming Liu enableSIMDIndex32} {} 61199b3849dSAart Bik 61299b3849dSAart Bik LogicalResult matchAndRewrite(scf::ForOp op, 61399b3849dSAart Bik PatternRewriter &rewriter) const override { 61499b3849dSAart Bik // Check for single block, unit-stride for-loop that is generated by 615c43e6274STim Harvey // sparsifier, which means no data dependence analysis is required, 61699b3849dSAart Bik // and its loop-body is very restricted in form. 61784cd51bbSwren romano if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) || 618781eabebSPeiming Liu !op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) 61999b3849dSAart Bik return failure(); 62099b3849dSAart Bik // Analyze (!codegen) and rewrite (codegen) loop-body. 62199b3849dSAart Bik if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) && 62299b3849dSAart Bik vectorizeStmt(rewriter, op, vl, /*codegen=*/true)) 62399b3849dSAart Bik return success(); 62499b3849dSAart Bik return failure(); 62599b3849dSAart Bik } 62699b3849dSAart Bik 62799b3849dSAart Bik private: 62899b3849dSAart Bik const VL vl; 62999b3849dSAart Bik }; 63099b3849dSAart Bik 631cb82d375SAart Bik /// Reduction chain cleanup. 632cb82d375SAart Bik /// v = for { } 633cb82d375SAart Bik /// s = vsum(v) v = for { } 634cb82d375SAart Bik /// u = expand(s) -> for (v) { } 635cb82d375SAart Bik /// for (u) { } 636cb82d375SAart Bik template <typename VectorOp> 637cb82d375SAart Bik struct ReducChainRewriter : public OpRewritePattern<VectorOp> { 638cb82d375SAart Bik public: 639cb82d375SAart Bik using OpRewritePattern<VectorOp>::OpRewritePattern; 640cb82d375SAart Bik 641cb82d375SAart Bik LogicalResult matchAndRewrite(VectorOp op, 642cb82d375SAart Bik PatternRewriter &rewriter) const override { 643cb82d375SAart Bik Value inp = op.getSource(); 644cb82d375SAart Bik if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) { 645cb82d375SAart Bik if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) { 646781eabebSPeiming Liu if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) { 647cb82d375SAart Bik rewriter.replaceOp(op, redOp.getVector()); 648cb82d375SAart Bik return success(); 649cb82d375SAart Bik } 650cb82d375SAart Bik } 651cb82d375SAart Bik } 652cb82d375SAart Bik return failure(); 653cb82d375SAart Bik } 654cb82d375SAart Bik }; 655cb82d375SAart Bik 65699b3849dSAart Bik } // namespace 65799b3849dSAart Bik 65899b3849dSAart Bik //===----------------------------------------------------------------------===// 65999b3849dSAart Bik // Public method for populating vectorization rules. 66099b3849dSAart Bik //===----------------------------------------------------------------------===// 66199b3849dSAart Bik 66299b3849dSAart Bik /// Populates the given patterns list with vectorization rules. 66399b3849dSAart Bik void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns, 66499b3849dSAart Bik unsigned vectorLength, 66599b3849dSAart Bik bool enableVLAVectorization, 66699b3849dSAart Bik bool enableSIMDIndex32) { 66716aa4e4bSAart Bik assert(vectorLength > 0); 668*a6e72f93SManupa Karunaratne vector::populateVectorStepLoweringPatterns(patterns); 66999b3849dSAart Bik patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength, 67099b3849dSAart Bik enableVLAVectorization, enableSIMDIndex32); 671cb82d375SAart Bik patterns.add<ReducChainRewriter<vector::InsertElementOp>, 672cb82d375SAart Bik ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext()); 67399b3849dSAart Bik } 674