xref: /llvm-project/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (revision 0d4efa27252cbbea4b5672d4d8ffc15a3ba51d83)
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
14 
15 #include "mlir/AsmParser/AsmParser.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Arith/Utils/Utils.h"
19 #include "mlir/Dialect/Complex/IR/Complex.h"
20 #include "mlir/Dialect/Math/IR/Math.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Dialect/Utils/IndexingUtils.h"
26 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
27 #include "mlir/Dialect/Utils/StaticValueUtils.h"
28 #include "mlir/IR/AffineExprVisitor.h"
29 #include "mlir/IR/AffineMap.h"
30 #include "mlir/IR/Attributes.h"
31 #include "mlir/IR/BuiltinAttributes.h"
32 #include "mlir/IR/BuiltinTypeInterfaces.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/IR/OpImplementation.h"
35 #include "mlir/IR/OperationSupport.h"
36 #include "mlir/IR/PatternMatch.h"
37 #include "mlir/Interfaces/InferTypeOpInterface.h"
38 #include "mlir/Interfaces/SideEffectInterfaces.h"
39 
40 #include "llvm/ADT/DenseMap.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SetOperations.h"
43 #include "llvm/ADT/SmallSet.h"
44 #include "llvm/ADT/SmallVector.h"
45 #include "llvm/ADT/StringSet.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/FormatVariadic.h"
48 #include "llvm/Support/LogicalResult.h"
49 #include "llvm/Support/MathExtras.h"
50 #include "llvm/Support/raw_ostream.h"
51 #include <cassert>
52 #include <optional>
53 
54 using namespace mlir;
55 using namespace mlir::linalg;
56 
57 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
58 static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
59                                 int64_t dim) {
60   auto type = cast<ShapedType>(v.getType());
61   if (!type.isDynamicDim(dim))
62     return builder.getIndexAttr(type.getDimSize(dim));
63 
64   return getAsOpFoldResult(
65       TypeSwitch<Type, Value>(v.getType())
66           .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
67             return builder.create<tensor::DimOp>(loc, v, dim);
68           })
69           .Case<MemRefType>([&](MemRefType t) -> Value {
70             return builder.create<memref::DimOp>(loc, v, dim);
71           }));
72 }
73 
74 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
75 /// `source`.
76 static Operation *getSlice(OpBuilder &b, Location loc, Value source,
77                            ArrayRef<OpFoldResult> offsets,
78                            ArrayRef<OpFoldResult> sizes,
79                            ArrayRef<OpFoldResult> strides) {
80   return TypeSwitch<Type, Operation *>(source.getType())
81       .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
82         return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
83                                                 strides);
84       })
85       .Case<MemRefType>([&](MemRefType type) -> Operation * {
86         return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
87                                            strides);
88       })
89       .Default([&](Type t) -> Operation * { return nullptr; });
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // Helper functions
94 //===----------------------------------------------------------------------===//
95 
96 Value linalg::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
97                                 int64_t dim) {
98   if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
99     return b.createOrFold<memref::DimOp>(loc, source, dim);
100   if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
101     return b.createOrFold<tensor::DimOp>(loc, source, dim);
102   llvm_unreachable("Expected MemRefType or TensorType");
103 }
104 
105 OpFoldResult linalg::createFoldedDimOp(OpBuilder &b, Location loc, Value source,
106                                        int64_t dim) {
107   auto shapedType = llvm::cast<ShapedType>(source.getType());
108   if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
109     return createOrFoldDimOp(b, loc, source, dim);
110   return b.getIndexAttr(shapedType.getDimSize(dim));
111 }
112 
113 //===----------------------------------------------------------------------===//
114 // Support for named Linalg ops defined in ods-gen.
115 //===----------------------------------------------------------------------===//
116 
117 using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
118                                                 ArrayRef<NamedAttribute>)>;
119 
120 /// Fills the region of a structured operation using the provided
121 /// `regionBuilder`. The method is used by both named structured ops created by
122 /// ods-gen and by manually defined C++ ops. It is called by both builders and
123 /// parsers and creates a block with arguments corresponding to the elemental
124 /// types of `inputTypes` and `outputTypes`.
125 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
126                                    TypeRange inputTypes, TypeRange outputTypes,
127                                    ArrayRef<NamedAttribute> attrs,
128                                    RegionBuilderFn regionBuilder) {
129   SmallVector<Type, 8> argTypes;
130   SmallVector<Location, 8> argLocs;
131   for (auto containers : {inputTypes, outputTypes}) {
132     for (auto t : containers) {
133       argTypes.push_back(
134           isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
135 
136       // TODO: Pass in a proper location here.
137       argLocs.push_back(opBuilder.getUnknownLoc());
138     }
139   }
140 
141   // RAII.
142   OpBuilder::InsertionGuard guard(opBuilder);
143   Block *body =
144       opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
145 
146   opBuilder.setInsertionPointToStart(body);
147   ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
148   regionBuilder(b, *body, attrs);
149 
150   // indexing_maps is an auto-generated method.
151 
152   // iterator_types is an auto-generated method.
153 }
154 
155 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
156 /// The result types are derived automatically if `resultTensorTypes` is none.
157 /// The body of the operation is filled using `regionBuilder`. All ods-gen
158 /// created structured operations use the method to implement their builders.
159 static void buildStructuredOp(OpBuilder &b, OperationState &state,
160                               std::optional<TypeRange> resultTensorTypes,
161                               ValueRange inputs, ValueRange outputs,
162                               ArrayRef<NamedAttribute> attributes,
163                               RegionBuilderFn regionBuilder) {
164   // Derive the result types if needed.
165   SmallVector<Type> derivedResultTypes =
166       resultTensorTypes.value_or(TypeRange());
167   if (!resultTensorTypes)
168     copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
169             llvm::IsaPred<RankedTensorType>);
170 
171   state.addOperands(inputs);
172   state.addOperands(outputs);
173   state.addTypes(derivedResultTypes);
174 
175   state.addAttributes(attributes);
176   state.addAttribute(
177       "operandSegmentSizes",
178       b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
179                               static_cast<int32_t>(outputs.size())}));
180 
181   // Create and fill the region of the structured operation.
182   Region &region = *state.addRegion();
183   fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
184                          state.attributes.getAttrs(), regionBuilder);
185 }
186 
187 static void buildMatmulOp(OpBuilder &b, OperationState &state,
188                           std::optional<TypeRange> resultTensorTypes,
189                           ValueRange inputs, ValueRange outputs,
190                           ArrayRef<NamedAttribute> attributes,
191                           RegionBuilderFn regionBuilder,
192                           ArrayRef<AffineMap> indexingMaps) {
193   // Initialize indexingMaps attribute, for MatmulOp.
194   SmallVector<Attribute, 3> indexingMapsAttrVal;
195   indexingMapsAttrVal = llvm::map_to_vector(
196       MatmulOp::getDefaultIndexingMaps(b.getContext()),
197       [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
198   state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
199   return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
200                            attributes, regionBuilder);
201 }
202 
203 /// Common parsing used for both named structured ops created by ods-gen and by
204 /// manually defined C++ ops. Does not handle regions.
205 static ParseResult
206 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
207                              SmallVectorImpl<Type> &inputTypes,
208                              SmallVectorImpl<Type> &outputTypes,
209                              bool addOperandSegmentSizes = true) {
210   SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
211   SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands,
212       outputsOperands;
213 
214   if (succeeded(parser.parseOptionalLess())) {
215     if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
216       return failure();
217   }
218   attrsLoc = parser.getCurrentLocation();
219   if (parser.parseOptionalAttrDict(result.attributes))
220     return failure();
221 
222   if (succeeded(parser.parseOptionalKeyword("ins"))) {
223     if (parser.parseLParen())
224       return failure();
225 
226     inputsOperandsLoc = parser.getCurrentLocation();
227     if (parser.parseOperandList(inputsOperands) ||
228         parser.parseColonTypeList(inputTypes) || parser.parseRParen())
229       return failure();
230   }
231 
232   if (succeeded(parser.parseOptionalKeyword("outs"))) {
233     outputsOperandsLoc = parser.getCurrentLocation();
234     if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
235         parser.parseColonTypeList(outputTypes) || parser.parseRParen())
236       return failure();
237   }
238 
239   if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
240                              result.operands) ||
241       parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
242                              result.operands))
243     return failure();
244 
245   if (addOperandSegmentSizes) {
246     // This is a bit complex because we're trying to be backward compatible with
247     // operation syntax that mix the inherent attributes and the discardable
248     // ones in the same dictionary. If the properties are used, we append the
249     // operandSegmentSizes there directly. Otherwise we append it to the
250     // discardable attributes dictionary where it is handled by the generic
251     // Operation::create(...) method.
252     if (result.propertiesAttr) {
253       NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
254       attrs.append("operandSegmentSizes",
255                    parser.getBuilder().getDenseI32ArrayAttr(
256                        {static_cast<int32_t>(inputsOperands.size()),
257                         static_cast<int32_t>(outputsOperands.size())}));
258       result.propertiesAttr = attrs.getDictionary(parser.getContext());
259     } else {
260       result.addAttribute("operandSegmentSizes",
261                           parser.getBuilder().getDenseI32ArrayAttr(
262                               {static_cast<int32_t>(inputsOperands.size()),
263                                static_cast<int32_t>(outputsOperands.size())}));
264     }
265   }
266   if (!result.propertiesAttr) {
267     std::optional<RegisteredOperationName> info =
268         result.name.getRegisteredInfo();
269     if (info) {
270       if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
271             return parser.emitError(attrsLoc)
272                    << "'" << result.name.getStringRef() << "' op ";
273           })))
274         return failure();
275     }
276   }
277   return success();
278 }
279 
280 static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
281                                          ValueRange outputs) {
282   if (!inputs.empty())
283     p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
284   if (!outputs.empty())
285     p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
286 }
287 
288 //===----------------------------------------------------------------------===//
289 // Specific parsing and printing for named structured ops created by ods-gen.
290 //===----------------------------------------------------------------------===//
291 
292 static ParseResult parseNamedStructuredOpRegion(
293     OpAsmParser &parser, Region &region, unsigned numRegionArgs,
294     TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
295     RegionBuilderFn regionBuilder) {
296   if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
297     return parser.emitError(
298         parser.getCurrentLocation(),
299         llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
300                       "region expects {0} args, got {1}",
301                       numRegionArgs, inputTypes.size() + outputTypes.size()));
302   }
303 
304   OpBuilder opBuilder(parser.getContext());
305   fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
306                          regionBuilder);
307   return success();
308 }
309 
310 static ParseResult
311 parseNamedStructuredOpResults(OpAsmParser &parser,
312                               SmallVectorImpl<Type> &resultTypes) {
313   if (parser.parseOptionalArrowTypeList(resultTypes))
314     return failure();
315   return success();
316 }
317 
318 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
319                                           OperationState &result,
320                                           unsigned numRegionArgs,
321                                           RegionBuilderFn regionBuilder) {
322   // TODO: Enable when ods-gen supports captures.
323   SmallVector<Type, 1> inputTypes, outputTypes;
324   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
325     return failure();
326 
327   // Parse optional attributes.
328   if (parser.parseOptionalAttrDict(result.attributes))
329     return failure();
330 
331   // TODO: consider merging results parsing into region parsing.
332   // Need to wait for declarative assembly resolution to decide.
333   SmallVector<Type, 1> outputTensorsTypes;
334   if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
335     return failure();
336   result.addTypes(outputTensorsTypes);
337 
338   std::unique_ptr<Region> region = std::make_unique<Region>();
339   if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
340                                    outputTypes, result.attributes.getAttrs(),
341                                    regionBuilder))
342     return failure();
343   result.addRegion(std::move(region));
344 
345   return success();
346 }
347 
348 static void printNamedStructuredOpResults(OpAsmPrinter &p,
349                                           TypeRange resultTypes) {
350   if (resultTypes.empty())
351     return;
352   p.printOptionalArrowTypeList(resultTypes);
353 }
354 
355 static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
356                                    ValueRange inputs, ValueRange outputs,
357                                    ArrayRef<StringRef> elidedAttrs = {}) {
358   p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
359 
360   // Printing is shared with generic ops, except for the region and
361   // attributes.
362   printCommonStructuredOpParts(p, inputs, outputs);
363 
364   // Results printing.
365   printNamedStructuredOpResults(p, op->getResultTypes());
366 
367   // Region is elided.
368 }
369 
370 //===----------------------------------------------------------------------===//
371 // Region builder helper.
372 // TODO: Move this to a utility library.
373 // The public methods on this class are referenced directly from generated code.
374 // Helper build the unary, binary, and type conversion functions defined by the
375 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
376 // class.
377 //
378 // Implementations of the math functions must be polymorphic over numeric types,
379 // internally performing necessary casts. If the function application makes no
380 // sense, then the only recourse is to assert and return nullptr. This can be
381 // extended later if it becomes possible to fail construction of the region. The
382 // invariant should be enforced at a higher level.
383 //
384 // TODO: These helpers are currently type polymorphic over the class of integer
385 // and floating point types, but they will not internally cast within bit
386 // widths of a class (mixed precision such as i8->i32) or across classes
387 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
388 // to be handled with care and work is being considered to extend the op
389 // language to make such cases explicit. In the mean-time, violating this will
390 // fail verification, which is deemed acceptable.
391 //===----------------------------------------------------------------------===//
392 
393 namespace {
394 
395 class RegionBuilderHelper {
396 public:
397   RegionBuilderHelper(OpBuilder &builder, Block &block)
398       : builder(builder), block(block) {}
399 
400   // Build the unary functions defined by OpDSL.
401   Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
402     if (!isFloatingPoint(arg))
403       llvm_unreachable("unsupported non numeric type");
404     OpBuilder::InsertionGuard g(builder);
405     builder.setInsertionPointToEnd(&block);
406     switch (unaryFn) {
407     case UnaryFn::exp:
408       return builder.create<math::ExpOp>(arg.getLoc(), arg);
409     case UnaryFn::log:
410       return builder.create<math::LogOp>(arg.getLoc(), arg);
411     case UnaryFn::abs:
412       return builder.create<math::AbsFOp>(arg.getLoc(), arg);
413     case UnaryFn::ceil:
414       return builder.create<math::CeilOp>(arg.getLoc(), arg);
415     case UnaryFn::floor:
416       return builder.create<math::FloorOp>(arg.getLoc(), arg);
417     case UnaryFn::negf:
418       return builder.create<arith::NegFOp>(arg.getLoc(), arg);
419     case UnaryFn::reciprocal: {
420       Attribute oneAttr = builder.getOneAttr(arg.getType());
421       auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
422                                                    ::cast<TypedAttr>(oneAttr));
423       return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
424     }
425     case UnaryFn::round:
426       return builder.create<math::RoundOp>(arg.getLoc(), arg);
427     case UnaryFn::sqrt:
428       return builder.create<math::SqrtOp>(arg.getLoc(), arg);
429     case UnaryFn::rsqrt:
430       return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
431     case UnaryFn::square:
432       return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
433     case UnaryFn::tanh:
434       return builder.create<math::TanhOp>(arg.getLoc(), arg);
435     case UnaryFn::erf:
436       return builder.create<math::ErfOp>(arg.getLoc(), arg);
437     }
438     llvm_unreachable("unsupported unary function");
439   }
440 
441   // Build the binary functions defined by OpDSL.
442   Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
443     bool allComplex = isComplex(arg0) && isComplex(arg1);
444     bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
445     bool allInteger = isInteger(arg0) && isInteger(arg1);
446     bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
447                    arg1.getType().getIntOrFloatBitWidth() == 1;
448     if (!allComplex && !allFloatingPoint && !allInteger)
449       llvm_unreachable("unsupported non numeric type");
450     OpBuilder::InsertionGuard g(builder);
451     builder.setInsertionPointToEnd(&block);
452     switch (binaryFn) {
453     case BinaryFn::add:
454       if (allComplex)
455         return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
456       if (allFloatingPoint)
457         return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
458       if (allBool)
459         return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
460       return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
461     case BinaryFn::sub:
462       if (allComplex)
463         return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
464       if (allFloatingPoint)
465         return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
466       if (allBool)
467         llvm_unreachable("unsupported operation: sub with bools");
468       return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
469     case BinaryFn::mul:
470       if (allComplex)
471         return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
472       if (allFloatingPoint)
473         return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
474       if (allBool)
475         return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
476       return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
477     case BinaryFn::div:
478       if (allComplex)
479         return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
480       if (allFloatingPoint)
481         return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
482       if (allBool)
483         llvm_unreachable("unsupported operation: div with bools");
484       return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
485     case BinaryFn::div_unsigned:
486       if (!allInteger || allBool)
487         llvm_unreachable("unsupported operation: unsigned div not on uint");
488       return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
489     case BinaryFn::max_signed:
490       assert(!allComplex);
491       if (allFloatingPoint)
492         return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
493       return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
494     case BinaryFn::min_signed:
495       assert(!allComplex);
496       if (allFloatingPoint)
497         return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
498       return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
499     case BinaryFn::max_unsigned:
500       assert(!allComplex);
501       if (allFloatingPoint)
502         return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
503       return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
504     case BinaryFn::min_unsigned:
505       assert(!allComplex);
506       if (allFloatingPoint)
507         return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
508       return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
509     case BinaryFn::powf:
510       assert(allFloatingPoint);
511       return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
512     }
513     llvm_unreachable("unsupported binary function");
514   }
515 
516   // Build the ternary functions defined by OpDSL.
517   Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
518                        Value arg2) {
519     bool headBool =
520         isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
521     bool tailFloatingPoint =
522         isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
523     bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
524     OpBuilder::InsertionGuard g(builder);
525     builder.setInsertionPointToEnd(&block);
526     switch (ternaryFn) {
527     case TernaryFn::select:
528       if (!headBool && !(tailFloatingPoint || tailInteger))
529         llvm_unreachable("unsupported non numeric type");
530       return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
531     }
532     llvm_unreachable("unsupported ternary function");
533   }
534 
535   // Build the type functions defined by OpDSL.
536   Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
537     switch (typeFn) {
538     case TypeFn::cast_signed:
539       return cast(toType, operand, false);
540     case TypeFn::cast_unsigned:
541       return cast(toType, operand, true);
542     }
543     llvm_unreachable("unsupported type conversion function");
544   }
545 
546   void yieldOutputs(ValueRange values) {
547     OpBuilder::InsertionGuard g(builder);
548     builder.setInsertionPointToEnd(&block);
549     Location loc = builder.getUnknownLoc();
550     builder.create<YieldOp>(loc, values);
551   }
552 
553   Value constant(const std::string &value) {
554     OpBuilder::InsertionGuard g(builder);
555     builder.setInsertionPointToEnd(&block);
556     Location loc = builder.getUnknownLoc();
557     Attribute valueAttr = parseAttribute(value, builder.getContext());
558     return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
559   }
560 
561   Value index(int64_t dim) {
562     OpBuilder::InsertionGuard g(builder);
563     builder.setInsertionPointToEnd(&block);
564     return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
565   }
566 
567   Type getIntegerType(unsigned width) {
568     return IntegerType::get(builder.getContext(), width);
569   }
570 
571   Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
572   Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
573 
574 private:
575   // Generates operations to cast the given operand to a specified type.
576   // If the cast cannot be performed, a warning will be issued and the
577   // operand returned as-is (which will presumably yield a verification
578   // issue downstream).
579   Value cast(Type toType, Value operand, bool isUnsignedCast) {
580     OpBuilder::InsertionGuard g(builder);
581     builder.setInsertionPointToEnd(&block);
582     auto loc = operand.getLoc();
583     return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
584   }
585 
586   bool isComplex(Value value) {
587     return llvm::isa<ComplexType>(value.getType());
588   }
589   bool isFloatingPoint(Value value) {
590     return llvm::isa<FloatType>(value.getType());
591   }
592   bool isInteger(Value value) {
593     return llvm::isa<IntegerType>(value.getType());
594   }
595 
596   OpBuilder &builder;
597   Block &block;
598 };
599 
600 } // namespace
601 
602 //===----------------------------------------------------------------------===//
603 // CopyOp
604 //===----------------------------------------------------------------------===//
605 
606 namespace {
607 
608 struct EraseSelfCopy : OpRewritePattern<CopyOp> {
609   using OpRewritePattern<CopyOp>::OpRewritePattern;
610   LogicalResult matchAndRewrite(CopyOp copyOp,
611                                 PatternRewriter &rewriter) const override {
612     if (copyOp.getInputs() != copyOp.getOutputs())
613       return rewriter.notifyMatchFailure(copyOp, "not a self copy");
614     if (copyOp.hasPureBufferSemantics())
615       rewriter.eraseOp(copyOp);
616     else
617       rewriter.replaceOp(copyOp, copyOp.getInputs());
618 
619     return success();
620   }
621 };
622 
623 } // namespace
624 
625 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
626                                          MLIRContext *context) {
627   results.add<EraseSelfCopy>(context);
628 }
629 
630 //===----------------------------------------------------------------------===//
631 // FillOp
632 //===----------------------------------------------------------------------===//
633 
634 namespace {
635 
636 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
637 ///
638 /// For such op chains, we can create new linalg.fill ops with the result
639 /// type of the tensor.expand/collapse_shape op.
640 template <typename TensorReshapeOp>
641 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
642   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
643   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
644                                 PatternRewriter &rewriter) const override {
645     auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
646     if (!oldFill)
647       return failure();
648 
649     Location loc = oldFill.getLoc();
650     TensorReshapeOp newInit;
651     if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
652 
653       newInit = rewriter.create<TensorReshapeOp>(
654           loc, reshapeOp.getResultType(), oldFill.output(),
655           reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
656           reshapeOp.getStaticOutputShape());
657     } else {
658       newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
659                                                  oldFill.output(),
660                                                  reshapeOp.getReassociation());
661     }
662     rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
663                                         ValueRange{newInit});
664     return success();
665   }
666 };
667 
668 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
669 /// filling value are the same.
670 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
671   using OpRewritePattern::OpRewritePattern;
672 
673   LogicalResult matchAndRewrite(tensor::PadOp padOp,
674                                 PatternRewriter &rewriter) const override {
675     auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
676     if (!fillOp)
677       return failure();
678 
679     // We can only fold if the padding value is the same as the original
680     // filling value.
681     Value padValue = padOp.getConstantPaddingValue();
682     if (!padValue || fillOp.value() != padValue)
683       return failure();
684 
685     ReifiedRankedShapedTypeDims reifiedShape;
686     if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
687       return rewriter.notifyMatchFailure(
688           padOp, "failed to reify tensor.pad op result shape");
689 
690     auto emptyTensor = rewriter.create<tensor::EmptyOp>(
691         padOp.getLoc(), reifiedShape.front(),
692         padOp.getResultType().getElementType());
693     Value replacement =
694         rewriter
695             .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
696                             ValueRange{emptyTensor})
697             .getResult(0);
698     if (replacement.getType() != padOp.getResultType()) {
699       replacement = rewriter.create<tensor::CastOp>(
700           fillOp.getLoc(), padOp.getResultType(), replacement);
701     }
702     rewriter.replaceOp(padOp, replacement);
703     return success();
704   }
705 };
706 
707 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
708 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
709 /// filling value are the same.
710 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
711   using OpRewritePattern::OpRewritePattern;
712 
713   LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
714                                 PatternRewriter &rewriter) const override {
715     auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
716     if (!srcPadOp)
717       return failure();
718 
719     if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
720       return failure();
721 
722     // Walk back the tensor.insert_slice chain and find the first destination
723     // value at the start of the chain.
724     Value firstDest = insertOp.getDest();
725     while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
726       if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
727         return failure();
728 
729       // Make sure the range of values accessed are disjoint. Without this, we
730       // cannot fold tensor.pad away.
731       bool disjoint = false;
732       for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
733         // If the dimension has dynamic offset/size, we cannot guarantee
734         // disjoint. So just skip it.
735         if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
736             insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
737             prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
738           continue;
739 
740         // Get the range start and end, inclusively for both.
741         int64_t prevStart = prevOp.getStaticOffset(i);
742         int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
743                                           prevOp.getStaticStride(i);
744         int64_t nextStart = insertOp.getStaticOffset(i);
745         int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
746                                           insertOp.getStaticStride(i);
747         if (prevEnd < nextStart || nextEnd < prevStart) {
748           disjoint = true;
749           break;
750         }
751       }
752 
753       if (!disjoint)
754         break;
755       firstDest = prevOp.getDest();
756     }
757 
758     // Check whether the first destination is a fill op. For overlapped cases,
759     // this also cannot be true.
760     auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
761     if (!dstFillOp)
762       return failure();
763 
764     // We can only fold if the padding value is the same as the original
765     // filling value.
766     Value padValue = srcPadOp.getConstantPaddingValue();
767     if (!padValue || dstFillOp.value() != padValue)
768       return failure();
769 
770     SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
771     SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
772 
773     Location loc = insertOp.getLoc();
774     MLIRContext *context = getContext();
775 
776     AffineExpr sym0, sym1;
777     bindSymbols(context, sym0, sym1);
778     auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
779 
780     // Calculate the new offsets for the insert. It should be the old offsets
781     // plus low padding sizes.
782     SmallVector<OpFoldResult, 4> newOffsets;
783     for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
784       newOffsets.push_back(affine::makeComposedFoldedAffineApply(
785           rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
786     }
787 
788     RankedTensorType srcPadType = srcPadOp.getSourceType();
789     SmallVector<OpFoldResult, 4> newSizes;
790     for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
791       if (srcPadType.isDynamicDim(i)) {
792         newSizes.push_back(
793             rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
794                 .getResult());
795       } else {
796         newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
797       }
798     }
799 
800     rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
801         insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
802         newSizes, insertOp.getMixedStrides());
803     return success();
804   }
805 };
806 
807 /// Fold tensor.extract(linalg.fill(<input>)) into <input>
808 struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
809 public:
810   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
811 
812   LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
813                                 PatternRewriter &rewriter) const override {
814     // See if tensor input of tensor.extract op is the result of a linalg.fill
815     // op.
816     auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
817     if (!fillOp)
818       return failure();
819 
820     // Get scalar input operand of linalg.fill op.
821     Value extractedScalar = fillOp.getInputs()[0];
822 
823     // Replace tensor.extract op with scalar value used to fill the tensor.
824     rewriter.replaceOp(extractOp, extractedScalar);
825     return success();
826   }
827 };
828 
829 /// Folds pack(fill) into a single fill op if
830 ///   1. The pack op does not have padding value, or
831 ///   2. The filled value and padding value are the same.
832 static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
833                                                 tensor::PackOp packOp) {
834   auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
835   if (!fillOp)
836     return failure();
837 
838   if (auto paddingValue = packOp.getPaddingValue())
839     if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
840       return failure();
841 
842   Value packOpDest = packOp.getDest();
843   if (!packOpDest.hasOneUse())
844     return failure();
845 
846   return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
847                                          packOp.getDest());
848 }
849 
850 /// Wrapper pattern that applies foldFillPackIntoFillOp method.
851 struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> {
852 public:
853   FoldFillWithPack(MLIRContext *context)
854       : OpRewritePattern<tensor::PackOp>(context) {}
855 
856   LogicalResult matchAndRewrite(tensor::PackOp packOp,
857                                 PatternRewriter &rewriter) const override {
858     auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
859     if (failed(fillOp))
860       return failure();
861     rewriter.replaceOp(packOp, fillOp.value().result());
862     return success();
863   }
864 };
865 
866 /// Fold fill with copy.
867 struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
868   using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
869 
870   LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
871                                 PatternRewriter &rewriter) const override {
872     if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
873       rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
874                                           fillOp.getInputs(),
875                                           copyOp.getOutputs());
876       return success();
877     }
878     if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
879       rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
880                                                   fillOp.getOutputs());
881       return success();
882     }
883     return failure();
884   }
885 };
886 
887 /// Fold fill with transpose.
888 struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
889   using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
890 
891   LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
892                                 PatternRewriter &rewriter) const override {
893     if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
894       rewriter.replaceOpWithNewOp<FillOp>(
895           transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
896           transposeOp.getDpsInitOperand(0)->get());
897       return success();
898     }
899     return failure();
900   }
901 };
902 
903 /// Fold a concat with all elements being fills of the same value
904 /// into a fill of the concat result shape.
905 struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
906   using OpRewritePattern::OpRewritePattern;
907 
908   LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
909                                 PatternRewriter &rewriter) const override {
910     auto concatOperands = concatOp.getInputs();
911     if (concatOperands.empty()) {
912       return failure();
913     }
914 
915     auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
916     if (!firstFillOp) {
917       return failure();
918     }
919     // Prefetch the fill value.
920     OpFoldResult firstFillVal =
921         getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
922     // Collect all the outs values for the fill operations.
923     SmallVector<Value> allOuts;
924     allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
925 
926     auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
927       auto fillOp = v.getDefiningOp<linalg::FillOp>();
928       if (!fillOp) {
929         return false;
930       }
931 
932       OpFoldResult fillVal =
933           getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
934       if (fillVal != firstFillVal)
935         return false;
936 
937       allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
938       return true;
939     };
940     if (!llvm::all_of(concatOperands.drop_front(),
941                       isDefinedByCompatibleFillOp)) {
942       return rewriter.notifyMatchFailure(
943           concatOp, "not all operands are defined by a compatible fill op");
944     }
945 
946     Value outsConcat = rewriter.create<tensor::ConcatOp>(
947         concatOp.getLoc(), concatOp.getDim(), allOuts);
948     rewriter.replaceOpWithNewOp<linalg::FillOp>(
949         concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
950     return success();
951   }
952 };
953 
954 } // namespace
955 
956 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
957                                          MLIRContext *context) {
958   results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
959               FoldFillWithPack, FoldFillWithPad,
960               FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
961               FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
962               FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
963 }
964 
965 //===----------------------------------------------------------------------===//
966 // GenericOp
967 //===----------------------------------------------------------------------===//
968 
969 static void buildGenericRegion(
970     OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
971     ValueRange outputs,
972     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
973   SmallVector<Type, 4> blockArgTypes;
974   SmallVector<Location, 4> blockArgLocs;
975   for (ValueRange container : {inputs, outputs}) {
976     for (Value v : container) {
977       Type t = v.getType();
978       blockArgTypes.push_back(
979           isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
980       blockArgLocs.push_back(v.getLoc());
981     }
982   }
983 
984   OpBuilder::InsertionGuard guard(builder);
985   Block *bodyBlock =
986       builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
987   bodyBuild(builder, loc, bodyBlock->getArguments());
988 }
989 
990 void GenericOp::getAsmBlockArgumentNames(Region &region,
991                                          OpAsmSetValueNameFn setNameFn) {
992   for (Value v : getRegionInputArgs())
993     setNameFn(v, "in");
994   for (Value v : getRegionOutputArgs())
995     setNameFn(v, "out");
996 }
997 
998 void GenericOp::build(
999     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1000     ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1001     ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1002     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1003     ArrayRef<NamedAttribute> attributes) {
1004   build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1005         iteratorTypes, doc, libraryCall);
1006   result.addAttributes(attributes);
1007   if (bodyBuild)
1008     buildGenericRegion(builder, result.location, *result.regions.front(),
1009                        inputs, outputs, bodyBuild);
1010 }
1011 
1012 void GenericOp::build(
1013     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1014     ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1015     ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1016     StringRef libraryCall,
1017     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1018     ArrayRef<NamedAttribute> attributes) {
1019   build(builder, result, resultTensorTypes, inputs, outputs,
1020         builder.getAffineMapArrayAttr(indexingMaps),
1021         builder.getArrayAttr(llvm::to_vector(llvm::map_range(
1022             iteratorTypes,
1023             [&](utils::IteratorType iter) -> mlir::Attribute {
1024               return IteratorTypeAttr::get(builder.getContext(), iter);
1025             }))),
1026         doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1027         libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1028         bodyBuild, attributes);
1029 }
1030 
1031 void GenericOp::build(
1032     OpBuilder &builder, OperationState &result, ValueRange inputs,
1033     ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1034     ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1035     StringRef libraryCall,
1036     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1037     ArrayRef<NamedAttribute> attributes) {
1038   build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1039         iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1040 }
1041 
1042 void GenericOp::build(
1043     OpBuilder &builder, OperationState &result, ValueRange inputs,
1044     ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1045     ArrayRef<utils::IteratorType> iteratorTypes,
1046     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1047     ArrayRef<NamedAttribute> attributes) {
1048   build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1049         /*doc=*/"",
1050         /*libraryCall=*/"", bodyBuild, attributes);
1051 }
1052 
1053 void GenericOp::build(
1054     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1055     ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1056     ArrayRef<utils::IteratorType> iteratorTypes,
1057     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1058     ArrayRef<NamedAttribute> attributes) {
1059   build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1060         iteratorTypes,
1061         /*doc=*/"",
1062         /*libraryCall=*/"", bodyBuild, attributes);
1063 }
1064 
1065 void GenericOp::print(OpAsmPrinter &p) {
1066   p << " ";
1067 
1068   // Print extra attributes.
1069   auto genericAttrNames = linalgTraitAttrNames();
1070 
1071   llvm::StringSet<> genericAttrNamesSet;
1072   genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
1073   SmallVector<NamedAttribute, 8> genericAttrs;
1074   for (auto attr : (*this)->getAttrs()) {
1075     if (attr.getName() == getIteratorTypesAttrName()) {
1076       auto iteratorTypes =
1077           llvm::cast<ArrayAttr>(attr.getValue())
1078               .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1079       // Convert IteratorType enums into the string representation. This is
1080       // needed, because tests still use the old format when 'iterator_types'
1081       // attribute is represented as an array of strings.
1082       // TODO: Remove this conversion once tests are fixed.
1083       SmallVector<Attribute> iteratorTypeNames =
1084           llvm::to_vector(llvm::map_range(
1085               iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1086                 return StringAttr::get(getContext(), stringifyIteratorType(t));
1087               }));
1088 
1089       genericAttrs.emplace_back(
1090           getIteratorTypesAttrName(),
1091           ArrayAttr::get(getContext(), iteratorTypeNames));
1092     } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1093       genericAttrs.push_back(attr);
1094     }
1095   }
1096   if (!genericAttrs.empty()) {
1097     auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1098     p << genericDictAttr;
1099   }
1100 
1101   // Printing is shared with named ops, except for the region and attributes
1102   printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1103 
1104   genericAttrNames.push_back("operandSegmentSizes");
1105   genericAttrNamesSet.insert(genericAttrNames.back());
1106 
1107   bool hasExtraAttrs = false;
1108   for (NamedAttribute n : (*this)->getAttrs()) {
1109     if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1110       break;
1111   }
1112   if (hasExtraAttrs) {
1113     p << " attrs = ";
1114     p.printOptionalAttrDict((*this)->getAttrs(),
1115                             /*elidedAttrs=*/genericAttrNames);
1116   }
1117 
1118   // Print region.
1119   if (!getRegion().empty()) {
1120     p << ' ';
1121     p.printRegion(getRegion());
1122   }
1123 
1124   // Print results.
1125   printNamedStructuredOpResults(p, getResultTensors().getTypes());
1126 }
1127 
1128 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1129   DictionaryAttr dictAttr;
1130   // Parse the core linalg traits that must check into a dictAttr.
1131   // The name is unimportant as we will overwrite result.attributes.
1132   // The core linalg traits must contain the information necessary to pass the
1133   // verifier.
1134   llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1135   if (parser.parseAttribute(dictAttr, "_", result.attributes))
1136     return failure();
1137   result.attributes.assign(dictAttr.getValue().begin(),
1138                            dictAttr.getValue().end());
1139 
1140   // Convert array of string into an array of IteratorType enums. This is
1141   // needed, because tests still use the old format when 'iterator_types'
1142   // attribute is represented as an array of strings.
1143   // TODO: Remove this conversion once tests are fixed.
1144   auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1145       result.attributes.get(getIteratorTypesAttrName(result.name)));
1146   if (!iteratorTypes) {
1147     return parser.emitError(attributeLocation)
1148            << "expected " << getIteratorTypesAttrName(result.name)
1149            << " array attribute";
1150   }
1151 
1152   SmallVector<Attribute> iteratorTypeAttrs;
1153 
1154   for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1155     auto maybeIteratorType = utils::symbolizeIteratorType(s);
1156     if (!maybeIteratorType.has_value())
1157       return parser.emitError(parser.getCurrentLocation())
1158              << "unexpected iterator_type (" << s << ")";
1159 
1160     iteratorTypeAttrs.push_back(
1161         IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1162   }
1163   result.attributes.set(getIteratorTypesAttrName(result.name),
1164                         parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1165 
1166   // Parsing is shared with named ops, except for the region.
1167   SmallVector<Type, 1> inputTypes, outputTypes;
1168   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1169     return failure();
1170 
1171   // Optional attributes may be added.
1172   if (succeeded(parser.parseOptionalKeyword("attrs")))
1173     if (failed(parser.parseEqual()) ||
1174         failed(parser.parseOptionalAttrDict(result.attributes)))
1175       return failure();
1176 
1177   std::unique_ptr<Region> region = std::make_unique<Region>();
1178   if (parser.parseRegion(*region, {}))
1179     return failure();
1180   result.addRegion(std::move(region));
1181 
1182   // Generic ops may specify that a subset of its outputs are tensors. Such
1183   // outputs are specified in the result type.
1184   // TODO: may need to move output parsing before region parsing.
1185   // Need to wait for declarative assembly resolution to decide.
1186   SmallVector<Type, 1> outputTensorsTypes;
1187   if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1188     return failure();
1189   result.addTypes(outputTensorsTypes);
1190 
1191   return success();
1192 }
1193 
1194 static void getGenericEffectsImpl(
1195     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1196         &effects,
1197     LinalgOp linalgOp) {
1198   for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1199     if (!llvm::isa<MemRefType>(operand.getType()))
1200       continue;
1201     effects.emplace_back(
1202         MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1203         /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1204   }
1205 
1206   for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1207     if (!llvm::isa<MemRefType>(operand.get().getType()))
1208       continue;
1209     if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1210       effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1211                            /*effectOnFullRegion=*/true,
1212                            SideEffects::DefaultResource::get());
1213     }
1214     effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1215                          /*effectOnFullRegion=*/true,
1216                          SideEffects::DefaultResource::get());
1217   }
1218 }
1219 
1220 void GenericOp::getEffects(
1221     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1222         &effects) {
1223   getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1224 }
1225 
1226 static Speculation::Speculatability
1227 getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
1228   // Operands with value semantics are speculatable, while operands with memory
1229   // semantics are not.
1230   if (!linalgOp.hasPureTensorSemantics())
1231     return Speculation::NotSpeculatable;
1232   // The body of the op can still have speculation in its region.
1233   return Speculation::RecursivelySpeculatable;
1234 }
1235 
1236 Speculation::Speculatability GenericOp::getSpeculatability() {
1237   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1238 }
1239 
1240 LogicalResult GenericOp::verify() { return success(); }
1241 
1242 namespace {
1243 
1244 /// Remove any linalg operation (on tensors) that are just copying
1245 /// the values from inputs to the results. Requirements are
1246 /// 1) All iterator types are parallel
1247 /// 2) The body contains just a yield operation with the yielded values being
1248 ///    the arguments corresponding to the operands.
1249 template <typename OpTy>
1250 struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1251   using OpRewritePattern<OpTy>::OpRewritePattern;
1252 
1253   LogicalResult matchAndRewrite(OpTy linalgOp,
1254                                 PatternRewriter &rewriter) const override {
1255     // All indexing maps must be equal. It follows that they are permutations.
1256     if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1257       return failure();
1258 
1259     // Check that the body of the linalg operation is just a linalg.yield
1260     // operation.
1261     Block &body = linalgOp->getRegion(0).front();
1262     if (!llvm::hasSingleElement(body))
1263       return failure();
1264     auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1265     if (!yieldOp)
1266       return failure();
1267 
1268     // In the buffer case, we need to check exact buffer equality.
1269     if (linalgOp.hasPureBufferSemantics()) {
1270       if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1271           linalgOp.getDpsInputOperand(0)->get() ==
1272               linalgOp.getDpsInitOperand(0)->get()) {
1273         rewriter.eraseOp(linalgOp);
1274         return success();
1275       }
1276       return failure();
1277     }
1278 
1279     // Mixed semantics is not supported yet.
1280     if (!linalgOp.hasPureTensorSemantics())
1281       return failure();
1282 
1283     // Get the argument number of the returned values. That is the operand
1284     // number to use for replacing uses of this operation.
1285     SmallVector<Value> returnedArgs;
1286     for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1287       auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1288       if (!yieldArg || yieldArg.getOwner() != &body)
1289         return failure();
1290       unsigned argumentNumber = yieldArg.getArgNumber();
1291       Value returnedArg = linalgOp->getOperand(argumentNumber);
1292       Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1293       // The input can have a different type than the result, e.g. a dynamic
1294       // input dimension can be turned into a static output dimension.
1295       Type returnType = returnedArg.getType();
1296       if (returnType != resultType) {
1297         // Distinguish between sparse conversion or dense tensor casting.
1298         // TODO: unify the two ops?
1299         if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1300             sparse_tensor::getSparseTensorEncoding(resultType))
1301           returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1302               linalgOp.getLoc(), resultType, returnedArg);
1303         else {
1304           if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1305                                                  resultType))
1306             return failure();
1307           returnedArg = rewriter.create<tensor::CastOp>(
1308               linalgOp.getLoc(), resultType, returnedArg);
1309         }
1310       }
1311       returnedArgs.push_back(returnedArg);
1312     }
1313 
1314     if (returnedArgs.size() != linalgOp->getNumResults())
1315       return failure();
1316     rewriter.replaceOp(linalgOp, returnedArgs);
1317     return success();
1318   }
1319 };
1320 
1321 } // namespace
1322 
1323 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1324                                             MLIRContext *context) {
1325   results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1326 }
1327 
1328 LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1329   return memref::foldMemRefCast(*this);
1330 }
1331 
1332 //===----------------------------------------------------------------------===//
1333 // MapOp
1334 //===----------------------------------------------------------------------===//
1335 
1336 static ParseResult parseDstStyleOp(
1337     OpAsmParser &parser, OperationState &result,
1338     function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1339         nullptr) {
1340   // Parse `ins` and `outs`.
1341   SmallVector<Type, 4> inputTypes, outputTypes;
1342   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1343                                    /*addOperandSegmentSizes=*/false))
1344     return failure();
1345 
1346   // Add result types.
1347   for (Type outputType : outputTypes) {
1348     if (llvm::isa<RankedTensorType>(outputType))
1349       result.addTypes(outputType);
1350   }
1351 
1352   // Parse required attributes.
1353   if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1354     return failure();
1355 
1356   // Parse optional attributes.
1357   if (parser.parseOptionalAttrDict(result.attributes))
1358     return failure();
1359   return success();
1360 }
1361 
1362 void MapOp::getAsmBlockArgumentNames(Region &region,
1363                                      OpAsmSetValueNameFn setNameFn) {
1364   for (Value v : getRegionInputArgs())
1365     setNameFn(v, "in");
1366 }
1367 
1368 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1369   if (!getResults().empty())
1370     setNameFn(getResults().front(), "mapped");
1371 }
1372 
1373 void MapOp::build(
1374     OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1375     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1376     ArrayRef<NamedAttribute> attributes) {
1377   build(builder, result, TypeRange{}, inputs, init);
1378   result.addAttributes(attributes);
1379 
1380   // Add output types for `RankedTensorType` output arguments.
1381   Type initType = init.getType();
1382   if (llvm::isa<RankedTensorType>(initType))
1383     result.addTypes(initType);
1384 
1385   if (bodyBuild)
1386     buildGenericRegion(builder, result.location, *result.regions.front(),
1387                        inputs, /*outputs=*/{}, bodyBuild);
1388 }
1389 
1390 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1391                                  const OperationName &payloadOpName,
1392                                  const NamedAttrList &payloadOpAttrs,
1393                                  ArrayRef<Value> operands,
1394                                  bool initFirst = false) {
1395   OpBuilder b(parser.getContext());
1396   Region *body = result.addRegion();
1397   Block &block = body->emplaceBlock();
1398   b.setInsertionPointToStart(&block);
1399   SmallVector<Value> bbArgs;
1400   for (auto &operand : operands) {
1401     block.addArgument(
1402         llvm::cast<ShapedType>(operand.getType()).getElementType(),
1403         b.getUnknownLoc());
1404   }
1405   SmallVector<Value> payloadOpOperands;
1406   // If initFirst flag is enabled, we consider init as the first position of
1407   // payload operands.
1408   if (initFirst) {
1409     payloadOpOperands.push_back(block.getArguments().back());
1410     for (const auto &arg : block.getArguments().drop_back())
1411       payloadOpOperands.push_back(arg);
1412   } else {
1413     payloadOpOperands = {block.getArguments().begin(),
1414                          block.getArguments().end()};
1415   }
1416 
1417   Operation *payloadOp = b.create(
1418       result.location, b.getStringAttr(payloadOpName.getStringRef()),
1419       payloadOpOperands,
1420       TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1421                     .getElementType()},
1422       payloadOpAttrs);
1423   b.create<YieldOp>(result.location, payloadOp->getResults());
1424 }
1425 
1426 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1427   std::optional<OperationName> payloadOpName;
1428   NamedAttrList payloadOpAttrs;
1429   if (succeeded(parser.parseOptionalLBrace())) {
1430     FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1431     if (failed(operationName))
1432       return failure();
1433     if (parser.parseOptionalAttrDict(payloadOpAttrs))
1434       return failure();
1435     payloadOpName = operationName.value();
1436     if (parser.parseRBrace())
1437       return failure();
1438   }
1439 
1440   if (parseDstStyleOp(parser, result))
1441     return failure();
1442 
1443   if (payloadOpName.has_value()) {
1444     if (!result.operands.empty())
1445       addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1446                            payloadOpAttrs,
1447                            ArrayRef(result.operands).drop_back());
1448     else
1449       result.addRegion();
1450   } else {
1451     SmallVector<OpAsmParser::Argument> regionArgs;
1452     if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1453                                  /*allowType=*/true, /*allowAttrs=*/true)) {
1454       return failure();
1455     }
1456     Region *body = result.addRegion();
1457     if (parser.parseRegion(*body, regionArgs))
1458       return failure();
1459   }
1460   return success();
1461 }
1462 
1463 // Retrieve the operation from the body, if it is the only one (except
1464 // yield) and if it gets the same amount of arguments as the body does.
1465 // If initFirst flag is enabled, we check that init takes the first position in
1466 // operands of payload.
1467 static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1468   if (body->getOperations().size() != 2)
1469     return nullptr;
1470   Operation &payload = body->getOperations().front();
1471   assert(isa<YieldOp>(body->getOperations().back()));
1472 
1473   if (payload.getNumOperands() == 0 ||
1474       payload.getNumOperands() != body->getNumArguments())
1475     return nullptr;
1476   if (initFirst) {
1477     // check init
1478     if (payload.getOperands().back() != body->getArgument(0))
1479       return nullptr;
1480     // check rest
1481     for (const auto &[operand, bbArg] :
1482          llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1483       if (bbArg != operand)
1484         return nullptr;
1485     }
1486   } else {
1487     for (const auto &[operand, bbArg] :
1488          llvm::zip(payload.getOperands(), body->getArguments())) {
1489       if (bbArg != operand)
1490         return nullptr;
1491     }
1492   }
1493   return &payload;
1494 }
1495 
1496 void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1497   SmallVector<StringRef> elidedAttrs;
1498   std::string attrToElide;
1499   p << " { " << payloadOp->getName().getStringRef();
1500   for (const auto &attr : payloadOp->getAttrs()) {
1501     auto fastAttr =
1502         llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1503     if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1504       attrToElide = attr.getName().str();
1505       elidedAttrs.push_back(attrToElide);
1506       break;
1507     }
1508   }
1509   p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1510   p << " }";
1511 }
1512 
1513 void MapOp::print(OpAsmPrinter &p) {
1514   Block *mapper = getBody();
1515   Operation *payloadOp = findPayloadOp(mapper);
1516   if (payloadOp) {
1517     printShortForm(p, payloadOp);
1518   }
1519 
1520   printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1521   p.printOptionalAttrDict((*this)->getAttrs());
1522 
1523   if (!payloadOp) {
1524     // Print region if the payload op was not detected.
1525     p.increaseIndent();
1526     p.printNewline();
1527     p << "(";
1528     llvm::interleaveComma(mapper->getArguments(), p,
1529                           [&](auto arg) { p.printRegionArgument(arg); });
1530     p << ") ";
1531 
1532     p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1533     p.decreaseIndent();
1534   }
1535 }
1536 
1537 LogicalResult MapOp::verify() {
1538   auto *bodyBlock = getBody();
1539   auto blockArgs = bodyBlock->getArguments();
1540 
1541   // Checks if the number of `inputs` match the arity of the `mapper` region.
1542   if (getInputs().size() != blockArgs.size())
1543     return emitOpError() << "expects number of operands to match the arity of "
1544                             "mapper, but got: "
1545                          << getInputs().size() << " and " << blockArgs.size();
1546 
1547   // The parameters of mapper should all match the element type of inputs.
1548   for (const auto &[bbArgType, inputArg] :
1549        llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1550     auto inputElemType =
1551         llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1552     if (bbArgType != inputElemType) {
1553       return emitOpError() << "expected element type of input " << inputElemType
1554                            << " to match bbArg type " << bbArgType;
1555     }
1556   }
1557 
1558   // The shape of each input must match the shape of the output.
1559   auto outputShape = getInit().getType().getShape();
1560   for (Type inputArgType : TypeRange{getInputs()}) {
1561     auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1562     if (inputElemShape != outputShape) {
1563       return emitOpError() << "expected shape of input (" << inputElemShape
1564                            << ") to match shape of output (" << outputShape
1565                            << ")";
1566     }
1567   }
1568 
1569   return success();
1570 }
1571 
1572 SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1573   int64_t rank = getInit().getType().getRank();
1574   return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1575 }
1576 
1577 ArrayAttr MapOp::getIndexingMaps() {
1578   Builder builder(getContext());
1579   int64_t rank = getInit().getType().getRank();
1580   int64_t numIndexingMaps = getOperands().size();
1581   return builder.getAffineMapArrayAttr(SmallVector<AffineMap>(
1582       numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1583 }
1584 
1585 void MapOp::getEffects(
1586     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1587         &effects) {
1588   getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1589 }
1590 
1591 Speculation::Speculatability MapOp::getSpeculatability() {
1592   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1593 }
1594 
1595 //===----------------------------------------------------------------------===//
1596 // ReduceOp
1597 //===----------------------------------------------------------------------===//
1598 
1599 void ReduceOp::getAsmBlockArgumentNames(Region &region,
1600                                         OpAsmSetValueNameFn setNameFn) {
1601   for (Value v : getRegionInputArgs())
1602     setNameFn(v, "in");
1603   for (Value v : getRegionOutputArgs())
1604     setNameFn(v, "init");
1605 }
1606 
1607 void ReduceOp::getAsmResultNames(
1608     function_ref<void(Value, StringRef)> setNameFn) {
1609   if (!getResults().empty())
1610     setNameFn(getResults().front(), "reduced");
1611 }
1612 
1613 void ReduceOp::build(
1614     OpBuilder &builder, OperationState &result, ValueRange inputs,
1615     ValueRange inits, ArrayRef<int64_t> dimensions,
1616     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1617     ArrayRef<NamedAttribute> attributes) {
1618   build(builder, result, TypeRange{}, inputs, inits, dimensions);
1619   result.addAttributes(attributes);
1620 
1621   // Add output types for `RankedTensorType` output arguments.
1622   for (Value init : inits) {
1623     Type initType = init.getType();
1624     if (llvm::isa<RankedTensorType>(initType))
1625       result.addTypes(initType);
1626   }
1627 
1628   if (bodyBuild)
1629     buildGenericRegion(builder, result.location, *result.regions.front(),
1630                        inputs, inits, bodyBuild);
1631 }
1632 
1633 SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1634   int64_t inputRank =
1635       llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1636   SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1637                                                  utils::IteratorType::parallel);
1638   for (int64_t reductionDim : getDimensions())
1639     iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1640   return iteratorTypes;
1641 }
1642 
1643 ArrayAttr ReduceOp::getIndexingMaps() {
1644   int64_t inputRank =
1645       llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1646   SmallVector<AffineMap> affineMaps(
1647       getNumDpsInputs(),
1648       AffineMap::getMultiDimIdentityMap(inputRank, getContext()));
1649   AffineMap resultMap =
1650       AffineMap::getMultiDimIdentityMap(inputRank, getContext())
1651           .dropResults(getDimensions());
1652   for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1653     affineMaps.push_back(resultMap);
1654   return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1655 }
1656 
1657 void ReduceOp::getEffects(
1658     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1659         &effects) {
1660   getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1661 }
1662 
1663 Speculation::Speculatability ReduceOp::getSpeculatability() {
1664   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1665 }
1666 
1667 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1668                                           NamedAttrList &attributes,
1669                                           StringRef attributeName) {
1670   if (parser.parseKeyword(attributeName) || parser.parseEqual())
1671     return failure();
1672 
1673   attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1674   return success();
1675 }
1676 
1677 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1678   std::optional<OperationName> payloadOpName;
1679   NamedAttrList payloadOpAttrs;
1680   if (succeeded(parser.parseOptionalLBrace())) {
1681     FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1682     if (failed(operationName))
1683       return failure();
1684     if (parser.parseOptionalAttrDict(payloadOpAttrs))
1685       return failure();
1686     payloadOpName = operationName.value();
1687     if (parser.parseRBrace())
1688       return failure();
1689   }
1690 
1691   if (parseDstStyleOp(
1692           parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1693             return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1694           }))
1695     return failure();
1696 
1697   if (payloadOpName.has_value()) {
1698     addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1699                          ArrayRef(result.operands), /*initFirst=*/true);
1700   } else {
1701     SmallVector<OpAsmParser::Argument> regionArgs;
1702     if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1703                                  /*allowType=*/true, /*allowAttrs=*/true)) {
1704       return failure();
1705     }
1706 
1707     Region *body = result.addRegion();
1708     if (parser.parseRegion(*body, regionArgs))
1709       return failure();
1710   }
1711 
1712   return success();
1713 }
1714 
1715 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1716                                    ArrayRef<int64_t> attributeValue) {
1717   p << ' ' << attributeName << " = [" << attributeValue << "] ";
1718 }
1719 
1720 void ReduceOp::print(OpAsmPrinter &p) {
1721   Block *mapper = getBody();
1722   Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1723   if (payloadOp) {
1724     printShortForm(p, payloadOp);
1725   }
1726 
1727   printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1728   printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1729   p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1730   if (!payloadOp) {
1731     // Print region if the payload op was not detected.
1732     p.increaseIndent();
1733     p.printNewline();
1734     p << "(";
1735     llvm::interleaveComma(mapper->getArguments(), p,
1736                           [&](auto arg) { p.printRegionArgument(arg); });
1737     p << ") ";
1738 
1739     p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1740     p.decreaseIndent();
1741   }
1742 }
1743 
1744 LogicalResult ReduceOp::verify() {
1745   ArrayRef<int64_t> dimensionsRef = getDimensions();
1746 
1747   for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1748     if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1749         llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1750       return emitOpError() << "expects all inputs to have the same shapes. "
1751                               "Shape at input-index "
1752                            << i
1753                            << " is not equal to the shape at input-index 0.";
1754     }
1755   }
1756   for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1757     if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1758         llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1759       return emitOpError() << "expects all outputs to have the same shapes. "
1760                               "Shape at output-index "
1761                            << i
1762                            << " is not equal to the shape at output-index 0.";
1763     }
1764   }
1765   auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1766   auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1767 
1768   DenseSet<int64_t> dimensionsToReduce;
1769   for (int64_t dimension : dimensionsRef) {
1770     if (dimension < 0 || dimension >= inputType.getRank()) {
1771       return emitOpError()
1772              << "dimensions for reduction should be in the range [0, "
1773              << inputType.getRank() - 1 << "].";
1774     }
1775     dimensionsToReduce.insert(dimension);
1776   }
1777 
1778   auto inputDims = inputType.getShape();
1779   auto initDims = initType.getShape();
1780 
1781   // Input dimensions that will be left after the reduction.
1782   SmallVector<int64_t> reducedInputDims;
1783   for (const auto &en : llvm::enumerate(inputDims)) {
1784     if (!dimensionsToReduce.count(en.index()))
1785       reducedInputDims.push_back(en.value());
1786   }
1787 
1788   if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1789     return emitOpError() << "number of dimensions after reduction "
1790                          << reducedInputDims.size()
1791                          << " doesn't match the init rank "
1792                          << initType.getRank();
1793   }
1794 
1795   if (reducedInputDims != initDims)
1796     return emitOpError() << "init dimensions [" << initDims
1797                          << "] doesn't match input dimensions after reduction ["
1798                          << reducedInputDims << "]";
1799 
1800   Block *block = getBody();
1801   if (block->getNumArguments() != this->getNumOperands())
1802     return emitOpError()
1803            << "mismatching number of operands and block arguments";
1804 
1805   // Check that the first block arguments match the element type of the inputs.
1806   for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1807     Type inputElementType =
1808         llvm::cast<ShapedType>(input.getType()).getElementType();
1809     if (inputElementType != bbArg.getType())
1810       return emitOpError()
1811              << "input element type " << inputElementType
1812              << " does not match corresponding block argument type "
1813              << bbArg.getType();
1814   }
1815 
1816   // Check that the last block arguments match the element type of the outputs.
1817   for (auto [output, bbArg] : llvm::zip(
1818            getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1819     auto outputElementType =
1820         llvm::cast<ShapedType>(output.getType()).getElementType();
1821     if (outputElementType != bbArg.getType())
1822       return emitOpError()
1823              << "output element type " << outputElementType
1824              << " does not match corresponding block argument type "
1825              << bbArg.getType();
1826   }
1827   return success();
1828 }
1829 
1830 //===----------------------------------------------------------------------===//
1831 // TransposeOp
1832 //===----------------------------------------------------------------------===//
1833 
1834 static void buildIdentityRegion(OpBuilder &builder, Location loc,
1835                                 Region &region, ValueRange inputs,
1836                                 ValueRange outputs) {
1837   buildGenericRegion(builder, loc, region, inputs, outputs,
1838                      [](OpBuilder &b, Location loc, ValueRange args) {
1839                        if (!args.empty())
1840                          b.create<linalg::YieldOp>(loc, args[0]);
1841                      });
1842 }
1843 
1844 void TransposeOp::build(::mlir::OpBuilder &builder,
1845                         ::mlir::OperationState &result, Value input, Value init,
1846                         DenseI64ArrayAttr permutation,
1847                         ArrayRef<NamedAttribute> attributes) {
1848   result.addOperands(input);
1849   result.addOperands(init);
1850   result.addAttribute(getPermutationAttrName(result.name), permutation);
1851   result.addAttributes(attributes);
1852 
1853   // Add output types for `RankedTensorType` output arguments.
1854   Type initType = init.getType();
1855   if (llvm::isa<RankedTensorType>(initType))
1856     result.addTypes(initType);
1857 
1858   buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1859                       init);
1860 }
1861 
1862 void TransposeOp::build(::mlir::OpBuilder &builder,
1863                         ::mlir::OperationState &result, Value input, Value init,
1864                         ArrayRef<int64_t> permutation,
1865                         ArrayRef<NamedAttribute> attributes) {
1866   build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1867         attributes);
1868 }
1869 
1870 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1871   if (failed(parseDstStyleOp(
1872           parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1873             return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1874           })))
1875     return failure();
1876 
1877   OpBuilder builder(parser.getContext());
1878   buildIdentityRegion(builder, result.location, *result.addRegion(),
1879                       /*inputs=*/result.operands,
1880                       /*outputs=*/{});
1881   return success();
1882 }
1883 
1884 void TransposeOp::getAsmResultNames(
1885     function_ref<void(Value, StringRef)> setNameFn) {
1886   if (!getResults().empty())
1887     setNameFn(getResults().front(), "transposed");
1888 }
1889 
1890 void TransposeOp::print(OpAsmPrinter &p) {
1891   printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1892   printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
1893   p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
1894 }
1895 
1896 LogicalResult TransposeOp::verify() {
1897   ArrayRef<int64_t> permutationRef = getPermutation();
1898 
1899   if (!isPermutationVector(permutationRef))
1900     return emitOpError("permutation is not valid");
1901 
1902   auto inputType = getInput().getType();
1903   auto initType = getInit().getType();
1904 
1905   int64_t rank = inputType.getRank();
1906 
1907   if (rank != initType.getRank())
1908     return emitOpError() << "input rank " << rank
1909                          << " does not match init rank " << initType.getRank();
1910 
1911   if (rank != static_cast<int64_t>(permutationRef.size()))
1912     return emitOpError() << "size of permutation " << permutationRef.size()
1913                          << " does not match the argument rank " << rank;
1914 
1915   auto inputDims = inputType.getShape();
1916   auto initDims = initType.getShape();
1917 
1918   for (int64_t i = 0; i < rank; ++i) {
1919     int64_t inputDim = inputDims[permutationRef[i]];
1920     int64_t initDim = initDims[i];
1921 
1922     if (inputDim != initDim) {
1923       return emitOpError() << "dim(result, " << i << ") = " << initDim
1924                            << " doesn't match dim(input, permutation[" << i
1925                            << "]) = " << inputDim;
1926     }
1927   }
1928 
1929   return success();
1930 }
1931 
1932 SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
1933   int64_t rank = getInit().getType().getRank();
1934   return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1935 }
1936 
1937 ArrayAttr TransposeOp::getIndexingMaps() {
1938   Builder builder(getContext());
1939   int64_t rank = getInit().getType().getRank();
1940   return builder.getAffineMapArrayAttr(
1941       {inversePermutation(AffineMap::getPermutationMap(
1942            llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
1943        builder.getMultiDimIdentityMap(rank)});
1944 }
1945 
1946 void TransposeOp::getEffects(
1947     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1948         &effects) {
1949   getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1950 }
1951 
1952 Speculation::Speculatability TransposeOp::getSpeculatability() {
1953   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1954 }
1955 
1956 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1957                                 SmallVectorImpl<OpFoldResult> &result) {
1958   // Only the tensor type is supported.
1959   if (!isa<TensorType>(getInput().getType()))
1960     return failure();
1961 
1962   // Single dimension transpose.
1963   if (getPermutation().size() == 0) {
1964     result.push_back(getInput());
1965     return success();
1966   }
1967   // Identity permutation.
1968   if (isIdentityPermutation(getPermutation())) {
1969     result.push_back(getInput());
1970     return success();
1971   }
1972 
1973   return failure();
1974 }
1975 
1976 /// Fold transpose with transpose.
1977 struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1978   using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1979 
1980   LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1981                                 PatternRewriter &rewriter) const override {
1982     auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
1983     if (!defTransposeOp)
1984       return failure();
1985     ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
1986     ArrayRef<int64_t> perms = transposeOp.getPermutation();
1987     SmallVector<int64_t> foldedPerms;
1988     foldedPerms.reserve(perms.size());
1989     for (int64_t perm : perms)
1990       foldedPerms.push_back(defPerms[perm]);
1991 
1992     rewriter.replaceOpWithNewOp<TransposeOp>(
1993         transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
1994         foldedPerms);
1995     return success();
1996   }
1997 };
1998 
1999 /// This pattern canonicalize transpose by swapping the order of
2000 /// broadcast and transpose:
2001 ///   transpose(broadcast(input)) -> broadcast(transpose(input))
2002 struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2003   using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2004 
2005   LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2006                                 PatternRewriter &rewriter) const override {
2007     Value input = transposeOp.getInput();
2008     BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2009     if (!input.hasOneUse() || !broadcastOp)
2010       return failure();
2011 
2012     ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2013     ArrayRef<int64_t> perms = transposeOp.getPermutation();
2014 
2015     // Get new perms and new dimensions.
2016     SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2017     SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
2018     SmallVector<int64_t> resultDimensions;
2019     unsigned dimensionSize = dimensions.size();
2020     for (unsigned i = 0; i < dimensionSize; ++i)
2021       resultDimensions.push_back(invertPerm[dimensions[i]]);
2022 
2023     // Create transpose result.
2024     Value broadcastInput = broadcastOp.getInput();
2025     Location loc = transposeOp.getLoc();
2026     MLIRContext *ctx = transposeOp.getContext();
2027     SmallVector<OpFoldResult> dims;
2028     auto broadcastInputTy =
2029         mlir::cast<RankedTensorType>(broadcastInput.getType());
2030     unsigned inputRank = broadcastInputTy.getRank();
2031     for (unsigned i = 0; i < inputRank; ++i) {
2032       if (broadcastInputTy.isDynamicDim(i)) {
2033         dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
2034                            ->getResult(0));
2035       } else {
2036         dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2037                                         broadcastInputTy.getDimSize(i)));
2038       }
2039     }
2040     SmallVector<OpFoldResult> transposeResultShapes =
2041         applyPermutation(dims, resultPerms);
2042     Value transposeInit = rewriter.create<tensor::EmptyOp>(
2043         transposeOp.getLoc(), transposeResultShapes,
2044         broadcastInputTy.getElementType());
2045 
2046     // Create broadcast(transpose(input)).
2047     Value transposeResult =
2048         rewriter
2049             .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
2050                                  resultPerms)
2051             ->getResult(0);
2052     rewriter.replaceOpWithNewOp<BroadcastOp>(
2053         transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2054     return success();
2055   }
2056 };
2057 
2058 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2059                                               MLIRContext *context) {
2060   results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
2061 }
2062 
2063 //===----------------------------------------------------------------------===//
2064 // BroadcastOp
2065 //===----------------------------------------------------------------------===//
2066 
2067 void BroadcastOp::build(::mlir::OpBuilder &builder,
2068                         ::mlir::OperationState &result, Value input, Value init,
2069                         DenseI64ArrayAttr dimensions,
2070                         ArrayRef<NamedAttribute> attributes) {
2071   result.addOperands(input);
2072   result.addOperands(init);
2073   result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2074   result.addAttributes(attributes);
2075 
2076   // Add output types for `RankedTensorType` output arguments.
2077   Type initType = init.getType();
2078   if (llvm::isa<RankedTensorType>(initType))
2079     result.addTypes(initType);
2080 
2081   buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2082                       init);
2083 }
2084 
2085 void BroadcastOp::build(::mlir::OpBuilder &builder,
2086                         ::mlir::OperationState &result, Value input, Value init,
2087                         ArrayRef<int64_t> dimensions,
2088                         ArrayRef<NamedAttribute> attributes) {
2089   build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2090         attributes);
2091 }
2092 
2093 ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2094   if (failed(parseDstStyleOp(
2095           parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2096             return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2097           })))
2098     return failure();
2099 
2100   OpBuilder builder(parser.getContext());
2101   buildIdentityRegion(builder, result.location, *result.addRegion(),
2102                       /*inputs=*/result.operands,
2103                       /*outputs=*/{});
2104   return success();
2105 }
2106 
2107 void BroadcastOp::getAsmResultNames(
2108     function_ref<void(Value, StringRef)> setNameFn) {
2109   if (!getResults().empty())
2110     setNameFn(getResults().front(), "broadcasted");
2111 }
2112 
2113 void BroadcastOp::print(OpAsmPrinter &p) {
2114   printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2115   printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2116   p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2117 }
2118 
2119 LogicalResult BroadcastOp::verify() {
2120   ArrayRef<int64_t> dimensionsRef = getDimensions();
2121 
2122   auto inputType = getInput().getType();
2123   auto initType = getInit().getType();
2124 
2125   int64_t inputRank = inputType.getRank();
2126   int64_t initRank = initType.getRank();
2127 
2128   auto inputShape = inputType.getShape();
2129   auto initShape = initType.getShape();
2130 
2131   if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2132     return emitOpError() << "input rank plus added dimensions does not "
2133                             "match init rank. input rank: "
2134                          << inputRank
2135                          << ", dimensions size: " << dimensionsRef.size()
2136                          << ", init rank: " << initRank;
2137 
2138   for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2139     if (dim < 0 || dim >= initRank)
2140       return emitOpError() << "dimension " << idx
2141                            << " is out of range. expected range: [0, "
2142                            << initRank - 1 << "], got: " << dim;
2143   }
2144 
2145   // Mapping from input dims to init dims.
2146   SmallVector<int64_t> dimMap;
2147   for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2148     if (!llvm::is_contained(dimensionsRef, dim))
2149       dimMap.push_back(dim);
2150   }
2151 
2152   for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2153     // This dimensions is mapped from the input. Init and input dims should
2154     // match.
2155     if (inputShape[inputDimIdx] != initShape[initDimIdx])
2156       return emitOpError() << "input dim " << inputDimIdx
2157                            << " should match init dim " << initDimIdx
2158                            << ". input: " << inputShape[inputDimIdx]
2159                            << ", init: " << initShape[initDimIdx];
2160   }
2161 
2162   return success();
2163 }
2164 
2165 SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2166   int64_t rank = getInit().getType().getRank();
2167   return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2168 }
2169 
2170 ArrayAttr BroadcastOp::getIndexingMaps() {
2171   Builder builder(getContext());
2172   int64_t rank = getInit().getType().getRank();
2173   return builder.getAffineMapArrayAttr(
2174       {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2175        builder.getMultiDimIdentityMap(rank)});
2176 }
2177 
2178 void BroadcastOp::getEffects(
2179     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2180         &effects) {
2181   getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2182 }
2183 
2184 Speculation::Speculatability BroadcastOp::getSpeculatability() {
2185   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2186 }
2187 
2188 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2189                                               MLIRContext *context) {
2190   results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2191 }
2192 
2193 //===----------------------------------------------------------------------===//
2194 // YieldOp
2195 //===----------------------------------------------------------------------===//
2196 
2197 void linalg::YieldOp::print(OpAsmPrinter &p) {
2198   if (getNumOperands() > 0)
2199     p << ' ' << getOperands();
2200   p.printOptionalAttrDict((*this)->getAttrs());
2201   if (getNumOperands() > 0)
2202     p << " : " << getOperandTypes();
2203 }
2204 
2205 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2206   SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2207   SmallVector<Type, 2> types;
2208   SMLoc loc = parser.getCurrentLocation();
2209   return failure(parser.parseOperandList(opInfo) ||
2210                  parser.parseOptionalAttrDict(result.attributes) ||
2211                  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2212                  parser.resolveOperands(opInfo, types, loc, result.operands));
2213 }
2214 
2215 // Check the operand number and types must match the element types of the
2216 // LinalgOp interface's shaped operands.
2217 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2218   if (op.getNumOperands() != linalgOp.getNumDpsInits())
2219     return op.emitOpError("expected number of yield values (")
2220            << op.getNumOperands()
2221            << ") to match the number of inits / outs operands of the enclosing "
2222            << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2223 
2224   for (OpOperand &opOperand : op->getOpOperands()) {
2225     OpOperand *outputOperand =
2226         linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2227     Type elementType = outputOperand->get().getType();
2228     if (isa<MemRefType, RankedTensorType>(elementType))
2229       elementType = getElementTypeOrSelf(outputOperand->get().getType());
2230     if (opOperand.get().getType() != elementType)
2231       return op.emitOpError("type of yield operand ")
2232              << (opOperand.getOperandNumber() + 1) << " ("
2233              << opOperand.get().getType() << ") doesn't match "
2234              << "the element type of the enclosing linalg.generic op ("
2235              << elementType << ")";
2236   }
2237   return success();
2238 }
2239 
2240 LogicalResult linalg::YieldOp::verify() {
2241   auto *parentOp = (*this)->getParentOp();
2242   if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2243     return emitOpError("expected single non-empty parent region");
2244 
2245   if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2246     return verifyYield(*this, linalgOp);
2247 
2248   return emitOpError("expected parent op with LinalgOp interface");
2249 }
2250 
2251 //===----------------------------------------------------------------------===//
2252 // IndexOp
2253 //===----------------------------------------------------------------------===//
2254 
2255 LogicalResult IndexOp::verify() {
2256   auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2257   if (!linalgOp)
2258     return emitOpError("expected parent op with LinalgOp interface");
2259   if (linalgOp.getNumLoops() <= getDim())
2260     return emitOpError("expected dim (")
2261            << getDim() << ") to be lower than the number of loops ("
2262            << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2263   return success();
2264 }
2265 
2266 /////// Operations corresponding to library calls defined with Tablegen ////////
2267 
2268 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2269 
2270 #define GET_OP_CLASSES
2271 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2272 
2273 #define GET_OP_CLASSES
2274 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2275 
2276 AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2277                                              unsigned rank,
2278                                              MLIRContext *context) {
2279   if (maybeMap)
2280     return *maybeMap;
2281   if (rank == 0)
2282     return AffineMap::get(context);
2283   return AffineMap::getMultiDimIdentityMap(rank, context);
2284 }
2285 
2286 SmallVector<AffineExpr, 4>
2287 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2288                                  MLIRContext *context) {
2289   SmallVector<AffineExpr, 4> res;
2290   res.reserve(num);
2291   for (unsigned i = 0; i < num; ++i)
2292     res.push_back(getAffineDimExpr(startIdx++, context));
2293   return res;
2294 }
2295 
2296 SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
2297                                                 ArrayRef<AffineExpr> b) {
2298   auto rangeA = llvm::make_range(a.begin(), a.end());
2299   auto rangeB = llvm::make_range(b.begin(), b.end());
2300   auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2301   return llvm::to_vector<4>(concatRanges);
2302 }
2303 
2304 static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2305   if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2306     ss << "view";
2307     for (auto size : memref.getShape())
2308       if (size < 0)
2309         ss << "sx";
2310       else
2311         ss << size << "x";
2312     if (failed(appendMangledType(ss, memref.getElementType())))
2313       return failure();
2314     if (auto as = memref.getMemorySpace()) {
2315       if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2316         ss << "as" << attr.getInt();
2317       else
2318         return failure();
2319     }
2320     return success();
2321   }
2322   if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2323     ss << "vector";
2324     llvm::interleave(
2325         vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2326     if (failed(appendMangledType(ss, vec.getElementType())))
2327       return failure();
2328     return success();
2329   }
2330   if (t.isSignlessIntOrIndexOrFloat()) {
2331     ss << t;
2332     return success();
2333   }
2334   return failure();
2335 }
2336 
2337 std::string mlir::linalg::generateLibraryCallName(Operation *op) {
2338   assert(isa<LinalgOp>(op));
2339   std::string name(op->getName().getStringRef().str());
2340   std::string fun = "";
2341   for (NamedAttribute kv : op->getAttrs()) {
2342     if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2343       fun = stringifyEnum(ufa.getValue()).str() + "_";
2344     } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2345       fun = stringifyEnum(bfa.getValue()).str() + "_";
2346     }
2347   }
2348   name.reserve(128);
2349   std::replace(name.begin(), name.end(), '.', '_');
2350   llvm::raw_string_ostream ss(name);
2351   ss << "_" << fun;
2352   for (Type t : op->getOperandTypes()) {
2353     if (failed(appendMangledType(ss, t)))
2354       return std::string();
2355     ss << "_";
2356   }
2357   name.pop_back();
2358   return name;
2359 }
2360 
2361 //===----------------------------------------------------------------------===//
2362 // Canonicalizers and Folders.
2363 //===----------------------------------------------------------------------===//
2364 
2365 namespace {
2366 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2367   using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2368 
2369   LogicalResult matchAndRewrite(LinalgOp op,
2370                                 PatternRewriter &rewriter) const override {
2371     for (OpOperand &opOperand : op->getOpOperands()) {
2372       // Linalg "inputs" may be either tensor or memref type.
2373       // tensor<0xelt_type> is a convention that may not always mean
2374       // "0 iterations". Only erase in cases we see memref<...x0x...>.
2375       auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2376       if (!mt)
2377         continue;
2378       if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2379         rewriter.eraseOp(op);
2380         return success();
2381       }
2382     }
2383     return failure();
2384   }
2385 };
2386 
2387 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2388 /// result that is more static than the linalg op.
2389 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2390   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2391 
2392   LogicalResult matchAndRewrite(tensor::CastOp castOp,
2393                                 PatternRewriter &rewriter) const override {
2394     if (!tensor::canFoldIntoProducerOp(castOp))
2395       return failure();
2396 
2397     auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2398     if (!linalgOp)
2399       return failure();
2400 
2401     // Cast can be in conditionally reachable region, if which case folding will
2402     // generate invalid code. Only conservatively fold ops in same block for
2403     // now.
2404     if (castOp->getBlock() != linalgOp->getBlock())
2405       return failure();
2406 
2407     OpBuilder::InsertionGuard guard(rewriter);
2408     rewriter.setInsertionPoint(linalgOp);
2409 
2410     Location loc = linalgOp.getLoc();
2411     OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2412     unsigned resultNumber = resultValue.getResultNumber();
2413     auto resultType =
2414         llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2415     // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2416     // going from a more dynamic shape to a less dynamic shape. If the producer
2417     // for this cast, i.e. producer of the out operand, is also an operation
2418     // that folds with tensor.cast consumer (like this pattern), the cast will
2419     // continue to propagate as far up the stack as it can go.
2420     OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2421     Value newOperand =
2422         rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
2423     SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2424     SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2425                                       linalgOp.getDpsInits().end());
2426     outputOperands[resultNumber] = newOperand;
2427     newOperands.append(outputOperands.begin(), outputOperands.end());
2428 
2429     SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2430                                   linalgOp->result_type_end());
2431     resultTypes[resultNumber] = resultType;
2432     Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2433 
2434     // Create a tensor.cast operation back to the original type.
2435     Value castBack = rewriter.create<tensor::CastOp>(
2436         loc, resultValue.getType(), newOp->getResult(resultNumber));
2437 
2438     SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2439     results[resultNumber] = castBack;
2440     rewriter.replaceOp(linalgOp, results);
2441     rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2442     return success();
2443   }
2444 };
2445 
2446 /// For each of the operand in `operands` this function maps the static sizes of
2447 /// dimensions to their affine dim expressions.
2448 static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2449                         llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2450   for (OpOperand &opOperand : operands) {
2451     if (linalgOp.isScalar(&opOperand))
2452       continue;
2453     Value src = opOperand.get();
2454     auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2455     auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2456 
2457     // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2458     // `tensor.cast` operation and source of the cast operation has a static
2459     // shape, then assign it to the `sourceShape`.
2460     auto *parentOp = src.getDefiningOp();
2461     ArrayRef<int64_t> sourceShape = sourceType.getShape();
2462     if (parentOp) {
2463       if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2464         Value castSource = castOp.getSource();
2465         auto castSourceType =
2466             llvm::dyn_cast<RankedTensorType>(castSource.getType());
2467         if (castSourceType && castSourceType.hasStaticShape())
2468           sourceShape = castSourceType.getShape();
2469       }
2470     }
2471 
2472     // If the source shape's dimension has a static shape, map the affine dim
2473     // expression to the known static size.
2474     for (unsigned i = 0; i < sourceShape.size(); i++) {
2475       if (sourceType.isDynamicDim(i))
2476         continue;
2477       if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2478         affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2479     }
2480   }
2481 }
2482 
2483 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2484 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2485 /// their result types is stored in `resultTypes`. If `opOperand` requires no
2486 /// change then `changeNeeded` is false and same operand is added in the
2487 /// `newOperands` list.
2488 static void createNewOperandWithStaticSizes(
2489     Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2490     llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2491     SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2492     bool &changeNeeded) {
2493   Value src = opOperand->get();
2494   newOperands.push_back(src);
2495   if (linalgOp.isScalar(opOperand))
2496     return;
2497   auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2498   Type resultType = sourceType;
2499   if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2500     resultTypes.push_back(resultType);
2501     return;
2502   }
2503   ArrayRef<int64_t> sourceShape = sourceType.getShape();
2504   AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2505   SmallVector<int64_t> newShape;
2506   // If operand is updated with new shape, `newOperandNeeded` will be
2507   // true.
2508   bool newOperandNeeded = false;
2509   for (unsigned i = 0; i < sourceShape.size(); i++) {
2510     int64_t dimShape = sourceShape[i];
2511     AffineExpr dimExpr = sourceMap.getResult(i);
2512     if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2513       newShape.push_back(dimShape);
2514       continue;
2515     }
2516     // Dimension has a dynamic shape and corresponding affine dim
2517     // expression is present in the map. So assign the size for the
2518     // given affine dim expression to the dimension.
2519     newShape.push_back(affineExprToSize[dimExpr]);
2520     newOperandNeeded = true;
2521   }
2522   resultType = RankedTensorType::get(newShape, sourceType.getElementType());
2523   if (newOperandNeeded) {
2524     changeNeeded = true;
2525     // Get the new operand value given its size and element type by
2526     // casting it.
2527     Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
2528     unsigned index = opOperand->getOperandNumber();
2529     newOperands[index] = newOperand;
2530   }
2531   if (linalgOp.isDpsInit(opOperand))
2532     resultTypes.push_back(resultType);
2533 }
2534 
2535 /// Static shapes for the operands can be inferred if any one of the operands
2536 /// have a static shape. This can be done by referring to the affine dim
2537 /// expressions for the operand.
2538 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2539   using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2540 
2541   LogicalResult matchAndRewrite(LinalgOp linalgOp,
2542                                 PatternRewriter &rewriter) const override {
2543     if (!linalgOp.hasPureTensorSemantics())
2544       return failure();
2545 
2546     // Maps must be projected permutations.
2547     if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2548           return !map.isProjectedPermutation();
2549         }))
2550       return failure();
2551 
2552     // Maps affine dim expressions to the static size of that dimension.
2553     llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2554     Location loc = linalgOp.getLoc();
2555 
2556     // For each of the affine dim expression, check if the size is known. If
2557     // known add that in the map.
2558     populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2559 
2560     SmallVector<Value> newOperands;
2561     SmallVector<Type> resultTypes;
2562 
2563     // `changeNeeded` is `false` if the operands of `linalgOp` require no
2564     // change in their types.
2565     bool changeNeeded = false;
2566     newOperands.reserve(linalgOp->getNumOperands());
2567     resultTypes.reserve(linalgOp.getNumDpsInits());
2568 
2569     // Iterate over all the operands and update the static sizes.
2570     for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2571       createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2572                                       affineExprToSize, linalgOp, newOperands,
2573                                       resultTypes, changeNeeded);
2574     }
2575 
2576     // If the generic op has all the required static information, no
2577     // canonicalization needed.
2578     if (!changeNeeded)
2579       return failure();
2580 
2581     // Clone op.
2582     Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2583     SmallVector<Value> replacements;
2584     replacements.reserve(newOp->getNumResults());
2585     for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2586       Value newResult = std::get<1>(it);
2587       Value oldResult = std::get<0>(it);
2588       Type newType = newResult.getType();
2589       Type oldType = oldResult.getType();
2590       replacements.push_back(
2591           (newType != oldType)
2592               ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
2593               : newResult);
2594     }
2595     rewriter.replaceOp(linalgOp, replacements);
2596     return success();
2597   }
2598 };
2599 
2600 } // namespace
2601 
2602 // All named ops canonicalizers and folders are auto-generated in the
2603 // .cpp.inc.
2604 
2605 //===----------------------------------------------------------------------===//
2606 // SoftmaxOp
2607 //===----------------------------------------------------------------------===//
2608 
2609 LogicalResult SoftmaxOp::verify() {
2610   ShapedType inputType = getInputOperandType();
2611   ShapedType outputType = getOutputOperandType();
2612 
2613   ArrayRef<int64_t> inputShape = inputType.getShape();
2614   ArrayRef<int64_t> outputShape = outputType.getShape();
2615   if (failed(verifyCompatibleShape(inputShape, outputShape)))
2616     return emitOpError("incompatible output shape");
2617 
2618   int64_t inputRank = getInputOperandRank();
2619   int64_t dimension = getDimension();
2620   if ((dimension < 0) || (dimension >= inputRank))
2621     return emitOpError("incorrect dimension specified");
2622 
2623   return success();
2624 }
2625 
2626 SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2627   int64_t operandRank = getInputOperandRank();
2628   SmallVector<Range> loopBounds(operandRank);
2629   Location loc = getLoc();
2630   Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
2631   Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
2632   Value source = getInput();
2633   for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2634     loopBounds[dim].offset = zero;
2635     loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2636     loopBounds[dim].stride = one;
2637   }
2638   return loopBounds;
2639 }
2640 
2641 SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2642   SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2643                                                  utils::IteratorType::parallel);
2644   iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2645   return iteratorTypes;
2646 }
2647 
2648 FailureOr<TilingResult>
2649 SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2650                                   ArrayRef<OpFoldResult> offsets,
2651                                   ArrayRef<OpFoldResult> sizes) {
2652   int64_t rank = getInputOperandRank();
2653   auto oneAttr = builder.getI64IntegerAttr(1);
2654   SmallVector<OpFoldResult> strides(rank, oneAttr);
2655   SmallVector<Value> tiledOperands;
2656   Operation *inputSlice =
2657       getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2658   if (!inputSlice) {
2659     return emitOpError("failed to compute input slice");
2660   }
2661   tiledOperands.emplace_back(inputSlice->getResult(0));
2662   Operation *outputSlice =
2663       getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2664   if (!outputSlice) {
2665     return emitOpError("failed to compute output slice");
2666   }
2667   tiledOperands.emplace_back(outputSlice->getResult(0));
2668 
2669   SmallVector<Type, 4> resultTypes;
2670   if (hasPureTensorSemantics())
2671     resultTypes.push_back(tiledOperands[1].getType());
2672   Operation *tiledOp =
2673       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2674 
2675   return TilingResult{
2676       {tiledOp},
2677       SmallVector<Value>(tiledOp->getResults()),
2678       llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2679 }
2680 
2681 LogicalResult SoftmaxOp::getResultTilePosition(
2682     OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2683     ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2684     SmallVector<OpFoldResult> &resultSizes) {
2685   if (resultNumber == 0) {
2686     resultOffsets.assign(offsets.begin(), offsets.end());
2687     resultSizes.assign(sizes.begin(), sizes.end());
2688     return success();
2689   }
2690   return failure();
2691 }
2692 
2693 // cast(dynamic) -> static.
2694 LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2695   return memref::foldMemRefCast(*this);
2696 }
2697 
2698 LogicalResult
2699 SoftmaxOp::reifyResultShapes(OpBuilder &b,
2700                              ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2701   SmallVector<OpFoldResult> shapes;
2702   Location loc = getOperation()->getLoc();
2703   IRRewriter rewriter(b);
2704   auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2705   auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2706   for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2707     if (!outputShapedType.isDynamicDim(dim)) {
2708       // Static dim: Return IntegerAttr.
2709       shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2710     } else {
2711       // Dynamic dim: Return Value.
2712       OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2713       shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2714     }
2715   }
2716   reifiedReturnShapes.emplace_back(std::move(shapes));
2717   return success();
2718 }
2719 
2720 void SoftmaxOp::getEffects(
2721     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2722         &effects) {
2723   for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2724     if (!llvm::isa<MemRefType>(operand.getType()))
2725       continue;
2726     effects.emplace_back(MemoryEffects::Read::get(),
2727                          &getOperation()->getOpOperand(index), /*stage=*/0,
2728                          /*effectOnFullRegion=*/true,
2729                          SideEffects::DefaultResource::get());
2730   }
2731 
2732   for (OpOperand &operand : getDpsInitsMutable()) {
2733     if (!llvm::isa<MemRefType>(operand.get().getType()))
2734       continue;
2735     effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2736                          /*effectOnFullRegion=*/true,
2737                          SideEffects::DefaultResource::get());
2738     effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2739                          /*effectOnFullRegion=*/true,
2740                          SideEffects::DefaultResource::get());
2741   }
2742 }
2743 
2744 // Helper functions for softmax decomposition.
2745 // @{
2746 
2747 // Helper function to produce the iterator types (reduction or parallel) and
2748 // affine maps for the iterators used in the decomposition of softmax.
2749 // This method creates:
2750 // If allParallel == true:
2751 // - iterator type: {parallel, ..., parallel}
2752 // - affine maps:
2753 // -- identity with inputRank dimensions.
2754 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2755 //    where N == inputRank.
2756 //
2757 // If allParallel == false:
2758 // - iterator type at dim(i) == parallel for i != \p dim and
2759 //   dim(dim) == reduction.
2760 // - affine map:
2761 // -- identity with inputRank dimensions.
2762 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2763 //    where N == inputRank.
2764 static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2765 computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank,
2766                                     int64_t dim, bool allParallel = false) {
2767   SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2768                                                  utils::IteratorType::parallel);
2769   if (!allParallel)
2770     iteratorTypes[dim] = utils::IteratorType::reduction;
2771   MLIRContext *ctxt = builder.getContext();
2772   auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2773   SmallVector<AffineExpr, 2> affineExprs;
2774   for (int i = 0; i < inputRank; i++) {
2775     if (i != dim)
2776       affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2777   }
2778   auto reductionMap =
2779       AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2780   SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2781   return std::make_tuple(iteratorTypes, indexingMaps);
2782 }
2783 
2784 // Helper function to produce a linalg.generic that computes a reduction on
2785 // dimension \p dim with the operation type \p T.
2786 template <typename T>
2787 static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2788                     int64_t dim) {
2789   auto inputType = cast<ShapedType>(input.getType());
2790   ArrayRef<int64_t> inputShape = inputType.getShape();
2791   int64_t inputRank = inputShape.size();
2792   auto [iteratorTypes, indexingMaps] =
2793       computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2794   assert(indexingMaps.size() == 2 &&
2795          "We should have two maps: 1 for the input, 1 for the output");
2796   assert(indexingMaps[0].isIdentity() && "input map should be identity");
2797 
2798   auto genericOp = builder.create<linalg::GenericOp>(
2799       loc, output.getType(), input, output, indexingMaps, iteratorTypes,
2800       [&](OpBuilder &b, Location loc, ValueRange args) {
2801         Value result = b.create<T>(loc, args[0], args[1]);
2802         b.create<linalg::YieldOp>(loc, result);
2803       });
2804   return genericOp.getResult(0);
2805 }
2806 
2807 /// Produce a linalg generic that computes the second step of the softmax
2808 /// decomposition: res = exp(input - max), where \p max is the max of \p input
2809 /// on dimension \p dim.
2810 static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2811                               Value max, Value output, int64_t dim) {
2812   auto inputType = cast<ShapedType>(input.getType());
2813   ArrayRef<int64_t> inputShape = inputType.getShape();
2814   int64_t inputRank = inputShape.size();
2815   auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2816       builder, inputRank, dim, /*allParallel=*/true);
2817   assert(indexingMaps.size() == 2 && "We should have one map for each input");
2818   assert(indexingMaps[0].isIdentity() && "input map should be identity");
2819   // Add the affine map for the output argument.
2820   indexingMaps.push_back(indexingMaps[0]);
2821   auto genericOp = builder.create<linalg::GenericOp>(
2822       loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
2823       iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2824         Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2825         Value result = b.create<math::ExpOp>(loc, diff);
2826         b.create<linalg::YieldOp>(loc, result);
2827       });
2828   return genericOp.getResult(0);
2829 }
2830 
2831 /// Produce a linalg generic that computes the final step of the softmax
2832 /// decomposition.
2833 /// \returns  linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2834 ///   yield  n / d
2835 /// }
2836 static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
2837                         Value denominator, Value output, int64_t dim) {
2838   auto inputType = cast<ShapedType>(numerator.getType());
2839   ArrayRef<int64_t> inputShape = inputType.getShape();
2840   int64_t inputRank = inputShape.size();
2841   auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2842       builder, inputRank, dim, /*allParallel=*/true);
2843   assert(indexingMaps.size() == 2 &&
2844          "We should have one map for each input (2)");
2845   assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2846   // Add the affine map for the output tensor.
2847   indexingMaps.push_back(indexingMaps[0]);
2848   auto genericOp = builder.create<linalg::GenericOp>(
2849       loc, numerator.getType(), ValueRange{numerator, denominator}, output,
2850       indexingMaps, iteratorTypes,
2851       [&](OpBuilder &b, Location loc, ValueRange args) {
2852         Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2853         b.create<linalg::YieldOp>(loc, result);
2854       });
2855   return genericOp.getResult(0);
2856 }
2857 // @} End helper functions for softmax decomposition.
2858 
2859 /// Given an N-dimensional tensor x, this method converts
2860 /// softmax(x) to the following sequence of operations:
2861 ///
2862 /// 1. Compute the max of x along dimension d. This results
2863 ///    in a N-1 dimensional tensor m.
2864 ///    m = max(x, dim = d)
2865 ///
2866 /// 2. Subtract a broadcasted m from x and exponentiate. This results in
2867 ///    a N dimensional tensor z.
2868 ///    z = exp(x - m)
2869 ///
2870 /// 3. Compute the sum of z along dimension d. This results in
2871 ///    a N-1 dimensional tensor l.
2872 ///    l = sum(z, dim = d)
2873 ///
2874 /// 4. Divide z and l. This gives the N-dimensional softmax.
2875 ///    softmax = z / l
2876 ///
2877 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
2878   OpBuilder::InsertionGuard guard(b);
2879   b.setInsertionPoint(*this);
2880   Location loc = getLoc();
2881   Value input = getInput();
2882   ShapedType inputType = getInputOperandType();
2883   Type elementType = inputType.getElementType();
2884   int64_t reductionDim = getDimension();
2885   SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
2886   Value output = getOutput();
2887   dims.erase(dims.begin() + reductionDim);
2888   // Step 1: Compute max along dim.
2889   Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
2890   Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
2891                                                  elementType, b, loc,
2892                                                  /*useOnlyFiniteValue=*/true);
2893   Value neutralForMaxFInit =
2894       b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
2895           .result();
2896   Value max =
2897       reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2898 
2899   // Step 2: Subtract max from input and exponentiate.
2900   Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
2901 
2902   // Step 3: Compute sum along dim.
2903   Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
2904                                        b, loc, /*useOnlyFiniteValue=*/true);
2905   Value zeroInit =
2906       b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
2907   Value denominator =
2908       reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2909 
2910   // Step 4: Compute softmax.
2911   Value result =
2912       buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2913   return SmallVector<Value>{result};
2914 }
2915 
2916 //===----------------------------------------------------------------------===//
2917 // WinogradFilterTransformOp
2918 //===----------------------------------------------------------------------===//
2919 
2920 LogicalResult WinogradFilterTransformOp::verify() {
2921   auto filterType = cast<ShapedType>(getFilter().getType());
2922   ArrayRef<int64_t> filterShape = filterType.getShape();
2923   int64_t filterH = filterShape[getFilterHDim()];
2924   int64_t filterW = filterShape[getFilterWDim()];
2925   int64_t r = getR();
2926   int64_t m = getM();
2927 
2928   if (filterH != r && filterH != 1)
2929     return emitOpError("expect filter height either equals to r or 1");
2930   if (filterW != r && filterW != 1)
2931     return emitOpError("expect filter width either equals to r or 1");
2932   if (filterH == 1 && filterW == 1)
2933     return emitOpError("expect either filter height or width equals to r");
2934 
2935   SmallVector<int64_t> expectedOutputShape;
2936   expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
2937   expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
2938   expectedOutputShape.push_back(filterShape[getFilterCDim()]);
2939   expectedOutputShape.push_back(filterShape[getFilterFDim()]);
2940 
2941   auto outputType = cast<ShapedType>(getOutput().getType());
2942   ArrayRef<int64_t> outputShape = outputType.getShape();
2943   if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
2944     return emitOpError("the output shape is not expected");
2945   }
2946   return success();
2947 }
2948 
2949 SmallVector<Range>
2950 WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
2951   Location loc = getLoc();
2952   IntegerAttr zeroAttr = builder.getIndexAttr(0);
2953   IntegerAttr oneAttr = builder.getIndexAttr(1);
2954   Value filter = getFilter();
2955   int64_t filterRank = getFilterOperandRank();
2956   SmallVector<Range> loopBounds(filterRank);
2957   for (unsigned dim = 0; dim < filterRank; ++dim) {
2958     loopBounds[dim].offset = zeroAttr;
2959     loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
2960     loopBounds[dim].stride = oneAttr;
2961   }
2962   return loopBounds;
2963 }
2964 
2965 SmallVector<utils::IteratorType>
2966 WinogradFilterTransformOp::getLoopIteratorTypes() {
2967   int64_t filterRank = getFilterOperandRank();
2968   SmallVector<utils::IteratorType> iteratorTypes(filterRank,
2969                                                  utils::IteratorType::parallel);
2970   return iteratorTypes;
2971 }
2972 
2973 LogicalResult WinogradFilterTransformOp::getResultTilePosition(
2974     OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2975     ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2976     SmallVector<OpFoldResult> &resultSizes) {
2977   IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
2978   ShapedType filterType = getFilterOperandType();
2979   ArrayRef<int64_t> filterShape = filterType.getShape();
2980   int64_t filterH = filterShape[getFilterHDim()];
2981   int64_t filterW = filterShape[getFilterWDim()];
2982   int64_t m = getM();
2983   int64_t r = getR();
2984   int64_t alpha = m + r - 1;
2985   int64_t alphaH = filterH != 1 ? alpha : 1;
2986   int64_t alphaW = filterW != 1 ? alpha : 1;
2987   IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
2988   IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
2989 
2990   resultOffsets.append(
2991       {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
2992   resultSizes.append(
2993       {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
2994 
2995   return success();
2996 }
2997 
2998 /// Implement tiling for winograd_filter_transform
2999 /// The input of winograd_filter_transform is (F, KH, KW, C).
3000 /// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3001 /// Users can specify the tile sizes of F and C.
3002 /// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3003 /// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3004 FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3005     OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3006     ArrayRef<OpFoldResult> sizes) {
3007   IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3008   IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3009   ShapedType filterType = getFilterOperandType();
3010   ArrayRef<int64_t> filterShape = filterType.getShape();
3011   int64_t filterH = filterShape[getFilterHDim()];
3012   int64_t filterW = filterShape[getFilterWDim()];
3013   IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3014   IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3015   SmallVector<Value> tiledOperands;
3016   SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3017 
3018   sliceOffsets.append(
3019       {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3020   sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3021                      sizes[getFilterCDim()]});
3022   int64_t filterRank = getFilterOperandRank();
3023   SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3024   Location loc = getLoc();
3025   auto filterSlice = builder.create<tensor::ExtractSliceOp>(
3026       loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3027   tiledOperands.emplace_back(filterSlice);
3028 
3029   SmallVector<OpFoldResult> resultOffsets, resultSizes;
3030   if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3031                                    resultSizes)))
3032     return failure();
3033 
3034   int64_t outputRank = getOutputOperandRank();
3035   SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3036   auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3037       loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3038   tiledOperands.emplace_back(outputSlice);
3039 
3040   SmallVector<Type> resultTypes;
3041   resultTypes.push_back(tiledOperands[1].getType());
3042   Operation *tiledOp =
3043       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3044 
3045   return TilingResult{
3046       {tiledOp},
3047       SmallVector<Value>(tiledOp->getResults()),
3048       llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3049 }
3050 
3051 //===----------------------------------------------------------------------===//
3052 // WinogradInputTransformOp
3053 //===----------------------------------------------------------------------===//
3054 
3055 LogicalResult WinogradInputTransformOp::verify() {
3056   auto inputType = cast<ShapedType>(getInput().getType());
3057   ArrayRef<int64_t> inputShape = inputType.getShape();
3058   int64_t inputH = inputShape[getInputHDim()];
3059   int64_t inputW = inputShape[getInputWDim()];
3060   int m = getM();
3061   int r = getR();
3062   int64_t tileSize = m + r - 1;
3063 
3064   auto outputType = cast<ShapedType>(getOutput().getType());
3065   ArrayRef<int64_t> outputShape = outputType.getShape();
3066   bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3067   bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3068 
3069   SmallVector<int64_t> expectedOutputShape(6, inputH);
3070   if (ShapedType::isDynamic(inputH)) {
3071     expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3072     expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3073   } else {
3074     expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3075     expectedOutputShape[getOutputTileHDim()] =
3076         leftTransform ? (inputH - (r - 1)) / m : inputH;
3077   }
3078   if (ShapedType::isDynamic(inputW)) {
3079     expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3080     expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3081   } else {
3082     expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3083     expectedOutputShape[getOutputTileWDim()] =
3084         rightTransform ? (inputW - (r - 1)) / m : inputW;
3085   }
3086   expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3087   expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3088 
3089   if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3090     return emitOpError("the output shape is not expected");
3091   }
3092   return success();
3093 }
3094 
3095 SmallVector<Range>
3096 WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3097   Location loc = getLoc();
3098   IntegerAttr zeroAttr = builder.getIndexAttr(0);
3099   IntegerAttr oneAttr = builder.getIndexAttr(1);
3100   Value output = getOutput();
3101   int64_t outputRank = getOutputOperandRank();
3102   SmallVector<Range> loopBounds(outputRank);
3103   for (unsigned dim = 0; dim < outputRank; ++dim) {
3104     loopBounds[dim].offset = zeroAttr;
3105     // alphaH, alphaW, tileH, tileW, N, C
3106     loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3107     loopBounds[dim].stride = oneAttr;
3108   }
3109   return loopBounds;
3110 }
3111 
3112 SmallVector<utils::IteratorType>
3113 WinogradInputTransformOp::getLoopIteratorTypes() {
3114   int64_t outputRank = getOutputOperandRank();
3115   SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3116                                                  utils::IteratorType::parallel);
3117   return iteratorTypes;
3118 }
3119 
3120 LogicalResult WinogradInputTransformOp::getResultTilePosition(
3121     OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3122     ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3123     SmallVector<OpFoldResult> &resultSizes) {
3124   IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3125   ShapedType outputType = getOutputOperandType();
3126   ArrayRef<int64_t> outputShape = outputType.getShape();
3127   int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3128   int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3129 
3130   int64_t m = getM();
3131   int64_t r = getR();
3132   int64_t alpha = m + r - 1;
3133   int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3134   int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3135 
3136   IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3137   IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3138 
3139   resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3140                         offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3141                         offsets[getOutputCDim()]});
3142   resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3143                       sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3144                       sizes[getOutputCDim()]});
3145 
3146   return success();
3147 }
3148 
3149 /// Implement tiling for winograd_input_transform
3150 /// The input of winograd_input_transform is (N, H, W, C).
3151 /// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3152 /// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3153 /// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3154 /// the values for the sizes of tileH, tileW, N, C for one tile.
3155 FailureOr<TilingResult>
3156 WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3157                                                  ArrayRef<OpFoldResult> offsets,
3158                                                  ArrayRef<OpFoldResult> sizes) {
3159   IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3160   int64_t m = getM();
3161   int64_t r = getR();
3162 
3163   ShapedType outputType = getOutputOperandType();
3164   ArrayRef<int64_t> outputShape = outputType.getShape();
3165   int64_t alphaH = outputShape[getOutputAlphaHDim()];
3166   int64_t alphaW = outputShape[getOutputAlphaWDim()];
3167 
3168   Location loc = getLoc();
3169   MLIRContext *context = builder.getContext();
3170   auto identityAffineMap =
3171       AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3172   auto offsetAffineMap =
3173       AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3174   Value mappedOffsetH = affine::makeComposedAffineApply(
3175       builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3176       offsets[getOutputTileHDim()]);
3177   Value mappedOffsetW = affine::makeComposedAffineApply(
3178       builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3179       offsets[getOutputTileWDim()]);
3180   auto sizeAffineMap = AffineMap::get(
3181       1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3182   Value mappedSizeH = affine::makeComposedAffineApply(
3183       builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3184   Value mappedSizeW = affine::makeComposedAffineApply(
3185       builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3186 
3187   SmallVector<Value> tiledOperands;
3188   SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3189 
3190   OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3191   OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3192   sliceOffsets.append(
3193       {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3194   OpFoldResult sizeH =
3195       alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3196   OpFoldResult sizeW =
3197       alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3198   sliceSizes.append(
3199       {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3200   int64_t inputRank = getInputOperandRank();
3201   SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3202   auto inputSlice = builder.create<tensor::ExtractSliceOp>(
3203       loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3204   tiledOperands.emplace_back(inputSlice);
3205 
3206   SmallVector<OpFoldResult> resultOffsets, resultSizes;
3207   if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3208                                    resultSizes)))
3209     return failure();
3210 
3211   int64_t outputRank = getOutputOperandRank();
3212   SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3213   auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3214       loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3215   tiledOperands.emplace_back(outputSlice);
3216 
3217   SmallVector<Type> resultTypes;
3218   resultTypes.push_back(tiledOperands[1].getType());
3219   Operation *tiledOp =
3220       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3221 
3222   return TilingResult{
3223       {tiledOp},
3224       SmallVector<Value>(tiledOp->getResults()),
3225       llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3226 }
3227 
3228 //===----------------------------------------------------------------------===//
3229 // WinogradOutputTransformOp
3230 //===----------------------------------------------------------------------===//
3231 
3232 LogicalResult WinogradOutputTransformOp::verify() {
3233   auto valueType = cast<ShapedType>(getValue().getType());
3234   ArrayRef<int64_t> valueShape = valueType.getShape();
3235   int64_t valueH = valueShape[getValueAlphaHDim()];
3236   int64_t valueW = valueShape[getValueAlphaWDim()];
3237   int64_t valueTileH = valueShape[getValueTileHDim()];
3238   int64_t valueTileW = valueShape[getValueTileWDim()];
3239   int m = getM();
3240   int r = getR();
3241   bool leftTransform = valueH != 1;
3242   bool rightTransform = valueW != 1;
3243 
3244   int64_t outputRank = getOutputOperandRank();
3245   SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3246   if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3247     expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3248   } else {
3249     if (valueH != (leftTransform ? m + r - 1 : 1))
3250       return emitOpError("expect input height equals to input tile size");
3251     expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3252   }
3253   if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3254     expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3255   } else {
3256     if (valueW != (rightTransform ? m + r - 1 : 1))
3257       return emitOpError("expect input width equals to input tile size");
3258     expectedOutputShape[getOutputWDim()] =
3259         (rightTransform ? m : 1) * valueTileW;
3260   }
3261   expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3262   expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3263 
3264   auto outputType = cast<ShapedType>(getOutput().getType());
3265   ArrayRef<int64_t> outputShape = outputType.getShape();
3266   if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3267     return emitOpError("the output shape is not expected");
3268   }
3269   return success();
3270 }
3271 
3272 SmallVector<Range>
3273 WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3274   Location loc = getLoc();
3275   IntegerAttr zeroAttr = builder.getIndexAttr(0);
3276   IntegerAttr oneAttr = builder.getIndexAttr(1);
3277   Value value = getValue();
3278   int64_t valueRank = getValueOperandRank();
3279   SmallVector<Range> loopBounds(valueRank);
3280   for (unsigned dim = 0; dim < valueRank; ++dim) {
3281     loopBounds[dim].offset = zeroAttr;
3282     // alphaH, alphaW, tileH, tileW, N, F
3283     loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3284     loopBounds[dim].stride = oneAttr;
3285   }
3286   return loopBounds;
3287 }
3288 
3289 SmallVector<utils::IteratorType>
3290 WinogradOutputTransformOp::getLoopIteratorTypes() {
3291   int64_t valueRank = getValueOperandRank();
3292   SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3293                                                  utils::IteratorType::parallel);
3294   return iteratorTypes;
3295 }
3296 
3297 LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3298     OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3299     ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3300     SmallVector<OpFoldResult> &resultSizes) {
3301   int64_t m = getM();
3302 
3303   Location loc = getLoc();
3304   MLIRContext *context = builder.getContext();
3305   auto identityAffineMap =
3306       AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3307   auto affineMap =
3308       AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3309 
3310   ShapedType valueType = getValueOperandType();
3311   ArrayRef<int64_t> valueShape = valueType.getShape();
3312   int64_t valueH = valueShape[0];
3313   int64_t valueW = valueShape[1];
3314   Value mappedOffsetH = affine::makeComposedAffineApply(
3315       builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3316       offsets[getValueTileHDim()]);
3317   Value mappedOffsetW = affine::makeComposedAffineApply(
3318       builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3319       offsets[getValueTileWDim()]);
3320   Value mappedSizeH = affine::makeComposedAffineApply(
3321       builder, loc, affineMap, sizes[getValueTileHDim()]);
3322   Value mappedSizeW = affine::makeComposedAffineApply(
3323       builder, loc, affineMap, sizes[getValueTileWDim()]);
3324 
3325   IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3326   OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3327   OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3328   OpFoldResult sizeH =
3329       valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3330   OpFoldResult sizeW =
3331       valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3332 
3333   resultOffsets.append(
3334       {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3335   resultSizes.append(
3336       {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3337   return success();
3338 }
3339 
3340 /// Implement tiling for winograd_output_transform
3341 /// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3342 /// F). The output of winograd_output_transform is (N, H, W, F) Users can
3343 /// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3344 /// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3345 /// for the sizes of tileH, tileW, N, F for one tile.
3346 FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3347     OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3348     ArrayRef<OpFoldResult> sizes) {
3349   IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3350   IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3351   Location loc = getLoc();
3352   SmallVector<Value> tiledOperands;
3353   SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3354 
3355   ShapedType valueType = getValueOperandType();
3356   ArrayRef<int64_t> valueShape = valueType.getShape();
3357   int64_t alphaH = valueShape[getValueAlphaHDim()];
3358   int64_t alphaW = valueShape[getValueAlphaWDim()];
3359   IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3360   IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3361 
3362   sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3363                        offsets[getValueTileWDim()], offsets[getValueNDim()],
3364                        offsets[getValueFDim()]});
3365   sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3366                      sizes[getValueTileWDim()], sizes[getValueNDim()],
3367                      sizes[getValueFDim()]});
3368   int64_t valueRank = getValueOperandRank();
3369   SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3370   auto valueSlice = builder.create<tensor::ExtractSliceOp>(
3371       loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3372   tiledOperands.emplace_back(valueSlice);
3373 
3374   SmallVector<OpFoldResult> resultOffsets, resultSizes;
3375   if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3376                                    resultSizes)))
3377     return failure();
3378 
3379   int64_t outputRank = getOutputOperandRank();
3380   SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3381   auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3382       loc, getOutput(), resultOffsets, resultSizes, strides);
3383   tiledOperands.emplace_back(outputSlice);
3384 
3385   SmallVector<Type> resultTypes;
3386   resultTypes.push_back(tiledOperands[1].getType());
3387   Operation *tiledOp =
3388       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3389 
3390   return TilingResult{
3391       {tiledOp},
3392       SmallVector<Value>(tiledOp->getResults()),
3393       llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3394 }
3395 
3396 //===----------------------------------------------------------------------===//
3397 // LinalgDialect
3398 //===----------------------------------------------------------------------===//
3399 
3400 void LinalgDialect::getCanonicalizationPatterns(
3401     RewritePatternSet &results) const {
3402   results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
3403               InferStaticShapeOfOperands>(getContext());
3404 }
3405 
3406 Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
3407                                               Attribute value, Type type,
3408                                               Location loc) {
3409   return arith::ConstantOp::materialize(builder, value, type, loc);
3410 }
3411 
3412 /// Returns true if the result AffineExpr of the \p explicitMap is same as \p
3413 /// defaultMap.
3414 static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) {
3415   auto explicitRange = explictMap.getResults();
3416   auto defaultRange = defaultMap.getResults();
3417   DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3418   DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3419   llvm::set_union(explicitSet, defaultSet);
3420   return explicitSet == defaultSet;
3421 }
3422 
3423 /// Returns true if the \p explictMap is broadcasted with respect to the
3424 /// \p defaultMap.
3425 static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3426   return explictMap.getNumResults() < defaultMap.getNumResults();
3427 }
3428 
3429 /// Verifies the broadcast and transpose semantic sepecified by the explicit
3430 /// indexing map for the MatmulOp \p op for each operand specified by \p
3431 /// opIndex.
3432 static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3433                                                   unsigned opIndex) {
3434   SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3435   SmallVector<AffineMap, 3> defaultIndexingMaps =
3436       matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3437 
3438   auto opIndexingMap = opIndexingMaps[opIndex];
3439   auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3440   // Check general validity of indexing map results.
3441   if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
3442     return matmulOp->emitOpError()
3443            << "Unexpected dim expression in map result.";
3444 
3445   // Check if the requested broadcast is valid.
3446   if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3447     if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3448       return matmulOp->emitOpError()
3449              << "Invalid broadcast requested, should be (d2).";
3450     }
3451     return success();
3452   }
3453   return success();
3454 }
3455 
3456 namespace mlir {
3457 namespace linalg {
3458 
3459 //===----------------------------------------------------------------------===//
3460 // MatMulOp
3461 //===----------------------------------------------------------------------===//
3462 
3463 /// Returns a list of AffineMap with the typical matmul indexing charactristic.
3464 SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3465   AffineExpr d0, d1, d2;
3466   SmallVector<AffineMap> indexingMaps;
3467   bindDims(context, d0, d1, d2);
3468   indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3469   indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3470   indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3471   return indexingMaps;
3472 }
3473 
3474 SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3475   return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3476                                           utils::IteratorType::parallel,
3477                                           utils::IteratorType::reduction};
3478 }
3479 
3480 unsigned MatmulOp::getNumRegionArgs() { return 3; }
3481 
3482 std::string MatmulOp::getLibraryCallName() {
3483   return generateLibraryCallName(getOperation());
3484 }
3485 
3486 bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3487 
3488 /// Check if the op has broadcast and/or transpose semantic. Returns true if
3489 /// the user defined indexing maps are not equal to default map.
3490 bool MatmulOp::hasUserDefinedMaps() {
3491   SmallVector<AffineMap, 3> defaultMaps =
3492       getDefaultIndexingMaps(this->getContext());
3493   SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3494   return defaultMaps != explicitMaps;
3495 }
3496 
3497 /// Implements the block region builder for the MatmulOp. This is called by
3498 /// 'fillStructuredOpRegion'.
3499 void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3500                              ArrayRef<NamedAttribute> attrs) {
3501   assert(3 > 0 && block.getNumArguments() == 3 &&
3502          "MatmulOp regionBuilder expects 3 (>=0) args");
3503   RegionBuilderHelper helper(b, block);
3504   SmallVector<Value> yields;
3505 
3506   TypeFn castVal = TypeFn::cast_signed;
3507   auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3508     return attr.getName() == "cast";
3509   });
3510   if (castIter != attrs.end()) {
3511     if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3512       castVal = attr.getValue();
3513   }
3514 
3515   Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3516                                     block.getArgument(0));
3517   Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3518                                     block.getArgument(1));
3519   Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
3520   Value value4 =
3521       helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
3522   yields.push_back(value4);
3523   helper.yieldOutputs(yields);
3524 }
3525 
3526 /// Returns true if the given broadcast map \p bcastMap is valid for this op.
3527 bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3528   assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3529   AffineExpr exp = bcastMap.getResult(0);
3530   // Invalid map if the common dimension of matmul not found.
3531   return exp.isFunctionOfDim(bcastMap.getNumDims() - 1);
3532 }
3533 
3534 FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
3535   if (parser.parseOptionalKeyword("indexing_maps"))
3536     return {nullptr}; // Success in case indexing_maps was not provided.
3537 
3538   ArrayAttr arrayAttr;
3539   if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
3540     return failure();
3541 
3542   if (llvm::any_of(arrayAttr,
3543                    [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
3544     return parser.emitError(parser.getCurrentLocation())
3545            << "element of indexing_maps array is not an affine_map";
3546 
3547   return arrayAttr;
3548 }
3549 
3550 ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3551   FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3552   if (failed(indexingMapsAttr))
3553     return failure();
3554 
3555   if (*indexingMapsAttr == nullptr) {
3556     auto indexingMapAttrs = llvm::map_to_vector(
3557         MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3558         [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3559     indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
3560   }
3561 
3562   result.addAttribute("indexing_maps", *indexingMapsAttr);
3563   return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
3564                                 MatmulOp::getRegionBuilder());
3565 }
3566 
3567 void MatmulOp::print(OpAsmPrinter &p) {
3568   SmallVector<StringRef, 3> elidedAttrs = {
3569       "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3570   printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3571                          elidedAttrs);
3572 
3573   SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
3574       MatmulOp::getDefaultIndexingMaps(getContext()),
3575       [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3576   if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
3577     p << " indexing_maps = [";
3578     llvm::interleaveComma(getIndexingMaps(), p,
3579                           [&](Attribute attr) { p.printAttribute(attr); });
3580     p << "]";
3581   }
3582 }
3583 
3584 /// Verify the user defined indexing maps.
3585 LogicalResult MatmulOp::verify() {
3586   // Verification of pure matmul is handled by verifyStructuredOpInterface().
3587   if (!hasUserDefinedMaps())
3588     return success();
3589 
3590   for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3591     if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
3592       return failure();
3593   }
3594   return success();
3595 }
3596 
3597 LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3598   return memref::foldMemRefCast(*this);
3599 }
3600 
3601 void MatmulOp::getEffects(
3602     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3603         &effects) {
3604   if (hasPureTensorSemantics())
3605     return;
3606   getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3607 }
3608 
3609 Speculation::Speculatability MatmulOp::getSpeculatability() {
3610   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3611 }
3612 
3613 //===----------------------------------------------------------------------===//
3614 // ContractOp
3615 //===----------------------------------------------------------------------===//
3616 
3617 SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
3618   AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
3619   // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
3620   // domains are all the same, and each implements a projected permutation.
3621   // Each iteration space dim must occur for at least one operand and either
3622   // takes part in a contraction/reduction or else has parallel iteration type.
3623   // We have that a dim is a contraction/reduction dim if and only if the dim
3624   // occurs for the output operand. We use this fact for fast inference:
3625   // NB: In case we allow dims to occur solely for one input, the above still
3626   //     holds: per the einsum semantics, these are reduction dims as well.
3627   SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
3628   for (auto result : outAffineMap.getResults()) {
3629     auto dimExpr = dyn_cast<AffineDimExpr>(result);
3630     assert(dimExpr && "affine_map is a projected permutation");
3631     dimsInOutput[dimExpr.getPosition()] = true;
3632   }
3633 
3634   SmallVector<utils::IteratorType> iteratorTypes;
3635   for (auto dimOccursInOutput : dimsInOutput)
3636     iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
3637                                               : utils::IteratorType::reduction);
3638 
3639   return iteratorTypes;
3640 }
3641 
3642 unsigned ContractOp::getNumRegionArgs() { return 3; }
3643 
3644 /// Implement block region builder, which is called by 'fillStructuredOpRegion'.
3645 void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3646                                ArrayRef<NamedAttribute> attrs) {
3647   assert(block.getNumArguments() == 3 &&
3648          "ContractOp regionBuilder expects 3 args");
3649   RegionBuilderHelper helper(b, block);
3650 
3651   TypeFn castSignedness = TypeFn::cast_signed;
3652   auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3653     return attr.getName() == "cast";
3654   });
3655   if (castIter != attrs.end()) {
3656     if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3657       castSignedness = attr.getValue();
3658   }
3659 
3660   // TODO: Support fields with operators besides mult & add.
3661   Type outType = block.getArgument(2).getType();
3662   Value lhsAtOutType =
3663       helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
3664   Value rhsAtOutType =
3665       helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
3666   Value productAtOutType =
3667       helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
3668   Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3669                                       productAtOutType);
3670   helper.yieldOutputs({result});
3671 }
3672 
3673 ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
3674   FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3675   if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
3676     return parser.emitError(parser.getCurrentLocation(),
3677                             "expected 'indexing_maps' attribute");
3678   result.addAttribute("indexing_maps", *indexingMapsAttr);
3679 
3680   return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
3681                                 regionBuilder);
3682 }
3683 
3684 void ContractOp::print(OpAsmPrinter &p) {
3685   p << " indexing_maps = [";
3686   llvm::interleaveComma(getIndexingMaps(), p,
3687                         [&](Attribute attr) { p.printAttribute(attr); });
3688   p << "]";
3689   printNamedStructuredOp(
3690       p, getOperation(), getInputs(), getOutputs(),
3691       /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
3692 }
3693 
3694 LogicalResult ContractOp::verify() {
3695   int iterationSpaceDims = -1;
3696   // Map iter space dims to #occurrences in inputs' and output's affine_maps:
3697   // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
3698   // access an input operand (so occurrence count can be at most 2) and
3699   // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
3700   SmallVector<size_t> inOccurrences;
3701   SmallVector<size_t> outOccurrences;
3702 
3703   // A helper so that for each operand's affine_map and type we check that ...
3704   auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
3705                                    bool isInput) -> LogicalResult {
3706     // ... the affine_map is a projected permutation;
3707     if (!affineMap.isProjectedPermutation())
3708       return emitError("provided affine_map is not a projected permutation");
3709 
3710     // ... the rank of the affine_map's results and corresponding type match;
3711     if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
3712       if (affineMap.getNumResults() != shapedType.getRank())
3713         return emitError("ranks of shaped operand and results of corresponding "
3714                          "affine_map differ");
3715     } else if (affineMap.getNumResults() != 0) {
3716       return emitError("affine_map specifies shaped access while operand has "
3717                        "non-shaped type");
3718     }
3719 
3720     // ... the rank of the affine_map's domain is the same as those seen prior;
3721     if (iterationSpaceDims == -1) {
3722       iterationSpaceDims = affineMap.getNumDims();
3723       inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
3724       outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
3725     } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
3726       return emitError("iteration spaces of provided affine_maps differ");
3727     }
3728 
3729     // ... update counts of dims used to access either an input or the output.
3730     for (AffineExpr affineExpr : affineMap.getResults()) {
3731       auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
3732       if (!affineDimExpr)
3733         llvm_unreachable("affine_map is a projected permutation");
3734 
3735       if (isInput)
3736         inOccurrences[affineDimExpr.getPosition()] += 1;
3737       else
3738         outOccurrences[affineDimExpr.getPosition()] += 1;
3739     }
3740 
3741     return success();
3742   };
3743 
3744   for (auto &&[affineMap, operandType, isInput] :
3745        llvm::zip(getIndexingMapsArray(), getOperandTypes(),
3746                  SmallVector<bool>{true, true, false})) {
3747     if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
3748       return failure(); // NB: checkAffineMapAndType will emit relevant error.
3749   }
3750 
3751   bool hasContractingDim = false;
3752   for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
3753     size_t inOccCount = inOccurrences[dimIndex];
3754     size_t outOccCount = outOccurrences[dimIndex];
3755 
3756     // We have a contracting dim if and only if ...
3757     hasContractingDim |= inOccCount == 2 && outOccCount == 0;
3758 
3759     if (inOccCount == 0 && outOccCount == 0)
3760       return emitError() << "iteration space dim at index " << dimIndex
3761                          << " not used to access any operand";
3762 
3763     // NB: We disallow a dim which occurs for only one input operand and not
3764     //     for the output. In terms of einsum semantics such dims have a
3765     //     sensible meaning - namely an additional reduction per each such dim.
3766     //     By contrast, the ContractionOpInterface does not know about this
3767     //     iter type - cf. inferContractionDims' supported dim kinds. Similarly,
3768     //     while vector.contract's verifier accepts dims of this kind many of
3769     //     its lowerings give up on encountering these dims.
3770     // TODO: Remove following once we have comprehensive support for input-only
3771     //       reduction dims, at both the linalg- and vector-dialect levels.
3772     if (inOccCount == 1 && outOccCount != 1)
3773       return emitError()
3774              << "iteration space dim at index " << dimIndex
3775              << " is neither a contracting dim nor of parallel iteration type";
3776   }
3777 
3778   if (!hasContractingDim)
3779     return emitError("'indexing_maps' do not specify a contracting dimension");
3780 
3781   return success();
3782 }
3783 
3784 LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3785   return memref::foldMemRefCast(*this);
3786 }
3787 
3788 void ContractOp::getEffects(
3789     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3790         &effects) {
3791   if (hasPureTensorSemantics())
3792     return;
3793   getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3794 }
3795 
3796 Speculation::Speculatability ContractOp::getSpeculatability() {
3797   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3798 }
3799 
3800 } // namespace linalg
3801 } // namespace mlir
3802