xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (revision 781eabeb40b8e47e3a46b0b927784e63f0aad9ab)
1 //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts loops generated by the sparse compiler into a form that
10 // can exploit SIMD instructions of the target architecture. Note that this pass
11 // ensures the sparse compiler can generate efficient SIMD (including ArmSVE
12 // support) with proper separation of concerns as far as sparsification and
13 // vectorization is concerned. However, this pass is not the final abstraction
14 // level we want, and not the general vectorizer we want either. It forms a good
15 // stepping stone for incremental future improvements though.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "CodegenUtils.h"
20 #include "LoopEmitter.h"
21 
22 #include "mlir/Dialect/Affine/IR/AffineOps.h"
23 #include "mlir/Dialect/Arith/IR/Arith.h"
24 #include "mlir/Dialect/Complex/IR/Complex.h"
25 #include "mlir/Dialect/Math/IR/Math.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/Dialect/SCF/IR/SCF.h"
28 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
29 #include "mlir/Dialect/Vector/IR/VectorOps.h"
30 #include "mlir/IR/Matchers.h"
31 
32 using namespace mlir;
33 using namespace mlir::sparse_tensor;
34 
35 namespace {
36 
37 /// Target SIMD properties:
38 ///   vectorLength: # packed data elements (viz. vector<16xf32> has length 16)
39 ///   enableVLAVectorization: enables scalable vectors (viz. ARMSve)
40 ///   enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency
41 struct VL {
42   unsigned vectorLength;
43   bool enableVLAVectorization;
44   bool enableSIMDIndex32;
45 };
46 
47 /// Helper to test for given index value.
48 static bool isIntValue(Value val, int64_t idx) {
49   if (auto ival = getConstantIntValue(val))
50     return *ival == idx;
51   return false;
52 }
53 
54 /// Helper test for invariant value (defined outside given block).
55 static bool isInvariantValue(Value val, Block *block) {
56   return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block;
57 }
58 
59 /// Helper test for invariant argument (defined outside given block).
60 static bool isInvariantArg(BlockArgument arg, Block *block) {
61   return arg.getOwner() != block;
62 }
63 
64 /// Constructs vector type for element type.
65 static VectorType vectorType(VL vl, Type etp) {
66   unsigned numScalableDims = vl.enableVLAVectorization;
67   return VectorType::get(vl.vectorLength, etp, numScalableDims);
68 }
69 
70 /// Constructs vector type from pointer.
71 static VectorType vectorType(VL vl, Value ptr) {
72   return vectorType(vl, ptr.getType().cast<MemRefType>().getElementType());
73 }
74 
75 /// Constructs vector iteration mask.
76 static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl,
77                            Value iv, Value lo, Value hi, Value step) {
78   VectorType mtp = vectorType(vl, rewriter.getI1Type());
79   // Special case if the vector length evenly divides the trip count (for
80   // example, "for i = 0, 128, 16"). A constant all-true mask is generated
81   // so that all subsequent masked memory operations are immediately folded
82   // into unconditional memory operations.
83   IntegerAttr loInt, hiInt, stepInt;
84   if (matchPattern(lo, m_Constant(&loInt)) &&
85       matchPattern(hi, m_Constant(&hiInt)) &&
86       matchPattern(step, m_Constant(&stepInt))) {
87     if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
88       Value trueVal = constantI1(rewriter, loc, true);
89       return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal);
90     }
91   }
92   // Otherwise, generate a vector mask that avoids overrunning the upperbound
93   // during vector execution. Here we rely on subsequent loop optimizations to
94   // avoid executing the mask in all iterations, for example, by splitting the
95   // loop into an unconditional vector loop and a scalar cleanup loop.
96   auto min = AffineMap::get(
97       /*dimCount=*/2, /*symbolCount=*/1,
98       {rewriter.getAffineSymbolExpr(0),
99        rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
100       rewriter.getContext());
101   Value end =
102       rewriter.createOrFold<AffineMinOp>(loc, min, ValueRange{hi, iv, step});
103   return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
104 }
105 
106 /// Generates a vectorized invariant. Here we rely on subsequent loop
107 /// optimizations to hoist the invariant broadcast out of the vector loop.
108 static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl,
109                                      Value val) {
110   VectorType vtp = vectorType(vl, val.getType());
111   return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
112 }
113 
114 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi],
115 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
116 /// that the sparse compiler can only generate indirect loads in
117 /// the last index, i.e. back().
118 static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
119                            Value ptr, ArrayRef<Value> idxs, Value vmask) {
120   VectorType vtp = vectorType(vl, ptr);
121   Value pass = constantZero(rewriter, loc, vtp);
122   if (idxs.back().getType().isa<VectorType>()) {
123     SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
124     Value indexVec = idxs.back();
125     scalarArgs.back() = constantIndex(rewriter, loc, 0);
126     return rewriter.create<vector::GatherOp>(loc, vtp, ptr, scalarArgs,
127                                              indexVec, vmask, pass);
128   }
129   return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, idxs, vmask,
130                                                pass);
131 }
132 
133 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs
134 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
135 /// that the sparse compiler can only generate indirect stores in
136 /// the last index, i.e. back().
137 static void genVectorStore(PatternRewriter &rewriter, Location loc, Value ptr,
138                            ArrayRef<Value> idxs, Value vmask, Value rhs) {
139   if (idxs.back().getType().isa<VectorType>()) {
140     SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
141     Value indexVec = idxs.back();
142     scalarArgs.back() = constantIndex(rewriter, loc, 0);
143     rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, vmask,
144                                        rhs);
145     return;
146   }
147   rewriter.create<vector::MaskedStoreOp>(loc, ptr, idxs, vmask, rhs);
148 }
149 
150 /// Detects a vectorizable reduction operations and returns the
151 /// combining kind of reduction on success in `kind`.
152 static bool isVectorizableReduction(Value red, Value iter,
153                                     vector::CombiningKind &kind) {
154   if (auto addf = red.getDefiningOp<arith::AddFOp>()) {
155     kind = vector::CombiningKind::ADD;
156     return addf->getOperand(0) == iter || addf->getOperand(1) == iter;
157   }
158   if (auto addi = red.getDefiningOp<arith::AddIOp>()) {
159     kind = vector::CombiningKind::ADD;
160     return addi->getOperand(0) == iter || addi->getOperand(1) == iter;
161   }
162   if (auto subf = red.getDefiningOp<arith::SubFOp>()) {
163     kind = vector::CombiningKind::ADD;
164     return subf->getOperand(0) == iter;
165   }
166   if (auto subi = red.getDefiningOp<arith::SubIOp>()) {
167     kind = vector::CombiningKind::ADD;
168     return subi->getOperand(0) == iter;
169   }
170   if (auto mulf = red.getDefiningOp<arith::MulFOp>()) {
171     kind = vector::CombiningKind::MUL;
172     return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter;
173   }
174   if (auto muli = red.getDefiningOp<arith::MulIOp>()) {
175     kind = vector::CombiningKind::MUL;
176     return muli->getOperand(0) == iter || muli->getOperand(1) == iter;
177   }
178   if (auto andi = red.getDefiningOp<arith::AndIOp>()) {
179     kind = vector::CombiningKind::AND;
180     return andi->getOperand(0) == iter || andi->getOperand(1) == iter;
181   }
182   if (auto ori = red.getDefiningOp<arith::OrIOp>()) {
183     kind = vector::CombiningKind::OR;
184     return ori->getOperand(0) == iter || ori->getOperand(1) == iter;
185   }
186   if (auto xori = red.getDefiningOp<arith::XOrIOp>()) {
187     kind = vector::CombiningKind::XOR;
188     return xori->getOperand(0) == iter || xori->getOperand(1) == iter;
189   }
190   return false;
191 }
192 
193 /// Generates an initial value for a vector reduction, following the scheme
194 /// given in Chapter 5 of "The Software Vectorization Handbook", where the
195 /// initial scalar value is correctly embedded in the vector reduction value,
196 /// and a straightforward horizontal reduction will complete the operation.
197 /// Value 'r' denotes the initial value of the reduction outside the loop.
198 static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
199                                 Value red, Value iter, Value r,
200                                 VectorType vtp) {
201   vector::CombiningKind kind;
202   if (!isVectorizableReduction(red, iter, kind))
203     llvm_unreachable("unknown reduction");
204   switch (kind) {
205   case vector::CombiningKind::ADD:
206   case vector::CombiningKind::XOR:
207     // Initialize reduction vector to: | 0 | .. | 0 | r |
208     return rewriter.create<vector::InsertElementOp>(
209         loc, r, constantZero(rewriter, loc, vtp),
210         constantIndex(rewriter, loc, 0));
211   case vector::CombiningKind::MUL:
212     // Initialize reduction vector to: | 1 | .. | 1 | r |
213     return rewriter.create<vector::InsertElementOp>(
214         loc, r, constantOne(rewriter, loc, vtp),
215         constantIndex(rewriter, loc, 0));
216   case vector::CombiningKind::AND:
217   case vector::CombiningKind::OR:
218     // Initialize reduction vector to: | r | .. | r | r |
219     return rewriter.create<vector::BroadcastOp>(loc, vtp, r);
220   default:
221     break;
222   }
223   llvm_unreachable("unknown reduction kind");
224 }
225 
226 /// This method is called twice to analyze and rewrite the given subscripts.
227 /// The first call (!codegen) does the analysis. Then, on success, the second
228 /// call (codegen) yields the proper vector form in the output parameter
229 /// vector 'idxs'. This mechanism ensures that analysis and rewriting code
230 /// stay in sync. Note that the analyis part is simple because the sparse
231 /// compiler only generates relatively simple subscript expressions.
232 ///
233 /// See https://llvm.org/docs/GetElementPtr.html for some background on
234 /// the complications described below.
235 ///
236 /// We need to generate a pointer/index load from the sparse storage scheme.
237 /// Narrower data types need to be zero extended before casting the value
238 /// into the index type used for looping and indexing.
239 ///
240 /// For the scalar case, subscripts simply zero extend narrower indices
241 /// into 64-bit values before casting to an index type without a performance
242 /// penalty. Indices that already are 64-bit, in theory, cannot express the
243 /// full range since the LLVM backend defines addressing in terms of an
244 /// unsigned pointer/signed index pair.
245 static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
246                                 VL vl, ValueRange subs, bool codegen,
247                                 Value vmask, SmallVectorImpl<Value> &idxs) {
248   unsigned d = 0;
249   unsigned dim = subs.size();
250   Block *block = &forOp.getRegion().front();
251   for (auto sub : subs) {
252     bool innermost = ++d == dim;
253     // Invariant subscripts in outer dimensions simply pass through.
254     // Note that we rely on LICM to hoist loads where all subscripts
255     // are invariant in the innermost loop.
256     // Example:
257     //   a[inv][i] for inv
258     if (isInvariantValue(sub, block)) {
259       if (innermost)
260         return false;
261       if (codegen)
262         idxs.push_back(sub);
263       continue; // success so far
264     }
265     // Invariant block arguments (including outer loop indices) in outer
266     // dimensions simply pass through. Direct loop indices in the
267     // innermost loop simply pass through as well.
268     // Example:
269     //   a[i][j] for both i and j
270     if (auto arg = sub.dyn_cast<BlockArgument>()) {
271       if (isInvariantArg(arg, block) == innermost)
272         return false;
273       if (codegen)
274         idxs.push_back(sub);
275       continue; // success so far
276     }
277     // Look under the hood of casting.
278     auto cast = sub;
279     while (true) {
280       if (auto icast = cast.getDefiningOp<arith::IndexCastOp>())
281         cast = icast->getOperand(0);
282       else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
283         cast = ecast->getOperand(0);
284       else
285         break;
286     }
287     // Since the index vector is used in a subsequent gather/scatter
288     // operations, which effectively defines an unsigned pointer + signed
289     // index, we must zero extend the vector to an index width. For 8-bit
290     // and 16-bit values, an 32-bit index width suffices. For 32-bit values,
291     // zero extending the elements into 64-bit loses some performance since
292     // the 32-bit indexed gather/scatter is more efficient than the 64-bit
293     // index variant (if the negative 32-bit index space is unused, the
294     // enableSIMDIndex32 flag can preserve this performance). For 64-bit
295     // values, there is no good way to state that the indices are unsigned,
296     // which creates the potential of incorrect address calculations in the
297     // unlikely case we need such extremely large offsets.
298     // Example:
299     //    a[ ind[i] ]
300     if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
301       if (!innermost)
302         return false;
303       if (codegen) {
304         SmallVector<Value> idxs2(load.getIndices()); // no need to analyze
305         Location loc = forOp.getLoc();
306         Value vload =
307             genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
308         Type etp = vload.getType().cast<VectorType>().getElementType();
309         if (!etp.isa<IndexType>()) {
310           if (etp.getIntOrFloatBitWidth() < 32)
311             vload = rewriter.create<arith::ExtUIOp>(
312                 loc, vectorType(vl, rewriter.getI32Type()), vload);
313           else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32)
314             vload = rewriter.create<arith::ExtUIOp>(
315                 loc, vectorType(vl, rewriter.getI64Type()), vload);
316         }
317         idxs.push_back(vload);
318       }
319       continue; // success so far
320     }
321     // Address calculation 'i = add inv, idx' (after LICM).
322     // Example:
323     //    a[base + i]
324     if (auto load = cast.getDefiningOp<arith::AddIOp>()) {
325       Value inv = load.getOperand(0);
326       Value idx = load.getOperand(1);
327       if (isInvariantValue(inv, block)) {
328         if (auto arg = idx.dyn_cast<BlockArgument>()) {
329           if (isInvariantArg(arg, block) || !innermost)
330             return false;
331           if (codegen)
332             idxs.push_back(
333                 rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx));
334           continue; // success so far
335         }
336       }
337     }
338     return false;
339   }
340   return true;
341 }
342 
343 #define UNAOP(xxx)                                                             \
344   if (isa<xxx>(def)) {                                                         \
345     if (codegen)                                                               \
346       vexp = rewriter.create<xxx>(loc, vx);                                    \
347     return true;                                                               \
348   }
349 
350 #define TYPEDUNAOP(xxx)                                                        \
351   if (auto x = dyn_cast<xxx>(def)) {                                           \
352     if (codegen) {                                                             \
353       VectorType vtp = vectorType(vl, x.getType());                            \
354       vexp = rewriter.create<xxx>(loc, vtp, vx);                               \
355     }                                                                          \
356     return true;                                                               \
357   }
358 
359 #define BINOP(xxx)                                                             \
360   if (isa<xxx>(def)) {                                                         \
361     if (codegen)                                                               \
362       vexp = rewriter.create<xxx>(loc, vx, vy);                                \
363     return true;                                                               \
364   }
365 
366 /// This method is called twice to analyze and rewrite the given expression.
367 /// The first call (!codegen) does the analysis. Then, on success, the second
368 /// call (codegen) yields the proper vector form in the output parameter 'vexp'.
369 /// This mechanism ensures that analysis and rewriting code stay in sync. Note
370 /// that the analyis part is simple because the sparse compiler only generates
371 /// relatively simple expressions inside the for-loops.
372 static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
373                           Value exp, bool codegen, Value vmask, Value &vexp) {
374   Location loc = forOp.getLoc();
375   // Reject unsupported types.
376   if (!VectorType::isValidElementType(exp.getType()))
377     return false;
378   // A block argument is invariant/reduction/index.
379   if (auto arg = exp.dyn_cast<BlockArgument>()) {
380     if (arg == forOp.getInductionVar()) {
381       // We encountered a single, innermost index inside the computation,
382       // such as a[i] = i, which must convert to [i, i+1, ...].
383       if (codegen) {
384         VectorType vtp = vectorType(vl, arg.getType());
385         Value veci = rewriter.create<vector::BroadcastOp>(loc, vtp, arg);
386         Value incr;
387         if (vl.enableVLAVectorization) {
388           Type stepvty = vectorType(vl, rewriter.getI64Type());
389           Value stepv = rewriter.create<LLVM::StepVectorOp>(loc, stepvty);
390           incr = rewriter.create<arith::IndexCastOp>(loc, vtp, stepv);
391         } else {
392           SmallVector<APInt> integers;
393           for (unsigned i = 0, l = vl.vectorLength; i < l; i++)
394             integers.push_back(APInt(/*width=*/64, i));
395           auto values = DenseElementsAttr::get(vtp, integers);
396           incr = rewriter.create<arith::ConstantOp>(loc, vtp, values);
397         }
398         vexp = rewriter.create<arith::AddIOp>(loc, veci, incr);
399       }
400       return true;
401     } // An invariant or reduction. In both cases, we treat this as an
402     // invariant value, and rely on later replacing and folding to
403     // construct a proper reduction chain for the latter case.
404     if (codegen)
405       vexp = genVectorInvariantValue(rewriter, vl, exp);
406     return true;
407   }
408   // Something defined outside the loop-body is invariant.
409   Operation *def = exp.getDefiningOp();
410   Block *block = &forOp.getRegion().front();
411   if (def->getBlock() != block) {
412     if (codegen)
413       vexp = genVectorInvariantValue(rewriter, vl, exp);
414     return true;
415   }
416   // Proper load operations. These are either values involved in the
417   // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi],
418   // or index values inside the computation that are now fetched from
419   // the sparse storage index arrays, such as a[i] = i becomes
420   // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index
421   // and 'hi = lo + vl - 1'.
422   if (auto load = dyn_cast<memref::LoadOp>(def)) {
423     auto subs = load.getIndices();
424     SmallVector<Value> idxs;
425     if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
426       if (codegen)
427         vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
428       return true;
429     }
430     return false;
431   }
432   // Inside loop-body unary and binary operations. Note that it would be
433   // nicer if we could somehow test and build the operations in a more
434   // concise manner than just listing them all (although this way we know
435   // for certain that they can vectorize).
436   //
437   // TODO: avoid visiting CSEs multiple times
438   //
439   if (def->getNumOperands() == 1) {
440     Value vx;
441     if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
442                       vx)) {
443       UNAOP(math::AbsFOp)
444       UNAOP(math::AbsIOp)
445       UNAOP(math::CeilOp)
446       UNAOP(math::FloorOp)
447       UNAOP(math::SqrtOp)
448       UNAOP(math::ExpM1Op)
449       UNAOP(math::Log1pOp)
450       UNAOP(math::SinOp)
451       UNAOP(math::TanhOp)
452       UNAOP(arith::NegFOp)
453       TYPEDUNAOP(arith::TruncFOp)
454       TYPEDUNAOP(arith::ExtFOp)
455       TYPEDUNAOP(arith::FPToSIOp)
456       TYPEDUNAOP(arith::FPToUIOp)
457       TYPEDUNAOP(arith::SIToFPOp)
458       TYPEDUNAOP(arith::UIToFPOp)
459       TYPEDUNAOP(arith::ExtSIOp)
460       TYPEDUNAOP(arith::ExtUIOp)
461       TYPEDUNAOP(arith::IndexCastOp)
462       TYPEDUNAOP(arith::TruncIOp)
463       TYPEDUNAOP(arith::BitcastOp)
464       // TODO: complex?
465     }
466   } else if (def->getNumOperands() == 2) {
467     Value vx, vy;
468     if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
469                       vx) &&
470         vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask,
471                       vy)) {
472       // We only accept shift-by-invariant (where the same shift factor applies
473       // to all packed elements). In the vector dialect, this is still
474       // represented with an expanded vector at the right-hand-side, however,
475       // so that we do not have to special case the code generation.
476       if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) ||
477           isa<arith::ShRSIOp>(def)) {
478         Value shiftFactor = def->getOperand(1);
479         if (!isInvariantValue(shiftFactor, block))
480           return false;
481       }
482       // Generate code.
483       BINOP(arith::MulFOp)
484       BINOP(arith::MulIOp)
485       BINOP(arith::DivFOp)
486       BINOP(arith::DivSIOp)
487       BINOP(arith::DivUIOp)
488       BINOP(arith::AddFOp)
489       BINOP(arith::AddIOp)
490       BINOP(arith::SubFOp)
491       BINOP(arith::SubIOp)
492       BINOP(arith::AndIOp)
493       BINOP(arith::OrIOp)
494       BINOP(arith::XOrIOp)
495       BINOP(arith::ShLIOp)
496       BINOP(arith::ShRUIOp)
497       BINOP(arith::ShRSIOp)
498       // TODO: complex?
499     }
500   }
501   return false;
502 }
503 
504 #undef UNAOP
505 #undef TYPEDUNAOP
506 #undef BINOP
507 
508 /// This method is called twice to analyze and rewrite the given for-loop.
509 /// The first call (!codegen) does the analysis. Then, on success, the second
510 /// call (codegen) rewriters the IR into vector form. This mechanism ensures
511 /// that analysis and rewriting code stay in sync.
512 static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
513                           bool codegen) {
514   Location loc = forOp.getLoc();
515   Block &block = forOp.getRegion().front();
516   scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator());
517   auto &last = *++block.rbegin();
518   scf::ForOp forOpNew;
519 
520   // Perform initial set up during codegen (we know that the first analysis
521   // pass was successful). For reductions, we need to construct a completely
522   // new for-loop, since the incoming and outgoing reduction type
523   // changes into SIMD form. For stores, we can simply adjust the stride
524   // and insert in the existing for-loop. In both cases, we set up a vector
525   // mask for all operations which takes care of confining vectors to
526   // the original iteration space (later cleanup loops or other
527   // optimizations can take care of those).
528   Value vmask;
529   if (codegen) {
530     Value step = constantIndex(rewriter, loc, vl.vectorLength);
531     if (vl.enableVLAVectorization) {
532       Value vscale =
533           rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
534       step = rewriter.create<arith::MulIOp>(loc, vscale, step);
535     }
536     if (!yield.getResults().empty()) {
537       Value init = forOp.getInitArgs()[0];
538       VectorType vtp = vectorType(vl, init.getType());
539       Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
540                                        forOp.getRegionIterArg(0), init, vtp);
541       forOpNew = rewriter.create<scf::ForOp>(
542           loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
543       forOpNew->setAttr(
544           LoopEmitter::getLoopEmitterLoopAttrName(),
545           forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
546       rewriter.setInsertionPointToStart(forOpNew.getBody());
547     } else {
548       forOp.setStep(step);
549       rewriter.setInsertionPoint(yield);
550     }
551     vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
552                           forOp.getLowerBound(), forOp.getUpperBound(), step);
553   }
554 
555   // Sparse for-loops either are terminated by a non-empty yield operation
556   // (reduction loop) or otherwise by a store operation (pararallel loop).
557   if (!yield.getResults().empty()) {
558     // Analyze/vectorize reduction.
559     if (yield->getNumOperands() != 1)
560       return false;
561     Value red = yield->getOperand(0);
562     Value iter = forOp.getRegionIterArg(0);
563     vector::CombiningKind kind;
564     Value vrhs;
565     if (isVectorizableReduction(red, iter, kind) &&
566         vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) {
567       if (codegen) {
568         Value partial = forOpNew.getResult(0);
569         Value vpass = genVectorInvariantValue(rewriter, vl, iter);
570         Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass);
571         rewriter.create<scf::YieldOp>(loc, vred);
572         rewriter.setInsertionPointAfter(forOpNew);
573         Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial);
574         // Now do some relinking (last one is not completely type safe
575         // but all bad ones are removed right away). This also folds away
576         // nop broadcast operations.
577         forOp.getResult(0).replaceAllUsesWith(vres);
578         forOp.getInductionVar().replaceAllUsesWith(forOpNew.getInductionVar());
579         forOp.getRegionIterArg(0).replaceAllUsesWith(
580             forOpNew.getRegionIterArg(0));
581         rewriter.eraseOp(forOp);
582       }
583       return true;
584     }
585   } else if (auto store = dyn_cast<memref::StoreOp>(last)) {
586     // Analyze/vectorize store operation.
587     auto subs = store.getIndices();
588     SmallVector<Value> idxs;
589     Value rhs = store.getValue();
590     Value vrhs;
591     if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
592         vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
593       if (codegen) {
594         genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
595         rewriter.eraseOp(store);
596       }
597       return true;
598     }
599   }
600 
601   assert(!codegen && "cannot call codegen when analysis failed");
602   return false;
603 }
604 
605 /// Basic for-loop vectorizer.
606 struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
607 public:
608   using OpRewritePattern<scf::ForOp>::OpRewritePattern;
609 
610   ForOpRewriter(MLIRContext *context, unsigned vectorLength,
611                 bool enableVLAVectorization, bool enableSIMDIndex32)
612       : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization,
613                                       enableSIMDIndex32} {}
614 
615   LogicalResult matchAndRewrite(scf::ForOp op,
616                                 PatternRewriter &rewriter) const override {
617     // Check for single block, unit-stride for-loop that is generated by
618     // sparse compiler, which means no data dependence analysis is required,
619     // and its loop-body is very restricted in form.
620     if (!op.getRegion().hasOneBlock() || !isIntValue(op.getStep(), 1) ||
621         !op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()))
622       return failure();
623     // Analyze (!codegen) and rewrite (codegen) loop-body.
624     if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) &&
625         vectorizeStmt(rewriter, op, vl, /*codegen=*/true))
626       return success();
627     return failure();
628   }
629 
630 private:
631   const VL vl;
632 };
633 
634 /// Reduction chain cleanup.
635 ///   v = for { }
636 ///   s = vsum(v)               v = for { }
637 ///   u = expand(s)       ->    for (v) { }
638 ///   for (u) { }
639 template <typename VectorOp>
640 struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
641 public:
642   using OpRewritePattern<VectorOp>::OpRewritePattern;
643 
644   LogicalResult matchAndRewrite(VectorOp op,
645                                 PatternRewriter &rewriter) const override {
646     Value inp = op.getSource();
647     if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
648       if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
649         if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
650           rewriter.replaceOp(op, redOp.getVector());
651           return success();
652         }
653       }
654     }
655     return failure();
656   }
657 };
658 
659 } // namespace
660 
661 //===----------------------------------------------------------------------===//
662 // Public method for populating vectorization rules.
663 //===----------------------------------------------------------------------===//
664 
665 /// Populates the given patterns list with vectorization rules.
666 void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
667                                                unsigned vectorLength,
668                                                bool enableVLAVectorization,
669                                                bool enableSIMDIndex32) {
670   assert(vectorLength > 0);
671   patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
672                               enableVLAVectorization, enableSIMDIndex32);
673   patterns.add<ReducChainRewriter<vector::InsertElementOp>,
674                ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
675 }
676