xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (revision 2fda6207118d1d1c19e3b66f615f332ffc2792d0)
1 //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts loops generated by the sparse compiler into a form that
10 // can exploit SIMD instructions of the target architecture. Note that this pass
11 // ensures the sparse compiler can generate efficient SIMD (including ArmSVE
12 // support) with proper separation of concerns as far as sparsification and
13 // vectorization is concerned. However, this pass is not the final abstraction
14 // level we want, and not the general vectorizer we want either. It forms a good
15 // stepping stone for incremental future improvements though.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "CodegenUtils.h"
20 
21 #include "mlir/Dialect/Affine/IR/AffineOps.h"
22 #include "mlir/Dialect/Arith/IR/Arith.h"
23 #include "mlir/Dialect/Complex/IR/Complex.h"
24 #include "mlir/Dialect/Math/IR/Math.h"
25 #include "mlir/Dialect/MemRef/IR/MemRef.h"
26 #include "mlir/Dialect/SCF/IR/SCF.h"
27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
28 #include "mlir/Dialect/Vector/IR/VectorOps.h"
29 #include "mlir/IR/Matchers.h"
30 
31 using namespace mlir;
32 using namespace mlir::sparse_tensor;
33 
34 namespace {
35 
36 /// Target SIMD properties:
37 ///   vectorLength: # packed data elements (viz. vector<16xf32> has length 16)
38 ///   enableVLAVectorization: enables scalable vectors (viz. ARMSve)
39 ///   enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency
40 struct VL {
41   unsigned vectorLength;
42   bool enableVLAVectorization;
43   bool enableSIMDIndex32;
44 };
45 
46 /// Helper to test for given index value.
47 static bool isIntValue(Value val, int64_t idx) {
48   if (auto ival = getConstantIntValue(val))
49     return *ival == idx;
50   return false;
51 }
52 
53 /// Constructs vector type for element type.
54 static VectorType vectorType(VL vl, Type etp) {
55   unsigned numScalableDims = vl.enableVLAVectorization;
56   return VectorType::get(vl.vectorLength, etp, numScalableDims);
57 }
58 
59 /// Constructs vector type from pointer.
60 static VectorType vectorType(VL vl, Value ptr) {
61   return vectorType(vl, ptr.getType().cast<MemRefType>().getElementType());
62 }
63 
64 /// Constructs vector iteration mask.
65 static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl,
66                            Value iv, Value lo, Value hi, Value step) {
67   VectorType mtp = vectorType(vl, rewriter.getI1Type());
68   // Special case if the vector length evenly divides the trip count (for
69   // example, "for i = 0, 128, 16"). A constant all-true mask is generated
70   // so that all subsequent masked memory operations are immediately folded
71   // into unconditional memory operations.
72   IntegerAttr loInt, hiInt, stepInt;
73   if (matchPattern(lo, m_Constant(&loInt)) &&
74       matchPattern(hi, m_Constant(&hiInt)) &&
75       matchPattern(step, m_Constant(&stepInt))) {
76     if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
77       Value trueVal = constantI1(rewriter, loc, true);
78       return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal);
79     }
80   }
81   // Otherwise, generate a vector mask that avoids overrunning the upperbound
82   // during vector execution. Here we rely on subsequent loop optimizations to
83   // avoid executing the mask in all iterations, for example, by splitting the
84   // loop into an unconditional vector loop and a scalar cleanup loop.
85   auto min = AffineMap::get(
86       /*dimCount=*/2, /*symbolCount=*/1,
87       {rewriter.getAffineSymbolExpr(0),
88        rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
89       rewriter.getContext());
90   Value end =
91       rewriter.createOrFold<AffineMinOp>(loc, min, ValueRange{hi, iv, step});
92   return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
93 }
94 
95 /// Generates a vectorized invariant. Here we rely on subsequent loop
96 /// optimizations to hoist the invariant broadcast out of the vector loop.
97 static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl,
98                                      Value val) {
99   VectorType vtp = vectorType(vl, val.getType());
100   return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
101 }
102 
103 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi],
104 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
105 /// that the sparse compiler can only generate indirect loads in
106 /// the last index, i.e. back().
107 static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
108                            Value ptr, ArrayRef<Value> idxs, Value vmask) {
109   VectorType vtp = vectorType(vl, ptr);
110   Value pass = constantZero(rewriter, loc, vtp);
111   if (idxs.back().getType().isa<VectorType>()) {
112     SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
113     Value indexVec = idxs.back();
114     scalarArgs.back() = constantIndex(rewriter, loc, 0);
115     return rewriter.create<vector::GatherOp>(loc, vtp, ptr, scalarArgs,
116                                              indexVec, vmask, pass);
117   }
118   return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, idxs, vmask,
119                                                pass);
120 }
121 
122 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs
123 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
124 /// that the sparse compiler can only generate indirect stores in
125 /// the last index, i.e. back().
126 static void genVectorStore(PatternRewriter &rewriter, Location loc, Value ptr,
127                            ArrayRef<Value> idxs, Value vmask, Value rhs) {
128   if (idxs.back().getType().isa<VectorType>()) {
129     SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
130     Value indexVec = idxs.back();
131     scalarArgs.back() = constantIndex(rewriter, loc, 0);
132     rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, vmask,
133                                        rhs);
134     return;
135   }
136   rewriter.create<vector::MaskedStoreOp>(loc, ptr, idxs, vmask, rhs);
137 }
138 
139 /// Detects a vectorizable reduction operations and returns the
140 /// combining kind of reduction on success in `kind`.
141 static bool isVectorizableReduction(Value red, Value iter,
142                                     vector::CombiningKind &kind) {
143   if (auto addf = red.getDefiningOp<arith::AddFOp>()) {
144     kind = vector::CombiningKind::ADD;
145     return addf->getOperand(0) == iter || addf->getOperand(1) == iter;
146   }
147   if (auto addi = red.getDefiningOp<arith::AddIOp>()) {
148     kind = vector::CombiningKind::ADD;
149     return addi->getOperand(0) == iter || addi->getOperand(1) == iter;
150   }
151   if (auto subf = red.getDefiningOp<arith::SubFOp>()) {
152     kind = vector::CombiningKind::ADD;
153     return subf->getOperand(0) == iter;
154   }
155   if (auto subi = red.getDefiningOp<arith::SubIOp>()) {
156     kind = vector::CombiningKind::ADD;
157     return subi->getOperand(0) == iter;
158   }
159   if (auto mulf = red.getDefiningOp<arith::MulFOp>()) {
160     kind = vector::CombiningKind::MUL;
161     return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter;
162   }
163   if (auto muli = red.getDefiningOp<arith::MulIOp>()) {
164     kind = vector::CombiningKind::MUL;
165     return muli->getOperand(0) == iter || muli->getOperand(1) == iter;
166   }
167   if (auto andi = red.getDefiningOp<arith::AndIOp>()) {
168     kind = vector::CombiningKind::AND;
169     return andi->getOperand(0) == iter || andi->getOperand(1) == iter;
170   }
171   if (auto ori = red.getDefiningOp<arith::OrIOp>()) {
172     kind = vector::CombiningKind::OR;
173     return ori->getOperand(0) == iter || ori->getOperand(1) == iter;
174   }
175   if (auto xori = red.getDefiningOp<arith::XOrIOp>()) {
176     kind = vector::CombiningKind::XOR;
177     return xori->getOperand(0) == iter || xori->getOperand(1) == iter;
178   }
179   return false;
180 }
181 
182 /// Generates an initial value for a vector reduction, following the scheme
183 /// given in Chapter 5 of "The Software Vectorization Handbook", where the
184 /// initial scalar value is correctly embedded in the vector reduction value,
185 /// and a straightforward horizontal reduction will complete the operation.
186 /// Value 'r' denotes the initial value of the reduction outside the loop.
187 static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
188                                 Value red, Value iter, Value r,
189                                 VectorType vtp) {
190   vector::CombiningKind kind;
191   if (!isVectorizableReduction(red, iter, kind))
192     llvm_unreachable("unknown reduction");
193   switch (kind) {
194   case vector::CombiningKind::ADD:
195   case vector::CombiningKind::XOR:
196     // Initialize reduction vector to: | 0 | .. | 0 | r |
197     return rewriter.create<vector::InsertElementOp>(
198         loc, r, constantZero(rewriter, loc, vtp),
199         constantIndex(rewriter, loc, 0));
200   case vector::CombiningKind::MUL:
201     // Initialize reduction vector to: | 1 | .. | 1 | r |
202     return rewriter.create<vector::InsertElementOp>(
203         loc, r, constantOne(rewriter, loc, vtp),
204         constantIndex(rewriter, loc, 0));
205   case vector::CombiningKind::AND:
206   case vector::CombiningKind::OR:
207     // Initialize reduction vector to: | r | .. | r | r |
208     return rewriter.create<vector::BroadcastOp>(loc, vtp, r);
209   default:
210     break;
211   }
212   llvm_unreachable("unknown reduction kind");
213 }
214 
215 /// This method is called twice to analyze and rewrite the given subscripts.
216 /// The first call (!codegen) does the analysis. Then, on success, the second
217 /// call (codegen) yields the proper vector form in the output parameter
218 /// vector 'idxs'. This mechanism ensures that analysis and rewriting code
219 /// stay in sync. Note that the analyis part is simple because the sparse
220 /// compiler only generates relatively simple subscript expressions.
221 ///
222 /// See https://llvm.org/docs/GetElementPtr.html for some background on
223 /// the complications described below.
224 ///
225 /// We need to generate a pointer/index load from the sparse storage scheme.
226 /// Narrower data types need to be zero extended before casting the value
227 /// into the index type used for looping and indexing.
228 ///
229 /// For the scalar case, subscripts simply zero extend narrower indices
230 /// into 64-bit values before casting to an index type without a performance
231 /// penalty. Indices that already are 64-bit, in theory, cannot express the
232 /// full range since the LLVM backend defines addressing in terms of an
233 /// unsigned pointer/signed index pair.
234 static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
235                                 VL vl, ValueRange subs, bool codegen,
236                                 Value vmask, SmallVectorImpl<Value> &idxs) {
237   for (auto sub : subs) {
238     // Invariant/loop indices simply pass through.
239     if (sub.dyn_cast<BlockArgument>() ||
240         sub.getDefiningOp()->getBlock() != &forOp.getRegion().front()) {
241       if (codegen)
242         idxs.push_back(sub);
243       continue; // success so far
244     }
245     // Look under the hood of casting.
246     auto cast = sub;
247     while (1) {
248       if (auto icast = cast.getDefiningOp<arith::IndexCastOp>())
249         cast = icast->getOperand(0);
250       else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
251         cast = ecast->getOperand(0);
252       else
253         break;
254     }
255     // Since the index vector is used in a subsequent gather/scatter
256     // operations, which effectively defines an unsigned pointer + signed
257     // index, we must zero extend the vector to an index width. For 8-bit
258     // and 16-bit values, an 32-bit index width suffices. For 32-bit values,
259     // zero extending the elements into 64-bit loses some performance since
260     // the 32-bit indexed gather/scatter is more efficient than the 64-bit
261     // index variant (if the negative 32-bit index space is unused, the
262     // enableSIMDIndex32 flag can preserve this performance). For 64-bit
263     // values, there is no good way to state that the indices are unsigned,
264     // which creates the potential of incorrect address calculations in the
265     // unlikely case we need such extremely large offsets.
266     if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
267       if (codegen) {
268         SmallVector<Value> idxs2(load.getIndices()); // no need to analyze
269         Location loc = forOp.getLoc();
270         Value vload =
271             genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
272         Type etp = vload.getType().cast<VectorType>().getElementType();
273         if (!etp.isa<IndexType>()) {
274           if (etp.getIntOrFloatBitWidth() < 32)
275             vload = rewriter.create<arith::ExtUIOp>(
276                 loc, vectorType(vl, rewriter.getI32Type()), vload);
277           else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32)
278             vload = rewriter.create<arith::ExtUIOp>(
279                 loc, vectorType(vl, rewriter.getI64Type()), vload);
280         }
281         idxs.push_back(vload);
282       }
283       continue; // success so far
284     }
285     return false;
286   }
287   return true;
288 }
289 
290 #define UNAOP(xxx)                                                             \
291   if (isa<xxx>(def)) {                                                         \
292     if (codegen)                                                               \
293       vexp = rewriter.create<xxx>(loc, vx);                                    \
294     return true;                                                               \
295   }
296 
297 #define TYPEDUNAOP(xxx)                                                        \
298   if (auto x = dyn_cast<xxx>(def)) {                                           \
299     if (codegen) {                                                             \
300       VectorType vtp = vectorType(vl, x.getType());                            \
301       vexp = rewriter.create<xxx>(loc, vtp, vx);                               \
302     }                                                                          \
303     return true;                                                               \
304   }
305 
306 #define BINOP(xxx)                                                             \
307   if (isa<xxx>(def)) {                                                         \
308     if (codegen)                                                               \
309       vexp = rewriter.create<xxx>(loc, vx, vy);                                \
310     return true;                                                               \
311   }
312 
313 /// This method is called twice to analyze and rewrite the given expression.
314 /// The first call (!codegen) does the analysis. Then, on success, the second
315 /// call (codegen) yields the proper vector form in the output parameter 'vexp'.
316 /// This mechanism ensures that analysis and rewriting code stay in sync. Note
317 /// that the analyis part is simple because the sparse compiler only generates
318 /// relatively simple expressions inside the for-loops.
319 static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
320                           Value exp, bool codegen, Value vmask, Value &vexp) {
321   Location loc = forOp.getLoc();
322   // Reject unsupported types.
323   if (!VectorType::isValidElementType(exp.getType()))
324     return false;
325   // A block argument is invariant/reduction/index.
326   if (auto arg = exp.dyn_cast<BlockArgument>()) {
327     if (arg == forOp.getInductionVar()) {
328       // We encountered a single, innermost index inside the computation,
329       // such as a[i] = i, which must convert to [i, i+1, ...].
330       if (codegen) {
331         VectorType vtp = vectorType(vl, arg.getType());
332         Value veci = rewriter.create<vector::BroadcastOp>(loc, vtp, arg);
333         Value incr;
334         if (vl.enableVLAVectorization) {
335           Type stepvty = vectorType(vl, rewriter.getI64Type());
336           Value stepv = rewriter.create<LLVM::StepVectorOp>(loc, stepvty);
337           incr = rewriter.create<arith::IndexCastOp>(loc, vtp, stepv);
338         } else {
339           SmallVector<APInt> integers;
340           for (unsigned i = 0, l = vl.vectorLength; i < l; i++)
341             integers.push_back(APInt(/*width=*/64, i));
342           auto values = DenseElementsAttr::get(vtp, integers);
343           incr = rewriter.create<arith::ConstantOp>(loc, vtp, values);
344         }
345         vexp = rewriter.create<arith::AddIOp>(loc, veci, incr);
346       }
347       return true;
348     } else {
349       // An invariant or reduction. In both cases, we treat this as an
350       // invariant value, and rely on later replacing and folding to
351       // construct a proper reduction chain for the latter case.
352       if (codegen)
353         vexp = genVectorInvariantValue(rewriter, vl, exp);
354       return true;
355     }
356   }
357   // Something defined outside the loop-body is invariant.
358   Operation *def = exp.getDefiningOp();
359   if (def->getBlock() != &forOp.getRegion().front()) {
360     if (codegen)
361       vexp = genVectorInvariantValue(rewriter, vl, exp);
362     return true;
363   }
364   // Proper load operations. These are either values involved in the
365   // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi],
366   // or index values inside the computation that are now fetched from
367   // the sparse storage index arrays, such as a[i] = i becomes
368   // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index
369   // and 'hi = lo + vl - 1'.
370   if (auto load = dyn_cast<memref::LoadOp>(def)) {
371     auto subs = load.getIndices();
372     SmallVector<Value> idxs;
373     if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
374       if (codegen)
375         vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
376       return true;
377     }
378     return false;
379   }
380   // Inside loop-body unary and binary operations. Note that it would be
381   // nicer if we could somehow test and build the operations in a more
382   // concise manner than just listing them all (although this way we know
383   // for certain that they can vectorize).
384   //
385   // TODO: avoid visiting CSEs multiple times
386   //
387   if (def->getNumOperands() == 1) {
388     Value vx;
389     if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
390                       vx)) {
391       UNAOP(math::AbsFOp)
392       UNAOP(math::AbsIOp)
393       UNAOP(math::CeilOp)
394       UNAOP(math::FloorOp)
395       UNAOP(math::SqrtOp)
396       UNAOP(math::ExpM1Op)
397       UNAOP(math::Log1pOp)
398       UNAOP(math::SinOp)
399       UNAOP(math::TanhOp)
400       UNAOP(arith::NegFOp)
401       TYPEDUNAOP(arith::TruncFOp)
402       TYPEDUNAOP(arith::ExtFOp)
403       TYPEDUNAOP(arith::FPToSIOp)
404       TYPEDUNAOP(arith::FPToUIOp)
405       TYPEDUNAOP(arith::SIToFPOp)
406       TYPEDUNAOP(arith::UIToFPOp)
407       TYPEDUNAOP(arith::ExtSIOp)
408       TYPEDUNAOP(arith::ExtUIOp)
409       TYPEDUNAOP(arith::IndexCastOp)
410       TYPEDUNAOP(arith::TruncIOp)
411       TYPEDUNAOP(arith::BitcastOp)
412     }
413   } else if (def->getNumOperands() == 2) {
414     Value vx, vy;
415     if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
416                       vx) &&
417         vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask,
418                       vy)) {
419       BINOP(arith::MulFOp)
420       BINOP(arith::MulIOp)
421       BINOP(arith::DivFOp)
422       BINOP(arith::DivSIOp)
423       BINOP(arith::DivUIOp)
424       BINOP(arith::AddFOp)
425       BINOP(arith::AddIOp)
426       BINOP(arith::SubFOp)
427       BINOP(arith::SubIOp)
428       BINOP(arith::AndIOp)
429       BINOP(arith::OrIOp)
430       BINOP(arith::XOrIOp)
431       // TODO: shift by invariant?
432     }
433   }
434   return false;
435 }
436 
437 #undef UNAOP
438 #undef TYPEDUNAOP
439 #undef BINOP
440 
441 /// This method is called twice to analyze and rewrite the given for-loop.
442 /// The first call (!codegen) does the analysis. Then, on success, the second
443 /// call (codegen) rewriters the IR into vector form. This mechanism ensures
444 /// that analysis and rewriting code stay in sync.
445 static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
446                           bool codegen) {
447   Location loc = forOp.getLoc();
448   Block &block = forOp.getRegion().front();
449   scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator());
450   auto &last = *++block.rbegin();
451   scf::ForOp forOpNew;
452 
453   // Perform initial set up during codegen (we know that the first analysis
454   // pass was successful). For reductions, we need to construct a completely
455   // new for-loop, since the incoming and outgoing reduction type
456   // changes into SIMD form. For stores, we can simply adjust the stride
457   // and insert in the existing for-loop. In both cases, we set up a vector
458   // mask for all operations which takes care of confining vectors to
459   // the original iteration space (later cleanup loops or other
460   // optimizations can take care of those).
461   Value vmask;
462   if (codegen) {
463     Value step = constantIndex(rewriter, loc, vl.vectorLength);
464     if (vl.enableVLAVectorization) {
465       Value vscale =
466           rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
467       step = rewriter.create<arith::MulIOp>(loc, vscale, step);
468     }
469     if (!yield.getResults().empty()) {
470       Value init = forOp.getInitArgs()[0];
471       VectorType vtp = vectorType(vl, init.getType());
472       Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
473                                        forOp.getRegionIterArg(0), init, vtp);
474       forOpNew = rewriter.create<scf::ForOp>(
475           loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
476       forOpNew->setAttr(
477           SparseTensorLoopEmitter::getLoopEmitterLoopAttrName(),
478           forOp->getAttr(
479               SparseTensorLoopEmitter::getLoopEmitterLoopAttrName()));
480       rewriter.setInsertionPointToStart(forOpNew.getBody());
481     } else {
482       forOp.setStep(step);
483       rewriter.setInsertionPoint(yield);
484     }
485     vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
486                           forOp.getLowerBound(), forOp.getUpperBound(), step);
487   }
488 
489   // Sparse for-loops either are terminated by a non-empty yield operation
490   // (reduction loop) or otherwise by a store operation (pararallel loop).
491   if (!yield.getResults().empty()) {
492     // Analyze/vectorize reduction.
493     if (yield->getNumOperands() != 1)
494       return false;
495     Value red = yield->getOperand(0);
496     Value iter = forOp.getRegionIterArg(0);
497     vector::CombiningKind kind;
498     Value vrhs;
499     if (isVectorizableReduction(red, iter, kind) &&
500         vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) {
501       if (codegen) {
502         Value partial = forOpNew.getResult(0);
503         Value vpass = genVectorInvariantValue(rewriter, vl, iter);
504         Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass);
505         rewriter.create<scf::YieldOp>(loc, vred);
506         rewriter.setInsertionPointAfter(forOpNew);
507         Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial);
508         // Now do some relinking (last one is not completely type safe
509         // but all bad ones are removed right away). This also folds away
510         // nop broadcast operations.
511         forOp.getResult(0).replaceAllUsesWith(vres);
512         forOp.getInductionVar().replaceAllUsesWith(forOpNew.getInductionVar());
513         forOp.getRegionIterArg(0).replaceAllUsesWith(
514             forOpNew.getRegionIterArg(0));
515         rewriter.eraseOp(forOp);
516       }
517       return true;
518     }
519   } else if (auto store = dyn_cast<memref::StoreOp>(last)) {
520     // Analyze/vectorize store operation.
521     auto subs = store.getIndices();
522     SmallVector<Value> idxs;
523     Value rhs = store.getValue();
524     Value vrhs;
525     if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
526         vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
527       if (codegen) {
528         genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
529         rewriter.eraseOp(store);
530       }
531       return true;
532     }
533   }
534 
535   assert(!codegen && "cannot call codegen when analysis failed");
536   return false;
537 }
538 
539 /// Basic for-loop vectorizer.
540 struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
541 public:
542   using OpRewritePattern<scf::ForOp>::OpRewritePattern;
543 
544   ForOpRewriter(MLIRContext *context, unsigned vectorLength,
545                 bool enableVLAVectorization, bool enableSIMDIndex32)
546       : OpRewritePattern(context),
547         vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {}
548 
549   LogicalResult matchAndRewrite(scf::ForOp op,
550                                 PatternRewriter &rewriter) const override {
551     // Check for single block, unit-stride for-loop that is generated by
552     // sparse compiler, which means no data dependence analysis is required,
553     // and its loop-body is very restricted in form.
554     if (!op.getRegion().hasOneBlock() || !isIntValue(op.getStep(), 1) ||
555         !op->hasAttr(SparseTensorLoopEmitter::getLoopEmitterLoopAttrName()))
556       return failure();
557     // Analyze (!codegen) and rewrite (codegen) loop-body.
558     if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) &&
559         vectorizeStmt(rewriter, op, vl, /*codegen=*/true))
560       return success();
561     return failure();
562   }
563 
564 private:
565   const VL vl;
566 };
567 
568 /// Reduction chain cleanup.
569 ///   v = for { }
570 ///   s = vsum(v)               v = for { }
571 ///   u = expand(s)       ->    for (v) { }
572 ///   for (u) { }
573 template <typename VectorOp>
574 struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
575 public:
576   using OpRewritePattern<VectorOp>::OpRewritePattern;
577 
578   LogicalResult matchAndRewrite(VectorOp op,
579                                 PatternRewriter &rewriter) const override {
580     Value inp = op.getSource();
581     if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
582       if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
583         if (forOp->hasAttr(
584                 SparseTensorLoopEmitter::getLoopEmitterLoopAttrName())) {
585           rewriter.replaceOp(op, redOp.getVector());
586           return success();
587         }
588       }
589     }
590     return failure();
591   }
592 };
593 
594 } // namespace
595 
596 //===----------------------------------------------------------------------===//
597 // Public method for populating vectorization rules.
598 //===----------------------------------------------------------------------===//
599 
600 /// Populates the given patterns list with vectorization rules.
601 void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
602                                                unsigned vectorLength,
603                                                bool enableVLAVectorization,
604                                                bool enableSIMDIndex32) {
605   patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
606                               enableVLAVectorization, enableSIMDIndex32);
607   patterns.add<ReducChainRewriter<vector::InsertElementOp>,
608                ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
609 }
610