xref: /llvm-project/mlir/lib/Dialect/Linalg/Utils/Utils.cpp (revision 5550c821897ab77e664977121a0e90ad5be1ff59)
1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Utils/Utils.h"
14 
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19 #include "mlir/Dialect/Affine/LoopUtils.h"
20 #include "mlir/Dialect/Arith/IR/Arith.h"
21 #include "mlir/Dialect/Arith/Utils/Utils.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"
23 #include "mlir/Dialect/Linalg/IR/Linalg.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/Tensor/IR/Tensor.h"
27 #include "mlir/Dialect/Tensor/Utils/Utils.h"
28 #include "mlir/Dialect/Utils/IndexingUtils.h"
29 #include "mlir/Dialect/Utils/StaticValueUtils.h"
30 #include "mlir/IR/AffineExpr.h"
31 #include "mlir/IR/AffineExprVisitor.h"
32 #include "mlir/IR/AffineMap.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/IR/OpImplementation.h"
35 #include "mlir/Pass/Pass.h"
36 #include "llvm/ADT/SetOperations.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 #include "llvm/Support/Debug.h"
39 #include <optional>
40 
41 #define DEBUG_TYPE "linalg-utils"
42 
43 using namespace mlir;
44 using namespace presburger;
45 using namespace mlir::affine;
46 using namespace mlir::linalg;
47 using namespace mlir::scf;
48 
49 namespace {
50 
51 // Helper visitor to determine whether an AffineExpr is tiled.
52 // This is achieved by traversing every AffineDimExpr with position `pos` and
53 // checking whether the corresponding `tileSizes[pos]` is non-zero.
54 // This also enforces only positive coefficients occur in multiplications.
55 //
56 // Example:
57 //   `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
58 //
59 struct TileCheck : public AffineExprVisitor<TileCheck> {
60   TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
61 
62   void visitDimExpr(AffineDimExpr expr) {
63     isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]);
64   }
65   void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
66     visit(expr.getLHS());
67     visit(expr.getRHS());
68     if (expr.getKind() == mlir::AffineExprKind::Mul)
69       assert(expr.getRHS().cast<AffineConstantExpr>().getValue() > 0 &&
70              "nonpositive multiplying coefficient");
71   }
72   bool isTiled = false;
73   ArrayRef<OpFoldResult> tileSizes;
74 };
75 
76 } // namespace
77 
78 static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) {
79   if (!expr)
80     return false;
81   TileCheck t(tileSizes);
82   t.visit(expr);
83   return t.isTiled;
84 }
85 
86 // Checks whether the `map  varies with respect to a non-zero `tileSize`.
87 static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) {
88   if (!map)
89     return false;
90   for (unsigned r = 0; r < map.getNumResults(); ++r)
91     if (isTiled(map.getResult(r), tileSizes))
92       return true;
93   return false;
94 }
95 
96 std::optional<RegionMatcher::BinaryOpKind>
97 RegionMatcher::matchAsScalarBinaryOp(GenericOp op) {
98   auto &region = op.getRegion();
99   if (!llvm::hasSingleElement(region))
100     return std::nullopt;
101 
102   Block &block = region.front();
103   if (block.getNumArguments() != 2 ||
104       !block.getArgument(0).getType().isSignlessIntOrFloat() ||
105       !block.getArgument(1).getType().isSignlessIntOrFloat())
106     return std::nullopt;
107 
108   auto &ops = block.getOperations();
109   if (!llvm::hasSingleElement(block.without_terminator()))
110     return std::nullopt;
111 
112   using mlir::matchers::m_Val;
113   auto a = m_Val(block.getArgument(0));
114   auto b = m_Val(block.getArgument(1));
115 
116   auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b));
117   if (addPattern.match(&ops.back()))
118     return BinaryOpKind::IAdd;
119 
120   return std::nullopt;
121 }
122 
123 /// Explicit instantiation of loop nest generator for different loop types.
124 template struct mlir::linalg::GenerateLoopNest<scf::ForOp>;
125 template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>;
126 template struct mlir::linalg::GenerateLoopNest<AffineForOp>;
127 
128 /// Given a list of subview ranges, extract individual values for lower, upper
129 /// bounds and steps and put them into the corresponding vectors.
130 static void unpackRanges(OpBuilder &builder, Location loc,
131                          ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
132                          SmallVectorImpl<Value> &ubs,
133                          SmallVectorImpl<Value> &steps) {
134   for (Range range : ranges) {
135     lbs.emplace_back(
136         getValueOrCreateConstantIndexOp(builder, loc, range.offset));
137     ubs.emplace_back(getValueOrCreateConstantIndexOp(builder, loc, range.size));
138     steps.emplace_back(
139         getValueOrCreateConstantIndexOp(builder, loc, range.stride));
140   }
141 }
142 
143 //===----------------------------------------------------------------------===//
144 // Utilities for inferring various semantics properties of Linalg ops.
145 //===----------------------------------------------------------------------===//
146 
147 DenseSet<int64_t> mlir::linalg::findPermutationsIndexingOperand(
148     LinalgOp linalgOp, OpOperand *opOperand, utils::IteratorType iter) {
149   DenseSet<int64_t> res;
150   assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
151   AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
152   for (AffineExpr e : indexingMap.getResults()) {
153     if (auto d = e.dyn_cast<AffineDimExpr>()) {
154       if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
155           llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
156             return e.isFunctionOfDim(d.getPosition());
157           }) == 1)
158         res.insert(d.getPosition());
159     }
160   }
161   return res;
162 }
163 
164 namespace {
165 auto par = utils::IteratorType::parallel;
166 auto red = utils::IteratorType::reduction;
167 } // namespace
168 
169 bool mlir::linalg::containsMostMinorMatmul(LinalgOp linalgOp) {
170   FailureOr<EmbeddedMatmulDimsCandidates> res = inferMatmulDims(linalgOp);
171   if (failed(res))
172     return false;
173   int64_t numLoops = linalgOp.getNumLoops();
174   for (const DenseSet<int64_t> &s : {res->mPos, res->nPos, res->kPos}) {
175     if (s.contains(numLoops - 3) || s.contains(numLoops - 2) ||
176         s.contains(numLoops - 1))
177       continue;
178     return false;
179   }
180   return true;
181 }
182 
183 FailureOr<EmbeddedMatmulDimsCandidates>
184 mlir::linalg::inferMatmulDims(LinalgOp linalgOp) {
185   if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
186     return failure();
187 
188   DenseSet<int64_t> a = findPermutationsIndexingOperand(
189       linalgOp, linalgOp.getDpsInputOperand(0), par);
190   DenseSet<int64_t> b = findPermutationsIndexingOperand(
191       linalgOp, linalgOp.getDpsInputOperand(1), par);
192   DenseSet<int64_t> c = findPermutationsIndexingOperand(
193       linalgOp, linalgOp.getDpsInitOperand(0), par);
194 
195   // A & C - B are the iterators involved in an outer-product along A (the LHS).
196   DenseSet<int64_t> ac = a;
197   llvm::set_intersect(ac, c);
198   llvm::set_subtract(ac, b);
199   // B & C - A are the iterators involved in an outer-product along B (the RHS).
200   DenseSet<int64_t> bc = b;
201   llvm::set_intersect(bc, c);
202   llvm::set_subtract(bc, a);
203 
204   // Note: if we ever need them, A & B & C would be "batch" dimensions.
205 
206   // A & B red are the reduction dimensions.
207   DenseSet<int64_t> ra = findPermutationsIndexingOperand(
208       linalgOp, linalgOp.getDpsInputOperand(0), red);
209   DenseSet<int64_t> rb = findPermutationsIndexingOperand(
210       linalgOp, linalgOp.getDpsInputOperand(1), red);
211   llvm::set_intersect(ra, rb);
212 
213   if (ac.empty() || bc.empty() || ra.empty())
214     return failure();
215 
216   // Pick the first one in each set.
217   // TODO: Better heuristic (e.g pick dims based on packing-based metric).
218   return EmbeddedMatmulDimsCandidates{ac, bc, ra};
219 }
220 
221 //===----------------------------------------------------------------------===//
222 // General utilities
223 //===----------------------------------------------------------------------===//
224 
225 namespace mlir {
226 namespace linalg {
227 
228 bool allIndexingsAreProjectedPermutation(LinalgOp op) {
229   return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) {
230     return m.isProjectedPermutation(/*allowZeroInResults=*/true);
231   });
232 }
233 
234 bool hasOnlyScalarElementwiseOp(Region &r) {
235   if (!llvm::hasSingleElement(r))
236     return false;
237   for (Operation &op : r.front()) {
238     if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp,
239               linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) ||
240           OpTrait::hasElementwiseMappableTraits(&op)) ||
241         llvm::any_of(op.getResultTypes(),
242                      [](Type type) { return !type.isIntOrIndexOrFloat(); }))
243       return false;
244   }
245   return true;
246 }
247 
248 bool isElementwise(LinalgOp op) {
249   if (op.getNumLoops() != op.getNumParallelLoops())
250     return false;
251 
252   if (!allIndexingsAreProjectedPermutation(op))
253     return false;
254 
255   // TODO: relax the restrictions on indexing map.
256   for (OpOperand *opOperand : op.getDpsInitOperands()) {
257     if (!op.getMatchingIndexingMap(opOperand).isPermutation())
258       return false;
259   }
260   return hasOnlyScalarElementwiseOp(op->getRegion(0));
261 }
262 
263 bool isParallelIterator(utils::IteratorType iteratorType) {
264   return iteratorType == utils::IteratorType::parallel;
265 }
266 
267 bool isReductionIterator(utils::IteratorType iteratorType) {
268   return iteratorType == utils::IteratorType::reduction;
269 }
270 
271 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
272                             Value source, Value pad, bool nofold) {
273   // Exit if `source` is not defined by an ExtractSliceOp.
274   auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>();
275   if (!sliceOp)
276     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
277 
278   // Search the `source` use-def chain for padded LinalgOps.
279   Value current = sliceOp.getSource();
280   while (current) {
281     auto linalgOp = current.getDefiningOp<LinalgOp>();
282     if (!linalgOp)
283       break;
284     OpResult opResult = cast<OpResult>(current);
285     current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get();
286   }
287   auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr;
288 
289   // Exit if the search fails to match a tensor::PadOp at the end of the matched
290   // LinalgOp sequence.
291   if (!padOp)
292     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
293 
294   // Exit if the padded result type does not match.
295   if (sliceOp.getSource().getType() != type)
296     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
297 
298   // Exit if the LinalgOps are not high padded.
299   if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) {
300         return getConstantIntValue(ofr) != static_cast<int64_t>(0);
301       }))
302     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
303 
304   // Exit if `padOpSliceOp`, which defines the slice used by
305   // `padOp`, is rank-reducing.
306   auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
307   if (!padOpSliceOp ||
308       sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size())
309     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
310 
311   // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size
312   // of the slice padded by `padOp`.
313   if (llvm::any_of(
314           llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()),
315           [](std::tuple<OpFoldResult, OpFoldResult> it) {
316             return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it));
317           }))
318     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
319 
320   // Exit if the padding values do not match.
321   Attribute padOpPadAttr, padAttr;
322   Value padOpPad = padOp.getConstantPaddingValue();
323   if (!padOpPad || !matchPattern(padOpPad, m_Constant(&padOpPadAttr)) ||
324       !matchPattern(pad, m_Constant(&padAttr)) || padOpPadAttr != padAttr)
325     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
326 
327   // Return the padded result if the padding values and sizes match.
328   return sliceOp.getSource();
329 }
330 
331 GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
332                           Value outputTensor,
333                           ArrayRef<int64_t> transposeVector) {
334   auto resultTensorType = cast<RankedTensorType>(outputTensor.getType());
335   Type elementType = resultTensorType.getElementType();
336 
337   assert(isPermutationVector(transposeVector) &&
338          "expect transpose vector to be a permutation");
339   assert(transposeVector.size() ==
340              static_cast<size_t>(resultTensorType.getRank()) &&
341          "expect transpose vector size to match result tensor rank");
342 
343   // Compute the transpose and the indentity indexing maps.
344   SmallVector<AffineMap> indexingMaps = {
345       inversePermutation(AffineMap::getPermutationMap(
346           SmallVector<unsigned>(transposeVector.begin(), transposeVector.end()),
347           b.getContext())),
348       AffineMap::getMultiDimIdentityMap(transposeVector.size(),
349                                         b.getContext())};
350   SmallVector<utils::IteratorType> iteratorTypes(transposeVector.size(),
351                                                  utils::IteratorType::parallel);
352 
353   // Create a GenericOp to transpose `inputTensor` into `outputTensor`.
354   auto transposeOp =
355       b.create<GenericOp>(loc, resultTensorType, inputTensor, outputTensor,
356                           indexingMaps, iteratorTypes);
357   Region &body = transposeOp.getRegion();
358   body.push_back(new Block());
359   body.front().addArguments({elementType, elementType}, {loc, loc});
360 
361   // Create the body of the transpose operation.
362   OpBuilder::InsertionGuard g(b);
363   b.setInsertionPointToEnd(&body.front());
364   b.create<YieldOp>(loc, transposeOp.getRegion().front().getArgument(0));
365   return transposeOp;
366 }
367 
368 GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
369   auto memrefTypeTo = cast<MemRefType>(to.getType());
370 #ifndef NDEBUG
371   auto memrefTypeFrom = cast<MemRefType>(from.getType());
372   assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() &&
373          "`from` and `to` memref must have the same rank");
374 #endif // NDEBUG
375 
376   AffineMap id =
377       AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
378   SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(),
379                                                  utils::IteratorType::parallel);
380   return b.create<linalg::GenericOp>(
381       loc,
382       /*inputs=*/from,
383       /*outputs=*/to,
384       /*indexingMaps=*/llvm::ArrayRef({id, id}),
385       /*iteratorTypes=*/iteratorTypes,
386       [](OpBuilder &b, Location loc, ValueRange args) {
387         b.create<linalg::YieldOp>(loc, args.front());
388       });
389 }
390 
391 /// Specialization to build an scf "for" nest.
392 template <>
393 void GenerateLoopNest<scf::ForOp>::doit(
394     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
395     ArrayRef<utils::IteratorType> iteratorTypes,
396     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
397                                   ValueRange)>
398         bodyBuilderFn,
399     ArrayRef<linalg::ProcInfo> procInfo) {
400   assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
401          "expected as many entries for proc info as number of loops, even if "
402          "they are null entries");
403   SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
404                                              ? SmallVector<Value>{}
405                                              : linalgOp.getDpsInitOperands();
406 
407   SmallVector<Value, 4> lbs, ubs, steps;
408   unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
409   LoopNest loopNest = mlir::scf::buildLoopNest(
410       b, loc, lbs, ubs, steps, iterArgInitValues,
411       [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) {
412         assert(iterArgs.size() == iterArgInitValues.size() &&
413                "expect the number of output tensors and iter args to match");
414         SmallVector<Value> operandValuesToUse = linalgOp->getOperands();
415         if (!iterArgs.empty()) {
416           operandValuesToUse = linalgOp.getDpsInputOperands();
417           operandValuesToUse.append(iterArgs.begin(), iterArgs.end());
418         }
419         return bodyBuilderFn(b, loc, ivs, operandValuesToUse);
420       });
421 
422   if (loopNest.loops.empty() || procInfo.empty())
423     return;
424 
425   // Filter out scf.for loops that were created out of parallel dimensions.
426   for (const auto &loop : llvm::enumerate(loopNest.loops)) {
427     if (procInfo[loop.index()].distributionMethod ==
428         DistributionMethod::Cyclic) {
429       mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId,
430                             procInfo[loop.index()].nprocs);
431     }
432   }
433 }
434 
435 /// Specialization to build affine "for" nest.
436 template <>
437 void GenerateLoopNest<AffineForOp>::doit(
438     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
439     ArrayRef<utils::IteratorType> iteratorTypes,
440     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
441                                   ValueRange)>
442         bodyBuilderFn,
443     ArrayRef<linalg::ProcInfo> /*procInfo*/) {
444   SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
445                                              ? SmallVector<Value>{}
446                                              : linalgOp.getDpsInitOperands();
447   assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
448   SmallVector<Value, 4> lbs, ubs, steps;
449   unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
450 
451   // Affine loops require constant steps.
452   SmallVector<int64_t, 4> constantSteps;
453   constantSteps.reserve(steps.size());
454   for (Value v : steps) {
455     auto op = v.getDefiningOp<arith::ConstantIndexOp>();
456     assert(op && "Affine loops require constant steps");
457     constantSteps.push_back(op.value());
458   }
459 
460   affine::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps,
461                               [&](OpBuilder &b, Location loc, ValueRange ivs) {
462                                 bodyBuilderFn(b, loc, ivs,
463                                               linalgOp->getOperands());
464                               });
465 }
466 
467 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
468 void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId,
469                                        Value nprocs, Value &lb, Value &ub,
470                                        Value &step) {
471   AffineExpr d0, d1;
472   bindDims(b.getContext(), d0, d1);
473   AffineExpr s0 = getAffineSymbolExpr(0, b.getContext());
474   lb =
475       affine::makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step});
476   step = affine::makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step});
477 }
478 
479 /// Generates a loop nest consisting of scf.parallel and scf.for, depending
480 /// on the `iteratorTypes.` Consecutive parallel loops create a single
481 /// scf.parallel operation; each sequential loop creates a new scf.for
482 /// operation. The body of the innermost loop is populated by
483 /// `bodyBuilderFn` that accepts a range of induction variables for all
484 /// loops. `ivStorage` is used to store the partial list of induction
485 /// variables.
486 // TODO: this function can be made iterative instead. However, it
487 // will have at most as many recursive calls as nested loops, which rarely
488 // exceeds 10.
489 static void generateParallelLoopNest(
490     OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
491     ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes,
492     ArrayRef<linalg::ProcInfo> procInfo,
493     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
494     SmallVectorImpl<Value> &ivStorage) {
495   assert(lbs.size() == ubs.size());
496   assert(lbs.size() == steps.size());
497   assert(lbs.size() == iteratorTypes.size());
498   assert(procInfo.empty() || (lbs.size() == procInfo.size()));
499 
500   // If there are no (more) loops to be generated, generate the body and be
501   // done with it.
502   if (iteratorTypes.empty()) {
503     bodyBuilderFn(b, loc, ivStorage);
504     return;
505   }
506 
507   // If there are no outer parallel loops, generate one sequential loop and
508   // recurse.
509   if (!isParallelIterator(iteratorTypes.front())) {
510     LoopNest singleLoop = buildLoopNest(
511         b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(),
512         [&](OpBuilder &b, Location loc, ValueRange ivs) {
513           ivStorage.append(ivs.begin(), ivs.end());
514           generateParallelLoopNest(
515               b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(),
516               iteratorTypes.drop_front(),
517               procInfo.empty() ? procInfo : procInfo.drop_front(),
518               bodyBuilderFn, ivStorage);
519         });
520     return;
521   }
522 
523   unsigned nLoops = iteratorTypes.size();
524   unsigned numProcessed = 0;
525   DistributionMethod distributionMethod = DistributionMethod::None;
526   if (procInfo.empty()) {
527     numProcessed = nLoops - iteratorTypes.drop_while(isParallelIterator).size();
528   } else {
529     distributionMethod = procInfo.front().distributionMethod;
530     numProcessed =
531         nLoops - procInfo
532                      .drop_while([&](linalg::ProcInfo p) {
533                        return p.distributionMethod == distributionMethod;
534                      })
535                      .size();
536   }
537 
538   auto remainderProcInfo =
539       procInfo.empty() ? procInfo : procInfo.drop_front(numProcessed);
540   switch (distributionMethod) {
541   case DistributionMethod::None: {
542     // Generate a single parallel loop-nest operation for all outermost
543     // parallel loops and recurse.
544     b.create<scf::ParallelOp>(
545         loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
546         steps.take_front(numProcessed),
547         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
548           ivStorage.append(localIvs.begin(), localIvs.end());
549           generateParallelLoopNest(
550               nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
551               ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
552               iteratorTypes.drop_front(numProcessed), remainderProcInfo,
553               bodyBuilderFn, ivStorage);
554         });
555     return;
556   }
557   case DistributionMethod::Cyclic: {
558     // Generate a single parallel loop-nest operation for all outermost
559     // parallel loops and recurse.
560     b.create<scf::ParallelOp>(
561         loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
562         steps.take_front(numProcessed),
563         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
564           ivStorage.append(localIvs.begin(), localIvs.end());
565           generateParallelLoopNest(
566               nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
567               ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
568               iteratorTypes.drop_front(numProcessed), remainderProcInfo,
569               bodyBuilderFn, ivStorage);
570         });
571     return;
572   }
573   case DistributionMethod::CyclicNumProcsGeNumIters: {
574     // Check (for the processed loops) that the iteration is in-bounds.
575     ArithBuilder ab(b, loc);
576     Value cond = ab.slt(lbs[0], ubs[0]);
577     for (unsigned i = 1; i < numProcessed; ++i)
578       cond = ab._and(cond, ab.slt(lbs[i], ubs[i]));
579     ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
580     b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) {
581       generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed),
582                                ubs.drop_front(numProcessed),
583                                steps.drop_front(numProcessed),
584                                iteratorTypes.drop_front(numProcessed),
585                                remainderProcInfo, bodyBuilderFn, ivStorage);
586       b.create<scf::YieldOp>(loc, ValueRange{});
587     });
588     return;
589   }
590   case DistributionMethod::CyclicNumProcsEqNumIters:
591     // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed
592     // with inner loop generation.
593     ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
594     generateParallelLoopNest(
595         b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
596         steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed),
597         remainderProcInfo, bodyBuilderFn, ivStorage);
598     return;
599   }
600 }
601 
602 /// Specialization for generating a mix of parallel and sequential scf loops.
603 template <>
604 void GenerateLoopNest<scf::ParallelOp>::doit(
605     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
606     ArrayRef<utils::IteratorType> iteratorTypes,
607     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
608                                   ValueRange)>
609         bodyBuilderFn,
610     ArrayRef<linalg::ProcInfo> procInfo) {
611   SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
612                                              ? SmallVector<Value>{}
613                                              : linalgOp.getDpsInitOperands();
614   assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
615   // This function may be passed more iterator types than ranges.
616   assert(iteratorTypes.size() >= loopRanges.size() &&
617          "expected iterator type for all ranges");
618   assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
619          "expected proc information for all loops when present");
620   iteratorTypes = iteratorTypes.take_front(loopRanges.size());
621   SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs;
622   unsigned numLoops = iteratorTypes.size();
623   ivs.reserve(numLoops);
624   lbsStorage.reserve(numLoops);
625   ubsStorage.reserve(numLoops);
626   stepsStorage.reserve(numLoops);
627 
628   // Get the loop lb, ub, and step.
629   unpackRanges(b, loc, loopRanges, lbsStorage, ubsStorage, stepsStorage);
630 
631   // Modify the lb, ub, and step based on the distribution options.
632   for (const auto &it : llvm::enumerate(procInfo)) {
633     if (it.value().distributionMethod != linalg::DistributionMethod::None) {
634       updateBoundsForCyclicDistribution(
635           b, loc, it.value().procId, it.value().nprocs, lbsStorage[it.index()],
636           ubsStorage[it.index()], stepsStorage[it.index()]);
637     }
638   }
639   ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
640   generateParallelLoopNest(
641       b, loc, lbs, ubs, steps, iteratorTypes, procInfo,
642       [&](OpBuilder &b, Location loc, ValueRange ivs) {
643         bodyBuilderFn(b, loc, ivs, linalgOp->getOperands());
644       },
645       ivs);
646 
647   assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
648 }
649 
650 static Value materializeTiledShape(OpBuilder &builder, Location loc,
651                                    Value valueToTile,
652                                    const SliceParameters &sliceParams) {
653   auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
654   auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
655                       .Case([&](MemRefType) {
656                         return builder.create<memref::SubViewOp>(
657                             loc, valueToTile, sliceParams.offsets,
658                             sliceParams.sizes, sliceParams.strides);
659                       })
660                       .Case([&](RankedTensorType) {
661                         return builder.create<tensor::ExtractSliceOp>(
662                             loc, valueToTile, sliceParams.offsets,
663                             sliceParams.sizes, sliceParams.strides);
664                       })
665                       .Default([](ShapedType) -> Operation * {
666                         llvm_unreachable("Unexpected shaped type");
667                       });
668   return sliceOp->getResult(0);
669 }
670 
671 Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
672                      ArrayRef<OpFoldResult> tileSizes, AffineMap map,
673                      ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
674                      ArrayRef<OpFoldResult> subShapeSizes,
675                      bool omitPartialTileCheck) {
676   SliceParameters sliceParams =
677       computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
678                              ubs, subShapeSizes, omitPartialTileCheck);
679   return materializeTiledShape(builder, loc, valueToTile, sliceParams);
680 }
681 
682 SliceParameters
683 computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
684                        ArrayRef<OpFoldResult> tileSizes, AffineMap map,
685                        ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
686                        ArrayRef<OpFoldResult> subShapeSizes,
687                        bool omitPartialTileCheck) {
688   auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
689   assert(shapedType && "only shaped types can be tiled");
690   ArrayRef<int64_t> shape = shapedType.getShape();
691   int64_t rank = shapedType.getRank();
692 
693   // Compute offsets/sizes/strides for the tile.
694   SliceParameters sliceParams;
695   sliceParams.offsets.reserve(rank);
696   sliceParams.sizes.reserve(rank);
697   sliceParams.strides.reserve(rank);
698   for (unsigned r = 0; r < rank; ++r) {
699     LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r);
700     if (!isTiled(map.getSubMap({r}), tileSizes)) {
701       sliceParams.offsets.push_back(builder.getIndexAttr(0));
702       OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r);
703       sliceParams.sizes.push_back(dim);
704       sliceParams.strides.push_back(builder.getIndexAttr(1));
705       LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
706       continue;
707     }
708     LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n");
709 
710     // Tiling creates a new slice at the proper index, the slice step is 1
711     // (i.e. the op does not subsample, stepping occurs in the loop).
712     auto m = map.getSubMap({r});
713     LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n");
714     IRRewriter rewriter(builder);
715     OpFoldResult offset = makeComposedFoldedAffineApply(rewriter, loc, m, lbs);
716     sliceParams.offsets.push_back(offset);
717     OpFoldResult closedIntSize =
718         makeComposedFoldedAffineApply(rewriter, loc, m, subShapeSizes);
719     // Resulting size needs to be made half open interval again.
720     AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext());
721     OpFoldResult size =
722         makeComposedFoldedAffineApply(rewriter, loc, s0 + 1, closedIntSize);
723     LLVM_DEBUG(llvm::dbgs()
724                << "computeSliceParameters: raw size: " << size << "\n");
725     LLVM_DEBUG(llvm::dbgs()
726                << "computeSliceParameters: new offset: " << offset << "\n");
727     sliceParams.strides.push_back(builder.getIndexAttr(1));
728 
729     if (omitPartialTileCheck) {
730       // We statically know that the partial/boundary tile condition is
731       // unnecessary.
732       LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
733       sliceParams.sizes.push_back(size);
734       continue;
735     }
736 
737     // The size of the subview / extract_slice should be trimmed to avoid
738     // out-of-bounds accesses, unless:
739     // a. We statically know the subshape size divides the shape size evenly.
740     // b. The subshape size is 1. According to the way the loops are set up,
741     //    tensors with "0" dimensions would never be constructed.
742     int64_t shapeSize = shape[r];
743     std::optional<int64_t> sizeCst = getConstantIntValue(size);
744     auto hasTileSizeOne = sizeCst && *sizeCst == 1;
745     auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) &&
746                          ((shapeSize % *sizeCst) == 0);
747     if (!hasTileSizeOne && !dividesEvenly) {
748       LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize
749                               << ", size: " << size
750                               << ": make sure in bound with affine.min\n");
751 
752       AffineExpr dim0, dim1, dim2;
753       bindDims(builder.getContext(), dim0, dim1, dim2);
754 
755       // Get the dimension size for this dimension. We need to first calculate
756       // the max index and then plus one. This is important because for
757       // convolution ops, we have its input window dimension's affine map of the
758       // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window
759       // dimension and `s0` is stride. Directly use the dimension size of
760       // output/filer window dimensions will cause incorrect calculation.
761       AffineMap minusOneMap =
762           AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 - 1}})
763               .front();
764       AffineMap plusOneMap =
765           AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 + 1}})
766               .front();
767       SmallVector<OpFoldResult> maxIndices =
768           llvm::to_vector(llvm::map_range(ubs, [&](OpFoldResult ub) {
769             return makeComposedFoldedAffineApply(rewriter, loc, minusOneMap,
770                                                  {ub});
771           }));
772       OpFoldResult maxIndex =
773           makeComposedFoldedAffineApply(rewriter, loc, m, maxIndices);
774       OpFoldResult d =
775           makeComposedFoldedAffineApply(rewriter, loc, plusOneMap, {maxIndex});
776 
777       // Compute min(dim - offset, size) to avoid out-of-bounds accesses.
778       AffineMap minMap = AffineMap::inferFromExprList(
779                              {ArrayRef<AffineExpr>{dim1 - dim2, dim0}})
780                              .front();
781       size =
782           makeComposedFoldedAffineMin(rewriter, loc, minMap, {size, d, offset});
783     }
784     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
785     sliceParams.sizes.push_back(size);
786   }
787   return sliceParams;
788 }
789 
790 SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
791                                              ArrayRef<OpFoldResult> ivs,
792                                              ArrayRef<OpFoldResult> tileSizes) {
793   SmallVector<OpFoldResult> offsets;
794   for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
795     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
796     bool isTiled = !isZeroIndex(tileSizes[idx]);
797     offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0));
798     LLVM_DEBUG(llvm::dbgs()
799                << "computeTileOffsets: " << offsets.back() << "\n");
800   }
801   return offsets;
802 }
803 
804 SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc,
805                                            ArrayRef<OpFoldResult> tileSizes,
806                                            ArrayRef<OpFoldResult> sizeBounds) {
807   SmallVector<OpFoldResult> sizes;
808   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
809     bool isTiled = !isZeroIndex(tileSizes[idx]);
810     // Before composing, we need to make range a closed interval.
811     OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx];
812     AffineExpr d0 = getAffineDimExpr(0, b.getContext());
813     IRRewriter rewriter(b);
814     sizes.push_back(makeComposedFoldedAffineApply(rewriter, loc, d0 - 1, size));
815     LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n");
816   }
817   return sizes;
818 }
819 
820 SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) {
821   if (op.hasBufferSemantics())
822     return {};
823   return llvm::to_vector(
824       llvm::map_range(op.getDpsInitOperands(), [&](OpOperand *opOperand) {
825         return operands[opOperand->getOperandNumber()].getType();
826       }));
827 }
828 
829 SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
830                                     LinalgOp op, ValueRange operands,
831                                     ValueRange results) {
832   if (op.hasBufferSemantics())
833     return {};
834   SmallVector<Value> tensorResults;
835   tensorResults.reserve(results.size());
836   // Insert a insert_slice for each output tensor.
837   unsigned resultIdx = 0;
838   for (OpOperand *opOperand : op.getDpsInitOperands()) {
839     // TODO: use an interface/adaptor to avoid leaking position in
840     // `tiledOperands`.
841     Value outputTensor = operands[opOperand->getOperandNumber()];
842     if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
843       Value inserted = builder.create<tensor::InsertSliceOp>(
844           loc, sliceOp.getSource().getType(), results[resultIdx],
845           sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(),
846           sliceOp.getStrides(), sliceOp.getStaticOffsets(),
847           sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
848       tensorResults.push_back(inserted);
849     } else {
850       tensorResults.push_back(results[resultIdx]);
851     }
852     ++resultIdx;
853   }
854   return tensorResults;
855 }
856 
857 SmallVector<std::optional<SliceParameters>>
858 computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
859                           ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs,
860                           ArrayRef<OpFoldResult> tileSizes,
861                           ArrayRef<OpFoldResult> sizeBounds,
862                           bool omitPartialTileCheck) {
863   assert(ivs.size() == static_cast<size_t>(llvm::count_if(
864                            llvm::make_range(tileSizes.begin(), tileSizes.end()),
865                            [](OpFoldResult v) { return !isZeroIndex(v); })) &&
866          "expected as many ivs as non-zero sizes");
867 
868   // Construct (potentially temporary) mins and maxes on which to apply maps
869   // that define tile subshapes.
870   SmallVector<OpFoldResult> lbs =
871       computeTileOffsets(builder, loc, ivs, tileSizes);
872   SmallVector<OpFoldResult> subShapeSizes =
873       computeTileSizes(builder, loc, tileSizes, sizeBounds);
874 
875   assert(static_cast<int64_t>(valuesToTile.size()) <=
876              linalgOp->getNumOperands() &&
877          "more value to tile than operands.");
878   SmallVector<std::optional<SliceParameters>> allSliceParams;
879   allSliceParams.reserve(valuesToTile.size());
880   for (auto [opOperand, val] :
881        llvm::zip(linalgOp->getOpOperands(), valuesToTile)) {
882     Value shapedOp = val;
883     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
884     AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
885     // Use `opOperand` as is if it is not tiled and not an output tensor. Having
886     // an extract/insert slice pair for all output tensors simplifies follow up
887     // transformations such as padding and bufferization since the
888     // extract/insert slice pairs make the accessed iteration argument
889     // subdomains explicit.
890 
891     Type operandType = opOperand.get().getType();
892     if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) &&
893                                       linalgOp.isDpsInit(&opOperand))) {
894       allSliceParams.push_back(std::nullopt);
895       LLVM_DEBUG(llvm::dbgs()
896                  << ": not tiled: use shape: " << operandType << "\n");
897       continue;
898     }
899     LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
900 
901     allSliceParams.push_back(computeSliceParameters(
902         builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
903         omitPartialTileCheck));
904   }
905 
906   return allSliceParams;
907 }
908 
909 SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc,
910                                    LinalgOp linalgOp, ValueRange valuesToTile,
911                                    ArrayRef<OpFoldResult> ivs,
912                                    ArrayRef<OpFoldResult> tileSizes,
913                                    ArrayRef<OpFoldResult> sizeBounds,
914                                    bool omitPartialTileCheck) {
915   SmallVector<std::optional<SliceParameters>> allSliceParameter =
916       computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs,
917                                 tileSizes, sizeBounds, omitPartialTileCheck);
918   SmallVector<Value> tiledShapes;
919   for (auto item : llvm::zip(valuesToTile, allSliceParameter)) {
920     Value valueToTile = std::get<0>(item);
921     std::optional<SliceParameters> sliceParams = std::get<1>(item);
922     tiledShapes.push_back(
923         sliceParams.has_value()
924             ? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
925             : valueToTile);
926   }
927   return tiledShapes;
928 }
929 
930 void offsetIndices(OpBuilder &b, LinalgOp linalgOp,
931                    ArrayRef<OpFoldResult> offsets) {
932   IRRewriter rewriter(b);
933   offsetIndices(rewriter, linalgOp, offsets);
934 }
935 
936 void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
937                    ArrayRef<OpFoldResult> offsets) {
938   if (!linalgOp.hasIndexSemantics())
939     return;
940 
941   for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) {
942     if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()])
943       continue;
944     OpBuilder::InsertionGuard guard(b);
945     b.setInsertionPointAfter(indexOp);
946     AffineExpr index, offset;
947     bindDims(b.getContext(), index, offset);
948     OpFoldResult applied = makeComposedFoldedAffineApply(
949         b, indexOp.getLoc(), index + offset,
950         {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
951     Value materialized =
952         getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied);
953     b.replaceOpWithIf(indexOp, materialized, [&](OpOperand &use) {
954       return use.getOwner() != materialized.getDefiningOp();
955     });
956   }
957 }
958 
959 /// Get the reassociation maps to fold the result of a extract_slice (or source
960 /// of a insert_slice) operation with given offsets, and sizes to its
961 /// rank-reduced version. This is only done for the cases where the size is 1
962 /// and offset is 0. Strictly speaking the offset 0 is not required in general,
963 /// but non-zero offsets are not handled by SPIR-V backend at this point (and
964 /// potentially cannot be handled).
965 std::optional<SmallVector<ReassociationIndices>>
966 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
967   SmallVector<ReassociationIndices> reassociation;
968   ReassociationIndices curr;
969   for (const auto &it : llvm::enumerate(mixedSizes)) {
970     auto dim = it.index();
971     auto size = it.value();
972     curr.push_back(dim);
973     auto attr = size.dyn_cast<Attribute>();
974     if (attr && cast<IntegerAttr>(attr).getInt() == 1)
975       continue;
976     reassociation.emplace_back(ReassociationIndices{});
977     std::swap(reassociation.back(), curr);
978   }
979   // When the reassociations are not empty, then fold the remaining
980   // unit-dimensions into the last dimension.  If the reassociations so far is
981   // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
982   if (!curr.empty() && !reassociation.empty())
983     reassociation.back().append(curr.begin(), curr.end());
984   return reassociation;
985 }
986 
987 /// Return the identity numeric value associated to the give op.
988 std::optional<TypedAttr> getNeutralElement(Operation *op) {
989   // Builder only used as helper for attribute creation.
990   OpBuilder b(op->getContext());
991   Type resultType = op->getResult(0).getType();
992   if (auto floatType = dyn_cast<FloatType>(resultType)) {
993     const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
994     if (isa<arith::AddFOp>(op))
995       return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
996     if (isa<arith::MulFOp>(op))
997       return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
998     if (isa<arith::MaxFOp>(op))
999       return b.getFloatAttr(resultType,
1000                             llvm::APFloat::getInf(semantic, /*Negative=*/true));
1001     if (isa<arith::MinFOp>(op))
1002       return b.getFloatAttr(
1003           resultType, llvm::APFloat::getInf(semantic, /*Negative=*/false));
1004     return std::nullopt;
1005   }
1006   if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
1007     return b.getIntegerAttr(resultType, 0);
1008   if (isa<arith::AndIOp>(op))
1009     return b.getIntegerAttr(resultType, -1);
1010   if (isa<arith::MaxSIOp>(op))
1011     return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
1012   if (isa<arith::MinSIOp>(op))
1013     return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
1014   if (isa<arith::MulIOp>(op))
1015     return b.getIntegerAttr(resultType, 1);
1016   return std::nullopt;
1017 }
1018 
1019 } // namespace linalg
1020 } // namespace mlir
1021