xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (revision a6e72f93923378bffe13367f6dedd526ad39b184)
199b3849dSAart Bik //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===//
299b3849dSAart Bik //
399b3849dSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
499b3849dSAart Bik // See https://llvm.org/LICENSE.txt for license information.
599b3849dSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699b3849dSAart Bik //
799b3849dSAart Bik //===----------------------------------------------------------------------===//
899b3849dSAart Bik //
9c43e6274STim Harvey // A pass that converts loops generated by the sparsifier into a form that
1099b3849dSAart Bik // can exploit SIMD instructions of the target architecture. Note that this pass
11c43e6274STim Harvey // ensures the sparsifier can generate efficient SIMD (including ArmSVE
1299b3849dSAart Bik // support) with proper separation of concerns as far as sparsification and
1399b3849dSAart Bik // vectorization is concerned. However, this pass is not the final abstraction
1499b3849dSAart Bik // level we want, and not the general vectorizer we want either. It forms a good
1599b3849dSAart Bik // stepping stone for incremental future improvements though.
1699b3849dSAart Bik //
1799b3849dSAart Bik //===----------------------------------------------------------------------===//
1899b3849dSAart Bik 
19365777ecSAart Bik #include "Utils/CodegenUtils.h"
20365777ecSAart Bik #include "Utils/LoopEmitter.h"
2199b3849dSAart Bik 
2299b3849dSAart Bik #include "mlir/Dialect/Affine/IR/AffineOps.h"
2399b3849dSAart Bik #include "mlir/Dialect/Arith/IR/Arith.h"
2499b3849dSAart Bik #include "mlir/Dialect/Complex/IR/Complex.h"
2599b3849dSAart Bik #include "mlir/Dialect/Math/IR/Math.h"
2699b3849dSAart Bik #include "mlir/Dialect/MemRef/IR/MemRef.h"
2799b3849dSAart Bik #include "mlir/Dialect/SCF/IR/SCF.h"
2899b3849dSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
2999b3849dSAart Bik #include "mlir/Dialect/Vector/IR/VectorOps.h"
30*a6e72f93SManupa Karunaratne #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
3199b3849dSAart Bik #include "mlir/IR/Matchers.h"
3299b3849dSAart Bik 
3399b3849dSAart Bik using namespace mlir;
3499b3849dSAart Bik using namespace mlir::sparse_tensor;
3599b3849dSAart Bik 
3699b3849dSAart Bik namespace {
3799b3849dSAart Bik 
3899b3849dSAart Bik /// Target SIMD properties:
3999b3849dSAart Bik ///   vectorLength: # packed data elements (viz. vector<16xf32> has length 16)
4099b3849dSAart Bik ///   enableVLAVectorization: enables scalable vectors (viz. ARMSve)
4199b3849dSAart Bik ///   enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency
4299b3849dSAart Bik struct VL {
4399b3849dSAart Bik   unsigned vectorLength;
4499b3849dSAart Bik   bool enableVLAVectorization;
4599b3849dSAart Bik   bool enableSIMDIndex32;
4699b3849dSAart Bik };
4799b3849dSAart Bik 
48431f6a54SAart Bik /// Helper test for invariant value (defined outside given block).
49431f6a54SAart Bik static bool isInvariantValue(Value val, Block *block) {
50431f6a54SAart Bik   return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block;
51431f6a54SAart Bik }
52431f6a54SAart Bik 
53431f6a54SAart Bik /// Helper test for invariant argument (defined outside given block).
54431f6a54SAart Bik static bool isInvariantArg(BlockArgument arg, Block *block) {
55431f6a54SAart Bik   return arg.getOwner() != block;
56431f6a54SAart Bik }
57431f6a54SAart Bik 
5899b3849dSAart Bik /// Constructs vector type for element type.
5999b3849dSAart Bik static VectorType vectorType(VL vl, Type etp) {
60f22af204SAndrzej Warzynski   return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization);
6199b3849dSAart Bik }
6299b3849dSAart Bik 
6384cd51bbSwren romano /// Constructs vector type from a memref value.
6484cd51bbSwren romano static VectorType vectorType(VL vl, Value mem) {
6584cd51bbSwren romano   return vectorType(vl, getMemRefType(mem).getElementType());
6699b3849dSAart Bik }
6799b3849dSAart Bik 
6899b3849dSAart Bik /// Constructs vector iteration mask.
6999b3849dSAart Bik static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl,
7099b3849dSAart Bik                            Value iv, Value lo, Value hi, Value step) {
7199b3849dSAart Bik   VectorType mtp = vectorType(vl, rewriter.getI1Type());
7299b3849dSAart Bik   // Special case if the vector length evenly divides the trip count (for
7399b3849dSAart Bik   // example, "for i = 0, 128, 16"). A constant all-true mask is generated
7499b3849dSAart Bik   // so that all subsequent masked memory operations are immediately folded
7599b3849dSAart Bik   // into unconditional memory operations.
7699b3849dSAart Bik   IntegerAttr loInt, hiInt, stepInt;
7799b3849dSAart Bik   if (matchPattern(lo, m_Constant(&loInt)) &&
7899b3849dSAart Bik       matchPattern(hi, m_Constant(&hiInt)) &&
7999b3849dSAart Bik       matchPattern(step, m_Constant(&stepInt))) {
8099b3849dSAart Bik     if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
8199b3849dSAart Bik       Value trueVal = constantI1(rewriter, loc, true);
8299b3849dSAart Bik       return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal);
8399b3849dSAart Bik     }
8499b3849dSAart Bik   }
8599b3849dSAart Bik   // Otherwise, generate a vector mask that avoids overrunning the upperbound
8699b3849dSAart Bik   // during vector execution. Here we rely on subsequent loop optimizations to
8799b3849dSAart Bik   // avoid executing the mask in all iterations, for example, by splitting the
8899b3849dSAart Bik   // loop into an unconditional vector loop and a scalar cleanup loop.
8999b3849dSAart Bik   auto min = AffineMap::get(
9099b3849dSAart Bik       /*dimCount=*/2, /*symbolCount=*/1,
9199b3849dSAart Bik       {rewriter.getAffineSymbolExpr(0),
9299b3849dSAart Bik        rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
9399b3849dSAart Bik       rewriter.getContext());
944c48f016SMatthias Springer   Value end = rewriter.createOrFold<affine::AffineMinOp>(
954c48f016SMatthias Springer       loc, min, ValueRange{hi, iv, step});
9699b3849dSAart Bik   return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
9799b3849dSAart Bik }
9899b3849dSAart Bik 
9999b3849dSAart Bik /// Generates a vectorized invariant. Here we rely on subsequent loop
10099b3849dSAart Bik /// optimizations to hoist the invariant broadcast out of the vector loop.
10199b3849dSAart Bik static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl,
10299b3849dSAart Bik                                      Value val) {
10399b3849dSAart Bik   VectorType vtp = vectorType(vl, val.getType());
10499b3849dSAart Bik   return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
10599b3849dSAart Bik }
10699b3849dSAart Bik 
10799b3849dSAart Bik /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi],
108cb82d375SAart Bik /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
109c43e6274STim Harvey /// that the sparsifier can only generate indirect loads in
110cb82d375SAart Bik /// the last index, i.e. back().
11199b3849dSAart Bik static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
11284cd51bbSwren romano                            Value mem, ArrayRef<Value> idxs, Value vmask) {
11384cd51bbSwren romano   VectorType vtp = vectorType(vl, mem);
11499b3849dSAart Bik   Value pass = constantZero(rewriter, loc, vtp);
115c1fa60b4STres Popp   if (llvm::isa<VectorType>(idxs.back().getType())) {
1165262865aSKazu Hirata     SmallVector<Value> scalarArgs(idxs);
11799b3849dSAart Bik     Value indexVec = idxs.back();
11899b3849dSAart Bik     scalarArgs.back() = constantIndex(rewriter, loc, 0);
11984cd51bbSwren romano     return rewriter.create<vector::GatherOp>(loc, vtp, mem, scalarArgs,
12099b3849dSAart Bik                                              indexVec, vmask, pass);
12199b3849dSAart Bik   }
12284cd51bbSwren romano   return rewriter.create<vector::MaskedLoadOp>(loc, vtp, mem, idxs, vmask,
12399b3849dSAart Bik                                                pass);
12499b3849dSAart Bik }
12599b3849dSAart Bik 
12699b3849dSAart Bik /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs
127cb82d375SAart Bik /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
128c43e6274STim Harvey /// that the sparsifier can only generate indirect stores in
129cb82d375SAart Bik /// the last index, i.e. back().
13084cd51bbSwren romano static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem,
13199b3849dSAart Bik                            ArrayRef<Value> idxs, Value vmask, Value rhs) {
132c1fa60b4STres Popp   if (llvm::isa<VectorType>(idxs.back().getType())) {
1335262865aSKazu Hirata     SmallVector<Value> scalarArgs(idxs);
13499b3849dSAart Bik     Value indexVec = idxs.back();
13599b3849dSAart Bik     scalarArgs.back() = constantIndex(rewriter, loc, 0);
13684cd51bbSwren romano     rewriter.create<vector::ScatterOp>(loc, mem, scalarArgs, indexVec, vmask,
13799b3849dSAart Bik                                        rhs);
13899b3849dSAart Bik     return;
13999b3849dSAart Bik   }
14084cd51bbSwren romano   rewriter.create<vector::MaskedStoreOp>(loc, mem, idxs, vmask, rhs);
14199b3849dSAart Bik }
14299b3849dSAart Bik 
143cb82d375SAart Bik /// Detects a vectorizable reduction operations and returns the
144cb82d375SAart Bik /// combining kind of reduction on success in `kind`.
145cb82d375SAart Bik static bool isVectorizableReduction(Value red, Value iter,
146cb82d375SAart Bik                                     vector::CombiningKind &kind) {
147cb82d375SAart Bik   if (auto addf = red.getDefiningOp<arith::AddFOp>()) {
148cb82d375SAart Bik     kind = vector::CombiningKind::ADD;
149cb82d375SAart Bik     return addf->getOperand(0) == iter || addf->getOperand(1) == iter;
150cb82d375SAart Bik   }
151cb82d375SAart Bik   if (auto addi = red.getDefiningOp<arith::AddIOp>()) {
152cb82d375SAart Bik     kind = vector::CombiningKind::ADD;
153cb82d375SAart Bik     return addi->getOperand(0) == iter || addi->getOperand(1) == iter;
154cb82d375SAart Bik   }
155cb82d375SAart Bik   if (auto subf = red.getDefiningOp<arith::SubFOp>()) {
156cb82d375SAart Bik     kind = vector::CombiningKind::ADD;
157cb82d375SAart Bik     return subf->getOperand(0) == iter;
158cb82d375SAart Bik   }
159cb82d375SAart Bik   if (auto subi = red.getDefiningOp<arith::SubIOp>()) {
160cb82d375SAart Bik     kind = vector::CombiningKind::ADD;
161cb82d375SAart Bik     return subi->getOperand(0) == iter;
162cb82d375SAart Bik   }
163cb82d375SAart Bik   if (auto mulf = red.getDefiningOp<arith::MulFOp>()) {
164cb82d375SAart Bik     kind = vector::CombiningKind::MUL;
165cb82d375SAart Bik     return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter;
166cb82d375SAart Bik   }
167cb82d375SAart Bik   if (auto muli = red.getDefiningOp<arith::MulIOp>()) {
168cb82d375SAart Bik     kind = vector::CombiningKind::MUL;
169cb82d375SAart Bik     return muli->getOperand(0) == iter || muli->getOperand(1) == iter;
170cb82d375SAart Bik   }
171cb82d375SAart Bik   if (auto andi = red.getDefiningOp<arith::AndIOp>()) {
172cb82d375SAart Bik     kind = vector::CombiningKind::AND;
173cb82d375SAart Bik     return andi->getOperand(0) == iter || andi->getOperand(1) == iter;
174cb82d375SAart Bik   }
175cb82d375SAart Bik   if (auto ori = red.getDefiningOp<arith::OrIOp>()) {
176cb82d375SAart Bik     kind = vector::CombiningKind::OR;
177cb82d375SAart Bik     return ori->getOperand(0) == iter || ori->getOperand(1) == iter;
178cb82d375SAart Bik   }
179cb82d375SAart Bik   if (auto xori = red.getDefiningOp<arith::XOrIOp>()) {
180cb82d375SAart Bik     kind = vector::CombiningKind::XOR;
181cb82d375SAart Bik     return xori->getOperand(0) == iter || xori->getOperand(1) == iter;
182cb82d375SAart Bik   }
183cb82d375SAart Bik   return false;
18499b3849dSAart Bik }
18599b3849dSAart Bik 
18699b3849dSAart Bik /// Generates an initial value for a vector reduction, following the scheme
18799b3849dSAart Bik /// given in Chapter 5 of "The Software Vectorization Handbook", where the
18899b3849dSAart Bik /// initial scalar value is correctly embedded in the vector reduction value,
18999b3849dSAart Bik /// and a straightforward horizontal reduction will complete the operation.
190cb82d375SAart Bik /// Value 'r' denotes the initial value of the reduction outside the loop.
19199b3849dSAart Bik static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
192cb82d375SAart Bik                                 Value red, Value iter, Value r,
193cb82d375SAart Bik                                 VectorType vtp) {
194cb82d375SAart Bik   vector::CombiningKind kind;
195cb82d375SAart Bik   if (!isVectorizableReduction(red, iter, kind))
196cb82d375SAart Bik     llvm_unreachable("unknown reduction");
19799b3849dSAart Bik   switch (kind) {
19899b3849dSAart Bik   case vector::CombiningKind::ADD:
19999b3849dSAart Bik   case vector::CombiningKind::XOR:
20099b3849dSAart Bik     // Initialize reduction vector to: | 0 | .. | 0 | r |
20199b3849dSAart Bik     return rewriter.create<vector::InsertElementOp>(
20299b3849dSAart Bik         loc, r, constantZero(rewriter, loc, vtp),
20399b3849dSAart Bik         constantIndex(rewriter, loc, 0));
20499b3849dSAart Bik   case vector::CombiningKind::MUL:
20599b3849dSAart Bik     // Initialize reduction vector to: | 1 | .. | 1 | r |
20699b3849dSAart Bik     return rewriter.create<vector::InsertElementOp>(
20799b3849dSAart Bik         loc, r, constantOne(rewriter, loc, vtp),
20899b3849dSAart Bik         constantIndex(rewriter, loc, 0));
20999b3849dSAart Bik   case vector::CombiningKind::AND:
21099b3849dSAart Bik   case vector::CombiningKind::OR:
21199b3849dSAart Bik     // Initialize reduction vector to: | r | .. | r | r |
21299b3849dSAart Bik     return rewriter.create<vector::BroadcastOp>(loc, vtp, r);
21399b3849dSAart Bik   default:
21499b3849dSAart Bik     break;
21599b3849dSAart Bik   }
21699b3849dSAart Bik   llvm_unreachable("unknown reduction kind");
21799b3849dSAart Bik }
21899b3849dSAart Bik 
21999b3849dSAart Bik /// This method is called twice to analyze and rewrite the given subscripts.
22099b3849dSAart Bik /// The first call (!codegen) does the analysis. Then, on success, the second
22199b3849dSAart Bik /// call (codegen) yields the proper vector form in the output parameter
22299b3849dSAart Bik /// vector 'idxs'. This mechanism ensures that analysis and rewriting code
223c43e6274STim Harvey /// stay in sync. Note that the analyis part is simple because the sparsifier
224c43e6274STim Harvey /// only generates relatively simple subscript expressions.
22599b3849dSAart Bik ///
22699b3849dSAart Bik /// See https://llvm.org/docs/GetElementPtr.html for some background on
22799b3849dSAart Bik /// the complications described below.
22899b3849dSAart Bik ///
22984cd51bbSwren romano /// We need to generate a position/coordinate load from the sparse storage
23084cd51bbSwren romano /// scheme.  Narrower data types need to be zero extended before casting
23184cd51bbSwren romano /// the value into the `index` type used for looping and indexing.
23299b3849dSAart Bik ///
23399b3849dSAart Bik /// For the scalar case, subscripts simply zero extend narrower indices
23499b3849dSAart Bik /// into 64-bit values before casting to an index type without a performance
23599b3849dSAart Bik /// penalty. Indices that already are 64-bit, in theory, cannot express the
23699b3849dSAart Bik /// full range since the LLVM backend defines addressing in terms of an
23799b3849dSAart Bik /// unsigned pointer/signed index pair.
23899b3849dSAart Bik static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
23999b3849dSAart Bik                                 VL vl, ValueRange subs, bool codegen,
24099b3849dSAart Bik                                 Value vmask, SmallVectorImpl<Value> &idxs) {
24170ac5981SAart Bik   unsigned d = 0;
24270ac5981SAart Bik   unsigned dim = subs.size();
243431f6a54SAart Bik   Block *block = &forOp.getRegion().front();
24499b3849dSAart Bik   for (auto sub : subs) {
24570ac5981SAart Bik     bool innermost = ++d == dim;
24670ac5981SAart Bik     // Invariant subscripts in outer dimensions simply pass through.
24770ac5981SAart Bik     // Note that we rely on LICM to hoist loads where all subscripts
24870ac5981SAart Bik     // are invariant in the innermost loop.
249431f6a54SAart Bik     // Example:
250431f6a54SAart Bik     //   a[inv][i] for inv
251431f6a54SAart Bik     if (isInvariantValue(sub, block)) {
25270ac5981SAart Bik       if (innermost)
25370ac5981SAart Bik         return false;
25470ac5981SAart Bik       if (codegen)
25570ac5981SAart Bik         idxs.push_back(sub);
25670ac5981SAart Bik       continue; // success so far
25770ac5981SAart Bik     }
25870ac5981SAart Bik     // Invariant block arguments (including outer loop indices) in outer
25970ac5981SAart Bik     // dimensions simply pass through. Direct loop indices in the
26070ac5981SAart Bik     // innermost loop simply pass through as well.
261431f6a54SAart Bik     // Example:
262431f6a54SAart Bik     //   a[i][j] for both i and j
263c1fa60b4STres Popp     if (auto arg = llvm::dyn_cast<BlockArgument>(sub)) {
264431f6a54SAart Bik       if (isInvariantArg(arg, block) == innermost)
26570ac5981SAart Bik         return false;
26699b3849dSAart Bik       if (codegen)
26799b3849dSAart Bik         idxs.push_back(sub);
26899b3849dSAart Bik       continue; // success so far
26999b3849dSAart Bik     }
27099b3849dSAart Bik     // Look under the hood of casting.
27199b3849dSAart Bik     auto cast = sub;
272f083c9bdSAdrian Kuegel     while (true) {
27399b3849dSAart Bik       if (auto icast = cast.getDefiningOp<arith::IndexCastOp>())
27499b3849dSAart Bik         cast = icast->getOperand(0);
27599b3849dSAart Bik       else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
27699b3849dSAart Bik         cast = ecast->getOperand(0);
27799b3849dSAart Bik       else
27899b3849dSAart Bik         break;
27999b3849dSAart Bik     }
28099b3849dSAart Bik     // Since the index vector is used in a subsequent gather/scatter
28199b3849dSAart Bik     // operations, which effectively defines an unsigned pointer + signed
28299b3849dSAart Bik     // index, we must zero extend the vector to an index width. For 8-bit
28399b3849dSAart Bik     // and 16-bit values, an 32-bit index width suffices. For 32-bit values,
28499b3849dSAart Bik     // zero extending the elements into 64-bit loses some performance since
28599b3849dSAart Bik     // the 32-bit indexed gather/scatter is more efficient than the 64-bit
28699b3849dSAart Bik     // index variant (if the negative 32-bit index space is unused, the
28799b3849dSAart Bik     // enableSIMDIndex32 flag can preserve this performance). For 64-bit
28899b3849dSAart Bik     // values, there is no good way to state that the indices are unsigned,
28999b3849dSAart Bik     // which creates the potential of incorrect address calculations in the
29099b3849dSAart Bik     // unlikely case we need such extremely large offsets.
291431f6a54SAart Bik     // Example:
292431f6a54SAart Bik     //    a[ ind[i] ]
29399b3849dSAart Bik     if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
29470ac5981SAart Bik       if (!innermost)
29570ac5981SAart Bik         return false;
29699b3849dSAart Bik       if (codegen) {
29799b3849dSAart Bik         SmallVector<Value> idxs2(load.getIndices()); // no need to analyze
29899b3849dSAart Bik         Location loc = forOp.getLoc();
29999b3849dSAart Bik         Value vload =
30099b3849dSAart Bik             genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
301c1fa60b4STres Popp         Type etp = llvm::cast<VectorType>(vload.getType()).getElementType();
302c1fa60b4STres Popp         if (!llvm::isa<IndexType>(etp)) {
30399b3849dSAart Bik           if (etp.getIntOrFloatBitWidth() < 32)
30499b3849dSAart Bik             vload = rewriter.create<arith::ExtUIOp>(
30599b3849dSAart Bik                 loc, vectorType(vl, rewriter.getI32Type()), vload);
30699b3849dSAart Bik           else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32)
30799b3849dSAart Bik             vload = rewriter.create<arith::ExtUIOp>(
30899b3849dSAart Bik                 loc, vectorType(vl, rewriter.getI64Type()), vload);
30999b3849dSAart Bik         }
31099b3849dSAart Bik         idxs.push_back(vload);
31199b3849dSAart Bik       }
31299b3849dSAart Bik       continue; // success so far
31399b3849dSAart Bik     }
31416aa4e4bSAart Bik     // Address calculation 'i = add inv, idx' (after LICM).
315431f6a54SAart Bik     // Example:
316431f6a54SAart Bik     //    a[base + i]
31716aa4e4bSAart Bik     if (auto load = cast.getDefiningOp<arith::AddIOp>()) {
31816aa4e4bSAart Bik       Value inv = load.getOperand(0);
31916aa4e4bSAart Bik       Value idx = load.getOperand(1);
3205a9af39aSAart Bik       // Swap non-invariant.
3215a9af39aSAart Bik       if (!isInvariantValue(inv, block)) {
3225a9af39aSAart Bik         inv = idx;
3235a9af39aSAart Bik         idx = load.getOperand(0);
3245a9af39aSAart Bik       }
3255a9af39aSAart Bik       // Inspect.
326431f6a54SAart Bik       if (isInvariantValue(inv, block)) {
327c1fa60b4STres Popp         if (auto arg = llvm::dyn_cast<BlockArgument>(idx)) {
328431f6a54SAart Bik           if (isInvariantArg(arg, block) || !innermost)
32970ac5981SAart Bik             return false;
33016aa4e4bSAart Bik           if (codegen)
33116aa4e4bSAart Bik             idxs.push_back(
33216aa4e4bSAart Bik                 rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx));
33316aa4e4bSAart Bik           continue; // success so far
33416aa4e4bSAart Bik         }
33516aa4e4bSAart Bik       }
336431f6a54SAart Bik     }
33799b3849dSAart Bik     return false;
33899b3849dSAart Bik   }
33999b3849dSAart Bik   return true;
34099b3849dSAart Bik }
34199b3849dSAart Bik 
34299b3849dSAart Bik #define UNAOP(xxx)                                                             \
34399b3849dSAart Bik   if (isa<xxx>(def)) {                                                         \
34499b3849dSAart Bik     if (codegen)                                                               \
34599b3849dSAart Bik       vexp = rewriter.create<xxx>(loc, vx);                                    \
34699b3849dSAart Bik     return true;                                                               \
34799b3849dSAart Bik   }
34899b3849dSAart Bik 
3492fda6207SAart Bik #define TYPEDUNAOP(xxx)                                                        \
3502fda6207SAart Bik   if (auto x = dyn_cast<xxx>(def)) {                                           \
3512fda6207SAart Bik     if (codegen) {                                                             \
3522fda6207SAart Bik       VectorType vtp = vectorType(vl, x.getType());                            \
3532fda6207SAart Bik       vexp = rewriter.create<xxx>(loc, vtp, vx);                               \
3542fda6207SAart Bik     }                                                                          \
3552fda6207SAart Bik     return true;                                                               \
3562fda6207SAart Bik   }
3572fda6207SAart Bik 
35899b3849dSAart Bik #define BINOP(xxx)                                                             \
35999b3849dSAart Bik   if (isa<xxx>(def)) {                                                         \
36099b3849dSAart Bik     if (codegen)                                                               \
36199b3849dSAart Bik       vexp = rewriter.create<xxx>(loc, vx, vy);                                \
36299b3849dSAart Bik     return true;                                                               \
36399b3849dSAart Bik   }
36499b3849dSAart Bik 
36599b3849dSAart Bik /// This method is called twice to analyze and rewrite the given expression.
36699b3849dSAart Bik /// The first call (!codegen) does the analysis. Then, on success, the second
36799b3849dSAart Bik /// call (codegen) yields the proper vector form in the output parameter 'vexp'.
3682fda6207SAart Bik /// This mechanism ensures that analysis and rewriting code stay in sync. Note
369c43e6274STim Harvey /// that the analyis part is simple because the sparsifier only generates
3702fda6207SAart Bik /// relatively simple expressions inside the for-loops.
37199b3849dSAart Bik static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
37299b3849dSAart Bik                           Value exp, bool codegen, Value vmask, Value &vexp) {
3732fda6207SAart Bik   Location loc = forOp.getLoc();
3742fda6207SAart Bik   // Reject unsupported types.
3752fda6207SAart Bik   if (!VectorType::isValidElementType(exp.getType()))
3762fda6207SAart Bik     return false;
3772fda6207SAart Bik   // A block argument is invariant/reduction/index.
378c1fa60b4STres Popp   if (auto arg = llvm::dyn_cast<BlockArgument>(exp)) {
3792fda6207SAart Bik     if (arg == forOp.getInductionVar()) {
3802fda6207SAart Bik       // We encountered a single, innermost index inside the computation,
3812fda6207SAart Bik       // such as a[i] = i, which must convert to [i, i+1, ...].
3822fda6207SAart Bik       if (codegen) {
3832fda6207SAart Bik         VectorType vtp = vectorType(vl, arg.getType());
3842fda6207SAart Bik         Value veci = rewriter.create<vector::BroadcastOp>(loc, vtp, arg);
385074414fdSCullen Rhodes         Value incr = rewriter.create<vector::StepOp>(loc, vtp);
3862fda6207SAart Bik         vexp = rewriter.create<arith::AddIOp>(loc, veci, incr);
3872fda6207SAart Bik       }
3882fda6207SAart Bik       return true;
3893ab00672SAart Bik     }
3903ab00672SAart Bik     // An invariant or reduction. In both cases, we treat this as an
3912fda6207SAart Bik     // invariant value, and rely on later replacing and folding to
3922fda6207SAart Bik     // construct a proper reduction chain for the latter case.
39399b3849dSAart Bik     if (codegen)
39499b3849dSAart Bik       vexp = genVectorInvariantValue(rewriter, vl, exp);
39599b3849dSAart Bik     return true;
39699b3849dSAart Bik   }
3972fda6207SAart Bik   // Something defined outside the loop-body is invariant.
39899b3849dSAart Bik   Operation *def = exp.getDefiningOp();
399431f6a54SAart Bik   Block *block = &forOp.getRegion().front();
400431f6a54SAart Bik   if (def->getBlock() != block) {
40199b3849dSAart Bik     if (codegen)
40299b3849dSAart Bik       vexp = genVectorInvariantValue(rewriter, vl, exp);
40399b3849dSAart Bik     return true;
40499b3849dSAart Bik   }
4052fda6207SAart Bik   // Proper load operations. These are either values involved in the
4062fda6207SAart Bik   // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi],
40784cd51bbSwren romano   // or coordinate values inside the computation that are now fetched from
40884cd51bbSwren romano   // the sparse storage coordinates arrays, such as a[i] = i becomes
4092fda6207SAart Bik   // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index
4102fda6207SAart Bik   // and 'hi = lo + vl - 1'.
41199b3849dSAart Bik   if (auto load = dyn_cast<memref::LoadOp>(def)) {
41299b3849dSAart Bik     auto subs = load.getIndices();
41399b3849dSAart Bik     SmallVector<Value> idxs;
41499b3849dSAart Bik     if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
41599b3849dSAart Bik       if (codegen)
41699b3849dSAart Bik         vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
41799b3849dSAart Bik       return true;
41899b3849dSAart Bik     }
4192fda6207SAart Bik     return false;
4202fda6207SAart Bik   }
4212fda6207SAart Bik   // Inside loop-body unary and binary operations. Note that it would be
4222fda6207SAart Bik   // nicer if we could somehow test and build the operations in a more
4232fda6207SAart Bik   // concise manner than just listing them all (although this way we know
4242fda6207SAart Bik   // for certain that they can vectorize).
4252fda6207SAart Bik   //
4262fda6207SAart Bik   // TODO: avoid visiting CSEs multiple times
4272fda6207SAart Bik   //
4282fda6207SAart Bik   if (def->getNumOperands() == 1) {
42999b3849dSAart Bik     Value vx;
43099b3849dSAart Bik     if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
43199b3849dSAart Bik                       vx)) {
43299b3849dSAart Bik       UNAOP(math::AbsFOp)
43399b3849dSAart Bik       UNAOP(math::AbsIOp)
43499b3849dSAart Bik       UNAOP(math::CeilOp)
43599b3849dSAart Bik       UNAOP(math::FloorOp)
43699b3849dSAart Bik       UNAOP(math::SqrtOp)
43799b3849dSAart Bik       UNAOP(math::ExpM1Op)
43899b3849dSAart Bik       UNAOP(math::Log1pOp)
43999b3849dSAart Bik       UNAOP(math::SinOp)
44099b3849dSAart Bik       UNAOP(math::TanhOp)
44199b3849dSAart Bik       UNAOP(arith::NegFOp)
4422fda6207SAart Bik       TYPEDUNAOP(arith::TruncFOp)
4432fda6207SAart Bik       TYPEDUNAOP(arith::ExtFOp)
4442fda6207SAart Bik       TYPEDUNAOP(arith::FPToSIOp)
4452fda6207SAart Bik       TYPEDUNAOP(arith::FPToUIOp)
4462fda6207SAart Bik       TYPEDUNAOP(arith::SIToFPOp)
4472fda6207SAart Bik       TYPEDUNAOP(arith::UIToFPOp)
4482fda6207SAart Bik       TYPEDUNAOP(arith::ExtSIOp)
4492fda6207SAart Bik       TYPEDUNAOP(arith::ExtUIOp)
4502fda6207SAart Bik       TYPEDUNAOP(arith::IndexCastOp)
4512fda6207SAart Bik       TYPEDUNAOP(arith::TruncIOp)
4522fda6207SAart Bik       TYPEDUNAOP(arith::BitcastOp)
45316aa4e4bSAart Bik       // TODO: complex?
45499b3849dSAart Bik     }
45599b3849dSAart Bik   } else if (def->getNumOperands() == 2) {
45699b3849dSAart Bik     Value vx, vy;
45799b3849dSAart Bik     if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
45899b3849dSAart Bik                       vx) &&
45999b3849dSAart Bik         vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask,
46099b3849dSAart Bik                       vy)) {
461431f6a54SAart Bik       // We only accept shift-by-invariant (where the same shift factor applies
462431f6a54SAart Bik       // to all packed elements). In the vector dialect, this is still
463431f6a54SAart Bik       // represented with an expanded vector at the right-hand-side, however,
464431f6a54SAart Bik       // so that we do not have to special case the code generation.
465431f6a54SAart Bik       if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) ||
466431f6a54SAart Bik           isa<arith::ShRSIOp>(def)) {
467431f6a54SAart Bik         Value shiftFactor = def->getOperand(1);
468431f6a54SAart Bik         if (!isInvariantValue(shiftFactor, block))
469431f6a54SAart Bik           return false;
470431f6a54SAart Bik       }
471431f6a54SAart Bik       // Generate code.
47299b3849dSAart Bik       BINOP(arith::MulFOp)
47399b3849dSAart Bik       BINOP(arith::MulIOp)
47499b3849dSAart Bik       BINOP(arith::DivFOp)
47599b3849dSAart Bik       BINOP(arith::DivSIOp)
47699b3849dSAart Bik       BINOP(arith::DivUIOp)
47799b3849dSAart Bik       BINOP(arith::AddFOp)
47899b3849dSAart Bik       BINOP(arith::AddIOp)
47999b3849dSAart Bik       BINOP(arith::SubFOp)
48099b3849dSAart Bik       BINOP(arith::SubIOp)
48199b3849dSAart Bik       BINOP(arith::AndIOp)
48299b3849dSAart Bik       BINOP(arith::OrIOp)
48399b3849dSAart Bik       BINOP(arith::XOrIOp)
484431f6a54SAart Bik       BINOP(arith::ShLIOp)
485431f6a54SAart Bik       BINOP(arith::ShRUIOp)
486431f6a54SAart Bik       BINOP(arith::ShRSIOp)
48716aa4e4bSAart Bik       // TODO: complex?
48899b3849dSAart Bik     }
48999b3849dSAart Bik   }
49099b3849dSAart Bik   return false;
49199b3849dSAart Bik }
49299b3849dSAart Bik 
49399b3849dSAart Bik #undef UNAOP
4942fda6207SAart Bik #undef TYPEDUNAOP
49599b3849dSAart Bik #undef BINOP
49699b3849dSAart Bik 
49799b3849dSAart Bik /// This method is called twice to analyze and rewrite the given for-loop.
49899b3849dSAart Bik /// The first call (!codegen) does the analysis. Then, on success, the second
49999b3849dSAart Bik /// call (codegen) rewriters the IR into vector form. This mechanism ensures
50099b3849dSAart Bik /// that analysis and rewriting code stay in sync.
50199b3849dSAart Bik static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
50299b3849dSAart Bik                           bool codegen) {
50399b3849dSAart Bik   Block &block = forOp.getRegion().front();
50423dc96bbSPeiming Liu   // For loops with single yield statement (as below) could be generated
50523dc96bbSPeiming Liu   // when custom reduce is used with unary operation.
50623dc96bbSPeiming Liu   // for (...)
50723dc96bbSPeiming Liu   //   yield c_0
50823dc96bbSPeiming Liu   if (block.getOperations().size() <= 1)
50923dc96bbSPeiming Liu     return false;
51023dc96bbSPeiming Liu 
51123dc96bbSPeiming Liu   Location loc = forOp.getLoc();
51299b3849dSAart Bik   scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator());
51399b3849dSAart Bik   auto &last = *++block.rbegin();
51499b3849dSAart Bik   scf::ForOp forOpNew;
51599b3849dSAart Bik 
51699b3849dSAart Bik   // Perform initial set up during codegen (we know that the first analysis
51799b3849dSAart Bik   // pass was successful). For reductions, we need to construct a completely
51899b3849dSAart Bik   // new for-loop, since the incoming and outgoing reduction type
51999b3849dSAart Bik   // changes into SIMD form. For stores, we can simply adjust the stride
52099b3849dSAart Bik   // and insert in the existing for-loop. In both cases, we set up a vector
52199b3849dSAart Bik   // mask for all operations which takes care of confining vectors to
52299b3849dSAart Bik   // the original iteration space (later cleanup loops or other
52399b3849dSAart Bik   // optimizations can take care of those).
52499b3849dSAart Bik   Value vmask;
52599b3849dSAart Bik   if (codegen) {
52699b3849dSAart Bik     Value step = constantIndex(rewriter, loc, vl.vectorLength);
52799b3849dSAart Bik     if (vl.enableVLAVectorization) {
52899b3849dSAart Bik       Value vscale =
52999b3849dSAart Bik           rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
53099b3849dSAart Bik       step = rewriter.create<arith::MulIOp>(loc, vscale, step);
53199b3849dSAart Bik     }
53299b3849dSAart Bik     if (!yield.getResults().empty()) {
53399b3849dSAart Bik       Value init = forOp.getInitArgs()[0];
53499b3849dSAart Bik       VectorType vtp = vectorType(vl, init.getType());
535cb82d375SAart Bik       Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
536cb82d375SAart Bik                                        forOp.getRegionIterArg(0), init, vtp);
53799b3849dSAart Bik       forOpNew = rewriter.create<scf::ForOp>(
53899b3849dSAart Bik           loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
539cb82d375SAart Bik       forOpNew->setAttr(
540781eabebSPeiming Liu           LoopEmitter::getLoopEmitterLoopAttrName(),
541781eabebSPeiming Liu           forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
54299b3849dSAart Bik       rewriter.setInsertionPointToStart(forOpNew.getBody());
54399b3849dSAart Bik     } else {
5445fcf907bSMatthias Springer       rewriter.modifyOpInPlace(forOp, [&]() { forOp.setStep(step); });
54599b3849dSAart Bik       rewriter.setInsertionPoint(yield);
54699b3849dSAart Bik     }
54799b3849dSAart Bik     vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
54899b3849dSAart Bik                           forOp.getLowerBound(), forOp.getUpperBound(), step);
54999b3849dSAart Bik   }
55099b3849dSAart Bik 
55199b3849dSAart Bik   // Sparse for-loops either are terminated by a non-empty yield operation
55299b3849dSAart Bik   // (reduction loop) or otherwise by a store operation (pararallel loop).
55399b3849dSAart Bik   if (!yield.getResults().empty()) {
554cb82d375SAart Bik     // Analyze/vectorize reduction.
55599b3849dSAart Bik     if (yield->getNumOperands() != 1)
55699b3849dSAart Bik       return false;
557cb82d375SAart Bik     Value red = yield->getOperand(0);
558cb82d375SAart Bik     Value iter = forOp.getRegionIterArg(0);
559cb82d375SAart Bik     vector::CombiningKind kind;
56099b3849dSAart Bik     Value vrhs;
561cb82d375SAart Bik     if (isVectorizableReduction(red, iter, kind) &&
562cb82d375SAart Bik         vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) {
56399b3849dSAart Bik       if (codegen) {
564cb82d375SAart Bik         Value partial = forOpNew.getResult(0);
565cb82d375SAart Bik         Value vpass = genVectorInvariantValue(rewriter, vl, iter);
56699b3849dSAart Bik         Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass);
56799b3849dSAart Bik         rewriter.create<scf::YieldOp>(loc, vred);
56899b3849dSAart Bik         rewriter.setInsertionPointAfter(forOpNew);
569cb82d375SAart Bik         Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial);
57099b3849dSAart Bik         // Now do some relinking (last one is not completely type safe
57199b3849dSAart Bik         // but all bad ones are removed right away). This also folds away
57299b3849dSAart Bik         // nop broadcast operations.
573ae9e1d1dSMatthias Springer         rewriter.replaceAllUsesWith(forOp.getResult(0), vres);
574ae9e1d1dSMatthias Springer         rewriter.replaceAllUsesWith(forOp.getInductionVar(),
575ae9e1d1dSMatthias Springer                                     forOpNew.getInductionVar());
576ae9e1d1dSMatthias Springer         rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0),
57799b3849dSAart Bik                                     forOpNew.getRegionIterArg(0));
57899b3849dSAart Bik         rewriter.eraseOp(forOp);
57999b3849dSAart Bik       }
58099b3849dSAart Bik       return true;
58199b3849dSAart Bik     }
58299b3849dSAart Bik   } else if (auto store = dyn_cast<memref::StoreOp>(last)) {
58399b3849dSAart Bik     // Analyze/vectorize store operation.
58499b3849dSAart Bik     auto subs = store.getIndices();
58599b3849dSAart Bik     SmallVector<Value> idxs;
58699b3849dSAart Bik     Value rhs = store.getValue();
58799b3849dSAart Bik     Value vrhs;
58899b3849dSAart Bik     if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
58999b3849dSAart Bik         vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
59099b3849dSAart Bik       if (codegen) {
59199b3849dSAart Bik         genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
59299b3849dSAart Bik         rewriter.eraseOp(store);
59399b3849dSAart Bik       }
59499b3849dSAart Bik       return true;
59599b3849dSAart Bik     }
59699b3849dSAart Bik   }
59799b3849dSAart Bik 
59899b3849dSAart Bik   assert(!codegen && "cannot call codegen when analysis failed");
59999b3849dSAart Bik   return false;
60099b3849dSAart Bik }
60199b3849dSAart Bik 
60299b3849dSAart Bik /// Basic for-loop vectorizer.
60399b3849dSAart Bik struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
60499b3849dSAart Bik public:
60599b3849dSAart Bik   using OpRewritePattern<scf::ForOp>::OpRewritePattern;
60699b3849dSAart Bik 
60799b3849dSAart Bik   ForOpRewriter(MLIRContext *context, unsigned vectorLength,
60899b3849dSAart Bik                 bool enableVLAVectorization, bool enableSIMDIndex32)
609781eabebSPeiming Liu       : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization,
610781eabebSPeiming Liu                                       enableSIMDIndex32} {}
61199b3849dSAart Bik 
61299b3849dSAart Bik   LogicalResult matchAndRewrite(scf::ForOp op,
61399b3849dSAart Bik                                 PatternRewriter &rewriter) const override {
61499b3849dSAart Bik     // Check for single block, unit-stride for-loop that is generated by
615c43e6274STim Harvey     // sparsifier, which means no data dependence analysis is required,
61699b3849dSAart Bik     // and its loop-body is very restricted in form.
61784cd51bbSwren romano     if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) ||
618781eabebSPeiming Liu         !op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()))
61999b3849dSAart Bik       return failure();
62099b3849dSAart Bik     // Analyze (!codegen) and rewrite (codegen) loop-body.
62199b3849dSAart Bik     if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) &&
62299b3849dSAart Bik         vectorizeStmt(rewriter, op, vl, /*codegen=*/true))
62399b3849dSAart Bik       return success();
62499b3849dSAart Bik     return failure();
62599b3849dSAart Bik   }
62699b3849dSAart Bik 
62799b3849dSAart Bik private:
62899b3849dSAart Bik   const VL vl;
62999b3849dSAart Bik };
63099b3849dSAart Bik 
631cb82d375SAart Bik /// Reduction chain cleanup.
632cb82d375SAart Bik ///   v = for { }
633cb82d375SAart Bik ///   s = vsum(v)               v = for { }
634cb82d375SAart Bik ///   u = expand(s)       ->    for (v) { }
635cb82d375SAart Bik ///   for (u) { }
636cb82d375SAart Bik template <typename VectorOp>
637cb82d375SAart Bik struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
638cb82d375SAart Bik public:
639cb82d375SAart Bik   using OpRewritePattern<VectorOp>::OpRewritePattern;
640cb82d375SAart Bik 
641cb82d375SAart Bik   LogicalResult matchAndRewrite(VectorOp op,
642cb82d375SAart Bik                                 PatternRewriter &rewriter) const override {
643cb82d375SAart Bik     Value inp = op.getSource();
644cb82d375SAart Bik     if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
645cb82d375SAart Bik       if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
646781eabebSPeiming Liu         if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
647cb82d375SAart Bik           rewriter.replaceOp(op, redOp.getVector());
648cb82d375SAart Bik           return success();
649cb82d375SAart Bik         }
650cb82d375SAart Bik       }
651cb82d375SAart Bik     }
652cb82d375SAart Bik     return failure();
653cb82d375SAart Bik   }
654cb82d375SAart Bik };
655cb82d375SAart Bik 
65699b3849dSAart Bik } // namespace
65799b3849dSAart Bik 
65899b3849dSAart Bik //===----------------------------------------------------------------------===//
65999b3849dSAart Bik // Public method for populating vectorization rules.
66099b3849dSAart Bik //===----------------------------------------------------------------------===//
66199b3849dSAart Bik 
66299b3849dSAart Bik /// Populates the given patterns list with vectorization rules.
66399b3849dSAart Bik void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
66499b3849dSAart Bik                                                unsigned vectorLength,
66599b3849dSAart Bik                                                bool enableVLAVectorization,
66699b3849dSAart Bik                                                bool enableSIMDIndex32) {
66716aa4e4bSAart Bik   assert(vectorLength > 0);
668*a6e72f93SManupa Karunaratne   vector::populateVectorStepLoweringPatterns(patterns);
66999b3849dSAart Bik   patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
67099b3849dSAart Bik                               enableVLAVectorization, enableSIMDIndex32);
671cb82d375SAart Bik   patterns.add<ReducChainRewriter<vector::InsertElementOp>,
672cb82d375SAart Bik                ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
67399b3849dSAart Bik }
674