//===- LinalgOps.cpp - Implementation of the linalg operations ------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements the Linalg operations. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include #include using namespace mlir; using namespace mlir::linalg; /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`. static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim) { auto type = cast(v.getType()); if (!type.isDynamicDim(dim)) return builder.getIndexAttr(type.getDimSize(dim)); return getAsOpFoldResult( TypeSwitch(v.getType()) .Case([&](RankedTensorType t) -> Value { return builder.create(loc, v, dim); }) .Case([&](MemRefType t) -> Value { return builder.create(loc, v, dim); })); } /// Returns a memref.subview or a tensor.extract_slice based on the type of the /// `source`. static Operation *getSlice(OpBuilder &b, Location loc, Value source, ArrayRef offsets, ArrayRef sizes, ArrayRef strides) { return TypeSwitch(source.getType()) .Case([&](RankedTensorType t) -> Operation * { return b.create(loc, source, offsets, sizes, strides); }) .Case([&](MemRefType type) -> Operation * { return b.create(loc, source, offsets, sizes, strides); }) .Default([&](Type t) -> Operation * { return nullptr; }); } //===----------------------------------------------------------------------===// // Helper functions //===----------------------------------------------------------------------===// Value linalg::createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) { if (llvm::isa(source.getType())) return b.createOrFold(loc, source, dim); if (llvm::isa(source.getType())) return b.createOrFold(loc, source, dim); llvm_unreachable("Expected MemRefType or TensorType"); } OpFoldResult linalg::createFoldedDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) { auto shapedType = llvm::cast(source.getType()); if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) return createOrFoldDimOp(b, loc, source, dim); return b.getIndexAttr(shapedType.getDimSize(dim)); } //===----------------------------------------------------------------------===// // Support for named Linalg ops defined in ods-gen. //===----------------------------------------------------------------------===// using RegionBuilderFn = llvm::function_ref)>; /// Fills the region of a structured operation using the provided /// `regionBuilder`. The method is used by both named structured ops created by /// ods-gen and by manually defined C++ ops. It is called by both builders and /// parsers and creates a block with arguments corresponding to the elemental /// types of `inputTypes` and `outputTypes`. static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef attrs, RegionBuilderFn regionBuilder) { SmallVector argTypes; SmallVector argLocs; for (auto containers : {inputTypes, outputTypes}) { for (auto t : containers) { argTypes.push_back( isa(t) ? getElementTypeOrSelf(t) : t); // TODO: Pass in a proper location here. argLocs.push_back(opBuilder.getUnknownLoc()); } } // RAII. OpBuilder::InsertionGuard guard(opBuilder); Block *body = opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs); opBuilder.setInsertionPointToStart(body); ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); regionBuilder(b, *body, attrs); // indexing_maps is an auto-generated method. // iterator_types is an auto-generated method. } /// Creates a structured operation given `inputs`, `outputs`, and `attributes`. /// The result types are derived automatically if `resultTensorTypes` is none. /// The body of the operation is filled using `regionBuilder`. All ods-gen /// created structured operations use the method to implement their builders. static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef attributes, RegionBuilderFn regionBuilder) { // Derive the result types if needed. SmallVector derivedResultTypes = resultTensorTypes.value_or(TypeRange()); if (!resultTensorTypes) copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes), llvm::IsaPred); state.addOperands(inputs); state.addOperands(outputs); state.addTypes(derivedResultTypes); state.addAttributes(attributes); state.addAttribute( "operandSegmentSizes", b.getDenseI32ArrayAttr({static_cast(inputs.size()), static_cast(outputs.size())})); // Create and fill the region of the structured operation. Region ®ion = *state.addRegion(); fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs), state.attributes.getAttrs(), regionBuilder); } static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef attributes, RegionBuilderFn regionBuilder, ArrayRef indexingMaps) { // Initialize indexingMaps attribute, for MatmulOp. SmallVector indexingMapsAttrVal; indexingMapsAttrVal = llvm::map_to_vector( MatmulOp::getDefaultIndexingMaps(b.getContext()), [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal)); return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs, attributes, regionBuilder); } /// Common parsing used for both named structured ops created by ods-gen and by /// manually defined C++ ops. Does not handle regions. static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl &inputTypes, SmallVectorImpl &outputTypes, bool addOperandSegmentSizes = true) { SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc; SmallVector inputsOperands, outputsOperands; if (succeeded(parser.parseOptionalLess())) { if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater()) return failure(); } attrsLoc = parser.getCurrentLocation(); if (parser.parseOptionalAttrDict(result.attributes)) return failure(); if (succeeded(parser.parseOptionalKeyword("ins"))) { if (parser.parseLParen()) return failure(); inputsOperandsLoc = parser.getCurrentLocation(); if (parser.parseOperandList(inputsOperands) || parser.parseColonTypeList(inputTypes) || parser.parseRParen()) return failure(); } if (succeeded(parser.parseOptionalKeyword("outs"))) { outputsOperandsLoc = parser.getCurrentLocation(); if (parser.parseLParen() || parser.parseOperandList(outputsOperands) || parser.parseColonTypeList(outputTypes) || parser.parseRParen()) return failure(); } if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc, result.operands) || parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc, result.operands)) return failure(); if (addOperandSegmentSizes) { // This is a bit complex because we're trying to be backward compatible with // operation syntax that mix the inherent attributes and the discardable // ones in the same dictionary. If the properties are used, we append the // operandSegmentSizes there directly. Otherwise we append it to the // discardable attributes dictionary where it is handled by the generic // Operation::create(...) method. if (result.propertiesAttr) { NamedAttrList attrs = llvm::cast(result.propertiesAttr); attrs.append("operandSegmentSizes", parser.getBuilder().getDenseI32ArrayAttr( {static_cast(inputsOperands.size()), static_cast(outputsOperands.size())})); result.propertiesAttr = attrs.getDictionary(parser.getContext()); } else { result.addAttribute("operandSegmentSizes", parser.getBuilder().getDenseI32ArrayAttr( {static_cast(inputsOperands.size()), static_cast(outputsOperands.size())})); } } if (!result.propertiesAttr) { std::optional info = result.name.getRegisteredInfo(); if (info) { if (failed(info->verifyInherentAttrs(result.attributes, [&]() { return parser.emitError(attrsLoc) << "'" << result.name.getStringRef() << "' op "; }))) return failure(); } } return success(); } static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs) { if (!inputs.empty()) p << " ins(" << inputs << " : " << inputs.getTypes() << ")"; if (!outputs.empty()) p << " outs(" << outputs << " : " << outputs.getTypes() << ")"; } //===----------------------------------------------------------------------===// // Specific parsing and printing for named structured ops created by ods-gen. //===----------------------------------------------------------------------===// static ParseResult parseNamedStructuredOpRegion( OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef attrs, RegionBuilderFn regionBuilder) { if (numRegionArgs != inputTypes.size() + outputTypes.size()) { return parser.emitError( parser.getCurrentLocation(), llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated " "region expects {0} args, got {1}", numRegionArgs, inputTypes.size() + outputTypes.size())); } OpBuilder opBuilder(parser.getContext()); fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs, regionBuilder); return success(); } static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl &resultTypes) { if (parser.parseOptionalArrowTypeList(resultTypes)) return failure(); return success(); } static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder) { // TODO: Enable when ods-gen supports captures. SmallVector inputTypes, outputTypes; if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) return failure(); // Parse optional attributes. if (parser.parseOptionalAttrDict(result.attributes)) return failure(); // TODO: consider merging results parsing into region parsing. // Need to wait for declarative assembly resolution to decide. SmallVector outputTensorsTypes; if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) return failure(); result.addTypes(outputTensorsTypes); std::unique_ptr region = std::make_unique(); if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes, outputTypes, result.attributes.getAttrs(), regionBuilder)) return failure(); result.addRegion(std::move(region)); return success(); } static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes) { if (resultTypes.empty()) return; p.printOptionalArrowTypeList(resultTypes); } static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef elidedAttrs = {}) { p.printOptionalAttrDict(op->getAttrs(), elidedAttrs); // Printing is shared with generic ops, except for the region and // attributes. printCommonStructuredOpParts(p, inputs, outputs); // Results printing. printNamedStructuredOpResults(p, op->getResultTypes()); // Region is elided. } //===----------------------------------------------------------------------===// // Region builder helper. // TODO: Move this to a utility library. // The public methods on this class are referenced directly from generated code. // Helper build the unary, binary, and type conversion functions defined by the // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this // class. // // Implementations of the math functions must be polymorphic over numeric types, // internally performing necessary casts. If the function application makes no // sense, then the only recourse is to assert and return nullptr. This can be // extended later if it becomes possible to fail construction of the region. The // invariant should be enforced at a higher level. // // TODO: These helpers are currently type polymorphic over the class of integer // and floating point types, but they will not internally cast within bit // widths of a class (mixed precision such as i8->i32) or across classes // (i.e. mixed float and integer). Many such combinations are ambiguous or need // to be handled with care and work is being considered to extend the op // language to make such cases explicit. In the mean-time, violating this will // fail verification, which is deemed acceptable. //===----------------------------------------------------------------------===// namespace { class RegionBuilderHelper { public: RegionBuilderHelper(OpBuilder &builder, Block &block) : builder(builder), block(block) {} // Build the unary functions defined by OpDSL. Value buildUnaryFn(UnaryFn unaryFn, Value arg) { if (!isFloatingPoint(arg)) llvm_unreachable("unsupported non numeric type"); OpBuilder::InsertionGuard g(builder); builder.setInsertionPointToEnd(&block); switch (unaryFn) { case UnaryFn::exp: return builder.create(arg.getLoc(), arg); case UnaryFn::log: return builder.create(arg.getLoc(), arg); case UnaryFn::abs: return builder.create(arg.getLoc(), arg); case UnaryFn::ceil: return builder.create(arg.getLoc(), arg); case UnaryFn::floor: return builder.create(arg.getLoc(), arg); case UnaryFn::negf: return builder.create(arg.getLoc(), arg); case UnaryFn::reciprocal: { Attribute oneAttr = builder.getOneAttr(arg.getType()); auto one = builder.create(arg.getLoc(), ::cast(oneAttr)); return builder.create(arg.getLoc(), one, arg); } case UnaryFn::round: return builder.create(arg.getLoc(), arg); case UnaryFn::sqrt: return builder.create(arg.getLoc(), arg); case UnaryFn::rsqrt: return builder.create(arg.getLoc(), arg); case UnaryFn::square: return builder.create(arg.getLoc(), arg, arg); case UnaryFn::tanh: return builder.create(arg.getLoc(), arg); case UnaryFn::erf: return builder.create(arg.getLoc(), arg); } llvm_unreachable("unsupported unary function"); } // Build the binary functions defined by OpDSL. Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) { bool allComplex = isComplex(arg0) && isComplex(arg1); bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1); bool allInteger = isInteger(arg0) && isInteger(arg1); bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 && arg1.getType().getIntOrFloatBitWidth() == 1; if (!allComplex && !allFloatingPoint && !allInteger) llvm_unreachable("unsupported non numeric type"); OpBuilder::InsertionGuard g(builder); builder.setInsertionPointToEnd(&block); switch (binaryFn) { case BinaryFn::add: if (allComplex) return builder.create(arg0.getLoc(), arg0, arg1); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); if (allBool) return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::sub: if (allComplex) return builder.create(arg0.getLoc(), arg0, arg1); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); if (allBool) llvm_unreachable("unsupported operation: sub with bools"); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::mul: if (allComplex) return builder.create(arg0.getLoc(), arg0, arg1); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); if (allBool) return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::div: if (allComplex) return builder.create(arg0.getLoc(), arg0, arg1); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); if (allBool) llvm_unreachable("unsupported operation: div with bools"); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::div_unsigned: if (!allInteger || allBool) llvm_unreachable("unsupported operation: unsigned div not on uint"); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::max_signed: assert(!allComplex); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::min_signed: assert(!allComplex); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::max_unsigned: assert(!allComplex); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::min_unsigned: assert(!allComplex); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::powf: assert(allFloatingPoint); return builder.create(arg0.getLoc(), arg0, arg1); } llvm_unreachable("unsupported binary function"); } // Build the ternary functions defined by OpDSL. Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2) { bool headBool = isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1; bool tailFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2); bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2); OpBuilder::InsertionGuard g(builder); builder.setInsertionPointToEnd(&block); switch (ternaryFn) { case TernaryFn::select: if (!headBool && !(tailFloatingPoint || tailInteger)) llvm_unreachable("unsupported non numeric type"); return builder.create(arg0.getLoc(), arg0, arg1, arg2); } llvm_unreachable("unsupported ternary function"); } // Build the type functions defined by OpDSL. Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) { switch (typeFn) { case TypeFn::cast_signed: return cast(toType, operand, false); case TypeFn::cast_unsigned: return cast(toType, operand, true); } llvm_unreachable("unsupported type conversion function"); } void yieldOutputs(ValueRange values) { OpBuilder::InsertionGuard g(builder); builder.setInsertionPointToEnd(&block); Location loc = builder.getUnknownLoc(); builder.create(loc, values); } Value constant(const std::string &value) { OpBuilder::InsertionGuard g(builder); builder.setInsertionPointToEnd(&block); Location loc = builder.getUnknownLoc(); Attribute valueAttr = parseAttribute(value, builder.getContext()); return builder.create(loc, ::cast(valueAttr)); } Value index(int64_t dim) { OpBuilder::InsertionGuard g(builder); builder.setInsertionPointToEnd(&block); return builder.create(builder.getUnknownLoc(), dim); } Type getIntegerType(unsigned width) { return IntegerType::get(builder.getContext(), width); } Type getFloat32Type() { return Float32Type::get(builder.getContext()); } Type getFloat64Type() { return Float64Type::get(builder.getContext()); } private: // Generates operations to cast the given operand to a specified type. // If the cast cannot be performed, a warning will be issued and the // operand returned as-is (which will presumably yield a verification // issue downstream). Value cast(Type toType, Value operand, bool isUnsignedCast) { OpBuilder::InsertionGuard g(builder); builder.setInsertionPointToEnd(&block); auto loc = operand.getLoc(); return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast); } bool isComplex(Value value) { return llvm::isa(value.getType()); } bool isFloatingPoint(Value value) { return llvm::isa(value.getType()); } bool isInteger(Value value) { return llvm::isa(value.getType()); } OpBuilder &builder; Block █ }; } // namespace //===----------------------------------------------------------------------===// // CopyOp //===----------------------------------------------------------------------===// namespace { struct EraseSelfCopy : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CopyOp copyOp, PatternRewriter &rewriter) const override { if (copyOp.getInputs() != copyOp.getOutputs()) return rewriter.notifyMatchFailure(copyOp, "not a self copy"); if (copyOp.hasPureBufferSemantics()) rewriter.eraseOp(copyOp); else rewriter.replaceOp(copyOp, copyOp.getInputs()); return success(); } }; } // namespace void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } //===----------------------------------------------------------------------===// // FillOp //===----------------------------------------------------------------------===// namespace { /// Fold linalg.fill -> tensor.expand/collapse_shape chain. /// /// For such op chains, we can create new linalg.fill ops with the result /// type of the tensor.expand/collapse_shape op. template struct FoldFillWithTensorReshape : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, PatternRewriter &rewriter) const override { auto oldFill = reshapeOp.getSrc().template getDefiningOp(); if (!oldFill) return failure(); Location loc = oldFill.getLoc(); TensorReshapeOp newInit; if constexpr (std::is_same::value) { newInit = rewriter.create( loc, reshapeOp.getResultType(), oldFill.output(), reshapeOp.getReassociation(), reshapeOp.getOutputShape(), reshapeOp.getStaticOutputShape()); } else { newInit = rewriter.create(loc, reshapeOp.getResultType(), oldFill.output(), reshapeOp.getReassociation()); } rewriter.replaceOpWithNewOp(reshapeOp, ValueRange{oldFill.value()}, ValueRange{newInit}); return success(); } }; /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the /// filling value are the same. struct FoldFillWithPad final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override { auto fillOp = padOp.getSource().getDefiningOp(); if (!fillOp) return failure(); // We can only fold if the padding value is the same as the original // filling value. Value padValue = padOp.getConstantPaddingValue(); if (!padValue || fillOp.value() != padValue) return failure(); ReifiedRankedShapedTypeDims reifiedShape; if (failed(reifyResultShapes(rewriter, padOp, reifiedShape))) return rewriter.notifyMatchFailure( padOp, "failed to reify tensor.pad op result shape"); auto emptyTensor = rewriter.create( padOp.getLoc(), reifiedShape.front(), padOp.getResultType().getElementType()); Value replacement = rewriter .create(fillOp.getLoc(), ValueRange{padValue}, ValueRange{emptyTensor}) .getResult(0); if (replacement.getType() != padOp.getResultType()) { replacement = rewriter.create( fillOp.getLoc(), padOp.getResultType(), replacement); } rewriter.replaceOp(padOp, replacement); return success(); } }; /// Fold tensor.insert_slice(tensor.pad(), linalg.fill) into /// tensor.insert_slice(, linalg.fill) if the padding value and the /// filling value are the same. struct FoldInsertPadIntoFill : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, PatternRewriter &rewriter) const override { auto srcPadOp = insertOp.getSource().getDefiningOp(); if (!srcPadOp) return failure(); if (insertOp.getType().getRank() != insertOp.getSourceType().getRank()) return failure(); // Walk back the tensor.insert_slice chain and find the first destination // value at the start of the chain. Value firstDest = insertOp.getDest(); while (auto prevOp = firstDest.getDefiningOp()) { if (prevOp.getType().getRank() != prevOp.getSourceType().getRank()) return failure(); // Make sure the range of values accessed are disjoint. Without this, we // cannot fold tensor.pad away. bool disjoint = false; for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) { // If the dimension has dynamic offset/size, we cannot guarantee // disjoint. So just skip it. if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) || insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) || prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i)) continue; // Get the range start and end, inclusively for both. int64_t prevStart = prevOp.getStaticOffset(i); int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) * prevOp.getStaticStride(i); int64_t nextStart = insertOp.getStaticOffset(i); int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) * insertOp.getStaticStride(i); if (prevEnd < nextStart || nextEnd < prevStart) { disjoint = true; break; } } if (!disjoint) break; firstDest = prevOp.getDest(); } // Check whether the first destination is a fill op. For overlapped cases, // this also cannot be true. auto dstFillOp = firstDest.getDefiningOp(); if (!dstFillOp) return failure(); // We can only fold if the padding value is the same as the original // filling value. Value padValue = srcPadOp.getConstantPaddingValue(); if (!padValue || dstFillOp.value() != padValue) return failure(); SmallVector lowPads = srcPadOp.getMixedLowPad(); SmallVector oldOffsets = insertOp.getMixedOffsets(); Location loc = insertOp.getLoc(); MLIRContext *context = getContext(); AffineExpr sym0, sym1; bindSymbols(context, sym0, sym1); auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context); // Calculate the new offsets for the insert. It should be the old offsets // plus low padding sizes. SmallVector newOffsets; for (const auto &p : llvm::zip(lowPads, oldOffsets)) { newOffsets.push_back(affine::makeComposedFoldedAffineApply( rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)})); } RankedTensorType srcPadType = srcPadOp.getSourceType(); SmallVector newSizes; for (int i = 0, e = srcPadType.getRank(); i < e; ++i) { if (srcPadType.isDynamicDim(i)) { newSizes.push_back( rewriter.create(loc, srcPadOp.getSource(), i) .getResult()); } else { newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i))); } } rewriter.replaceOpWithNewOp( insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets, newSizes, insertOp.getMixedStrides()); return success(); } }; /// Fold tensor.extract(linalg.fill()) into struct FoldFillWithTensorExtract : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, PatternRewriter &rewriter) const override { // See if tensor input of tensor.extract op is the result of a linalg.fill // op. auto fillOp = extractOp.getTensor().getDefiningOp(); if (!fillOp) return failure(); // Get scalar input operand of linalg.fill op. Value extractedScalar = fillOp.getInputs()[0]; // Replace tensor.extract op with scalar value used to fill the tensor. rewriter.replaceOp(extractOp, extractedScalar); return success(); } }; /// Folds pack(fill) into a single fill op if /// 1. The pack op does not have padding value, or /// 2. The filled value and padding value are the same. static FailureOr foldFillPackIntoFillOp(RewriterBase &rewriter, tensor::PackOp packOp) { auto fillOp = packOp.getSource().getDefiningOp(); if (!fillOp) return failure(); if (auto paddingValue = packOp.getPaddingValue()) if (!isEqualConstantIntOrValue(paddingValue, fillOp.value())) return failure(); Value packOpDest = packOp.getDest(); if (!packOpDest.hasOneUse()) return failure(); return rewriter.create(packOp.getLoc(), fillOp.getInputs(), packOp.getDest()); } /// Wrapper pattern that applies foldFillPackIntoFillOp method. struct FoldFillWithPack : public OpRewritePattern { public: FoldFillWithPack(MLIRContext *context) : OpRewritePattern(context) {} LogicalResult matchAndRewrite(tensor::PackOp packOp, PatternRewriter &rewriter) const override { auto fillOp = foldFillPackIntoFillOp(rewriter, packOp); if (failed(fillOp)) return failure(); rewriter.replaceOp(packOp, fillOp.value().result()); return success(); } }; /// Fold fill with copy. struct FoldFillWithCopy : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(linalg::CopyOp copyOp, PatternRewriter &rewriter) const override { if (auto fillOp = copyOp.getInputs().front().getDefiningOp()) { rewriter.replaceOpWithNewOp(copyOp, copyOp.getResultTypes(), fillOp.getInputs(), copyOp.getOutputs()); return success(); } if (auto fillOp = copyOp.getOutputs().front().getDefiningOp()) { rewriter.replaceOpWithNewOp(copyOp, copyOp.getInputs(), fillOp.getOutputs()); return success(); } return failure(); } }; /// Fold fill with transpose. struct FoldFillWithTranspose : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override { if (auto fillOp = transposeOp.getInput().getDefiningOp()) { rewriter.replaceOpWithNewOp( transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(), transposeOp.getDpsInitOperand(0)->get()); return success(); } return failure(); } }; /// Fold a concat with all elements being fills of the same value /// into a fill of the concat result shape. struct FoldConcatsOfFill : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::ConcatOp concatOp, PatternRewriter &rewriter) const override { auto concatOperands = concatOp.getInputs(); if (concatOperands.empty()) { return failure(); } auto firstFillOp = concatOperands.front().getDefiningOp(); if (!firstFillOp) { return failure(); } // Prefetch the fill value. OpFoldResult firstFillVal = getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get()); // Collect all the outs values for the fill operations. SmallVector allOuts; allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get()); auto isDefinedByCompatibleFillOp = [&](Value v) -> bool { auto fillOp = v.getDefiningOp(); if (!fillOp) { return false; } OpFoldResult fillVal = getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get()); if (fillVal != firstFillVal) return false; allOuts.push_back(fillOp.getDpsInitOperand(0)->get()); return true; }; if (!llvm::all_of(concatOperands.drop_front(), isDefinedByCompatibleFillOp)) { return rewriter.notifyMatchFailure( concatOp, "not all operands are defined by a compatible fill op"); } Value outsConcat = rewriter.create( concatOp.getLoc(), concatOp.getDim(), allOuts); rewriter.replaceOpWithNewOp( concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat); return success(); } }; } // namespace void FillOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, FoldFillWithTensorReshape, FoldInsertPadIntoFill, FoldFillWithTranspose>(context); } //===----------------------------------------------------------------------===// // GenericOp //===----------------------------------------------------------------------===// static void buildGenericRegion( OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref bodyBuild) { SmallVector blockArgTypes; SmallVector blockArgLocs; for (ValueRange container : {inputs, outputs}) { for (Value v : container) { Type t = v.getType(); blockArgTypes.push_back( isa(t) ? getElementTypeOrSelf(t) : t); blockArgLocs.push_back(v.getLoc()); } } OpBuilder::InsertionGuard guard(builder); Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); bodyBuild(builder, loc, bodyBlock->getArguments()); } void GenericOp::getAsmBlockArgumentNames(Region ®ion, OpAsmSetValueNameFn setNameFn) { for (Value v : getRegionInputArgs()) setNameFn(v, "in"); for (Value v : getRegionOutputArgs()) setNameFn(v, "out"); } void GenericOp::build( OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps, ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall, function_ref bodyBuild, ArrayRef attributes) { build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps, iteratorTypes, doc, libraryCall); result.addAttributes(attributes); if (bodyBuild) buildGenericRegion(builder, result.location, *result.regions.front(), inputs, outputs, bodyBuild); } void GenericOp::build( OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, ArrayRef iteratorTypes, StringRef doc, StringRef libraryCall, function_ref bodyBuild, ArrayRef attributes) { build(builder, result, resultTensorTypes, inputs, outputs, builder.getAffineMapArrayAttr(indexingMaps), builder.getArrayAttr(llvm::to_vector(llvm::map_range( iteratorTypes, [&](utils::IteratorType iter) -> mlir::Attribute { return IteratorTypeAttr::get(builder.getContext(), iter); }))), doc.empty() ? StringAttr() : builder.getStringAttr(doc), libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall), bodyBuild, attributes); } void GenericOp::build( OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, ArrayRef iteratorTypes, StringRef doc, StringRef libraryCall, function_ref bodyBuild, ArrayRef attributes) { build(builder, result, TypeRange{}, inputs, outputs, indexingMaps, iteratorTypes, doc, libraryCall, bodyBuild, attributes); } void GenericOp::build( OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, ArrayRef iteratorTypes, function_ref bodyBuild, ArrayRef attributes) { build(builder, result, inputs, outputs, indexingMaps, iteratorTypes, /*doc=*/"", /*libraryCall=*/"", bodyBuild, attributes); } void GenericOp::build( OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, ArrayRef iteratorTypes, function_ref bodyBuild, ArrayRef attributes) { build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps, iteratorTypes, /*doc=*/"", /*libraryCall=*/"", bodyBuild, attributes); } void GenericOp::print(OpAsmPrinter &p) { p << " "; // Print extra attributes. auto genericAttrNames = linalgTraitAttrNames(); llvm::StringSet<> genericAttrNamesSet; genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end()); SmallVector genericAttrs; for (auto attr : (*this)->getAttrs()) { if (attr.getName() == getIteratorTypesAttrName()) { auto iteratorTypes = llvm::cast(attr.getValue()) .getAsValueRange(); // Convert IteratorType enums into the string representation. This is // needed, because tests still use the old format when 'iterator_types' // attribute is represented as an array of strings. // TODO: Remove this conversion once tests are fixed. SmallVector iteratorTypeNames = llvm::to_vector(llvm::map_range( iteratorTypes, [&](utils::IteratorType t) -> Attribute { return StringAttr::get(getContext(), stringifyIteratorType(t)); })); genericAttrs.emplace_back( getIteratorTypesAttrName(), ArrayAttr::get(getContext(), iteratorTypeNames)); } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) { genericAttrs.push_back(attr); } } if (!genericAttrs.empty()) { auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs); p << genericDictAttr; } // Printing is shared with named ops, except for the region and attributes printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); genericAttrNames.push_back("operandSegmentSizes"); genericAttrNamesSet.insert(genericAttrNames.back()); bool hasExtraAttrs = false; for (NamedAttribute n : (*this)->getAttrs()) { if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref()))) break; } if (hasExtraAttrs) { p << " attrs = "; p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/genericAttrNames); } // Print region. if (!getRegion().empty()) { p << ' '; p.printRegion(getRegion()); } // Print results. printNamedStructuredOpResults(p, getResultTensors().getTypes()); } ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) { DictionaryAttr dictAttr; // Parse the core linalg traits that must check into a dictAttr. // The name is unimportant as we will overwrite result.attributes. // The core linalg traits must contain the information necessary to pass the // verifier. llvm::SMLoc attributeLocation = parser.getCurrentLocation(); if (parser.parseAttribute(dictAttr, "_", result.attributes)) return failure(); result.attributes.assign(dictAttr.getValue().begin(), dictAttr.getValue().end()); // Convert array of string into an array of IteratorType enums. This is // needed, because tests still use the old format when 'iterator_types' // attribute is represented as an array of strings. // TODO: Remove this conversion once tests are fixed. auto iteratorTypes = dyn_cast_or_null( result.attributes.get(getIteratorTypesAttrName(result.name))); if (!iteratorTypes) { return parser.emitError(attributeLocation) << "expected " << getIteratorTypesAttrName(result.name) << " array attribute"; } SmallVector iteratorTypeAttrs; for (StringRef s : iteratorTypes.getAsValueRange()) { auto maybeIteratorType = utils::symbolizeIteratorType(s); if (!maybeIteratorType.has_value()) return parser.emitError(parser.getCurrentLocation()) << "unexpected iterator_type (" << s << ")"; iteratorTypeAttrs.push_back( IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value())); } result.attributes.set(getIteratorTypesAttrName(result.name), parser.getBuilder().getArrayAttr(iteratorTypeAttrs)); // Parsing is shared with named ops, except for the region. SmallVector inputTypes, outputTypes; if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) return failure(); // Optional attributes may be added. if (succeeded(parser.parseOptionalKeyword("attrs"))) if (failed(parser.parseEqual()) || failed(parser.parseOptionalAttrDict(result.attributes))) return failure(); std::unique_ptr region = std::make_unique(); if (parser.parseRegion(*region, {})) return failure(); result.addRegion(std::move(region)); // Generic ops may specify that a subset of its outputs are tensors. Such // outputs are specified in the result type. // TODO: may need to move output parsing before region parsing. // Need to wait for declarative assembly resolution to decide. SmallVector outputTensorsTypes; if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) return failure(); result.addTypes(outputTensorsTypes); return success(); } static void getGenericEffectsImpl( SmallVectorImpl> &effects, LinalgOp linalgOp) { for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) { if (!llvm::isa(operand.getType())) continue; effects.emplace_back( MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0, /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get()); } for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { if (!llvm::isa(operand.get().getType())) continue; if (linalgOp.payloadUsesValueFromOperand(&operand)) { effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0, /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get()); } effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0, /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get()); } } void GenericOp::getEffects( SmallVectorImpl> &effects) { getGenericEffectsImpl(effects, cast(getOperation())); } static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp) { // Operands with value semantics are speculatable, while operands with memory // semantics are not. if (!linalgOp.hasPureTensorSemantics()) return Speculation::NotSpeculatable; // The body of the op can still have speculation in its region. return Speculation::RecursivelySpeculatable; } Speculation::Speculatability GenericOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast(getOperation())); } LogicalResult GenericOp::verify() { return success(); } namespace { /// Remove any linalg operation (on tensors) that are just copying /// the values from inputs to the results. Requirements are /// 1) All iterator types are parallel /// 2) The body contains just a yield operation with the yielded values being /// the arguments corresponding to the operands. template struct EraseIdentityLinalgOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy linalgOp, PatternRewriter &rewriter) const override { // All indexing maps must be equal. It follows that they are permutations. if (!llvm::all_equal(linalgOp.getIndexingMapsArray())) return failure(); // Check that the body of the linalg operation is just a linalg.yield // operation. Block &body = linalgOp->getRegion(0).front(); if (!llvm::hasSingleElement(body)) return failure(); auto yieldOp = dyn_cast(body.getTerminator()); if (!yieldOp) return failure(); // In the buffer case, we need to check exact buffer equality. if (linalgOp.hasPureBufferSemantics()) { if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 && linalgOp.getDpsInputOperand(0)->get() == linalgOp.getDpsInitOperand(0)->get()) { rewriter.eraseOp(linalgOp); return success(); } return failure(); } // Mixed semantics is not supported yet. if (!linalgOp.hasPureTensorSemantics()) return failure(); // Get the argument number of the returned values. That is the operand // number to use for replacing uses of this operation. SmallVector returnedArgs; for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) { auto yieldArg = llvm::dyn_cast(yieldVal.value()); if (!yieldArg || yieldArg.getOwner() != &body) return failure(); unsigned argumentNumber = yieldArg.getArgNumber(); Value returnedArg = linalgOp->getOperand(argumentNumber); Type resultType = linalgOp->getResult(yieldVal.index()).getType(); // The input can have a different type than the result, e.g. a dynamic // input dimension can be turned into a static output dimension. Type returnType = returnedArg.getType(); if (returnType != resultType) { // Distinguish between sparse conversion or dense tensor casting. // TODO: unify the two ops? if (sparse_tensor::getSparseTensorEncoding(returnType) || sparse_tensor::getSparseTensorEncoding(resultType)) returnedArg = rewriter.create( linalgOp.getLoc(), resultType, returnedArg); else { if (!tensor::CastOp::areCastCompatible(returnedArg.getType(), resultType)) return failure(); returnedArg = rewriter.create( linalgOp.getLoc(), resultType, returnedArg); } } returnedArgs.push_back(returnedArg); } if (returnedArgs.size() != linalgOp->getNumResults()) return failure(); rewriter.replaceOp(linalgOp, returnedArgs); return success(); } }; } // namespace void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add>(context); } LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl &) { return memref::foldMemRefCast(*this); } //===----------------------------------------------------------------------===// // MapOp //===----------------------------------------------------------------------===// static ParseResult parseDstStyleOp( OpAsmParser &parser, OperationState &result, function_ref parseAttrsFn = nullptr) { // Parse `ins` and `outs`. SmallVector inputTypes, outputTypes; if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes, /*addOperandSegmentSizes=*/false)) return failure(); // Add result types. for (Type outputType : outputTypes) { if (llvm::isa(outputType)) result.addTypes(outputType); } // Parse required attributes. if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes))) return failure(); // Parse optional attributes. if (parser.parseOptionalAttrDict(result.attributes)) return failure(); return success(); } void MapOp::getAsmBlockArgumentNames(Region ®ion, OpAsmSetValueNameFn setNameFn) { for (Value v : getRegionInputArgs()) setNameFn(v, "in"); } void MapOp::getAsmResultNames(function_ref setNameFn) { if (!getResults().empty()) setNameFn(getResults().front(), "mapped"); } void MapOp::build( OpBuilder &builder, OperationState &result, ValueRange inputs, Value init, function_ref bodyBuild, ArrayRef attributes) { build(builder, result, TypeRange{}, inputs, init); result.addAttributes(attributes); // Add output types for `RankedTensorType` output arguments. Type initType = init.getType(); if (llvm::isa(initType)) result.addTypes(initType); if (bodyBuild) buildGenericRegion(builder, result.location, *result.regions.front(), inputs, /*outputs=*/{}, bodyBuild); } static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef operands, bool initFirst = false) { OpBuilder b(parser.getContext()); Region *body = result.addRegion(); Block &block = body->emplaceBlock(); b.setInsertionPointToStart(&block); SmallVector bbArgs; for (auto &operand : operands) { block.addArgument( llvm::cast(operand.getType()).getElementType(), b.getUnknownLoc()); } SmallVector payloadOpOperands; // If initFirst flag is enabled, we consider init as the first position of // payload operands. if (initFirst) { payloadOpOperands.push_back(block.getArguments().back()); for (const auto &arg : block.getArguments().drop_back()) payloadOpOperands.push_back(arg); } else { payloadOpOperands = {block.getArguments().begin(), block.getArguments().end()}; } Operation *payloadOp = b.create( result.location, b.getStringAttr(payloadOpName.getStringRef()), payloadOpOperands, TypeRange{llvm::cast(result.operands.back().getType()) .getElementType()}, payloadOpAttrs); b.create(result.location, payloadOp->getResults()); } ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { std::optional payloadOpName; NamedAttrList payloadOpAttrs; if (succeeded(parser.parseOptionalLBrace())) { FailureOr operationName = parser.parseCustomOperationName(); if (failed(operationName)) return failure(); if (parser.parseOptionalAttrDict(payloadOpAttrs)) return failure(); payloadOpName = operationName.value(); if (parser.parseRBrace()) return failure(); } if (parseDstStyleOp(parser, result)) return failure(); if (payloadOpName.has_value()) { if (!result.operands.empty()) addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs, ArrayRef(result.operands).drop_back()); else result.addRegion(); } else { SmallVector regionArgs; if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren, /*allowType=*/true, /*allowAttrs=*/true)) { return failure(); } Region *body = result.addRegion(); if (parser.parseRegion(*body, regionArgs)) return failure(); } return success(); } // Retrieve the operation from the body, if it is the only one (except // yield) and if it gets the same amount of arguments as the body does. // If initFirst flag is enabled, we check that init takes the first position in // operands of payload. static Operation *findPayloadOp(Block *body, bool initFirst = false) { if (body->getOperations().size() != 2) return nullptr; Operation &payload = body->getOperations().front(); assert(isa(body->getOperations().back())); if (payload.getNumOperands() == 0 || payload.getNumOperands() != body->getNumArguments()) return nullptr; if (initFirst) { // check init if (payload.getOperands().back() != body->getArgument(0)) return nullptr; // check rest for (const auto &[operand, bbArg] : llvm::zip(payload.getOperands(), body->getArguments().drop_front())) { if (bbArg != operand) return nullptr; } } else { for (const auto &[operand, bbArg] : llvm::zip(payload.getOperands(), body->getArguments())) { if (bbArg != operand) return nullptr; } } return &payload; } void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { SmallVector elidedAttrs; std::string attrToElide; p << " { " << payloadOp->getName().getStringRef(); for (const auto &attr : payloadOp->getAttrs()) { auto fastAttr = llvm::dyn_cast(attr.getValue()); if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) { attrToElide = attr.getName().str(); elidedAttrs.push_back(attrToElide); break; } } p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs); p << " }"; } void MapOp::print(OpAsmPrinter &p) { Block *mapper = getBody(); Operation *payloadOp = findPayloadOp(mapper); if (payloadOp) { printShortForm(p, payloadOp); } printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); p.printOptionalAttrDict((*this)->getAttrs()); if (!payloadOp) { // Print region if the payload op was not detected. p.increaseIndent(); p.printNewline(); p << "("; llvm::interleaveComma(mapper->getArguments(), p, [&](auto arg) { p.printRegionArgument(arg); }); p << ") "; p.printRegion(getMapper(), /*printEntryBlockArgs=*/false); p.decreaseIndent(); } } LogicalResult MapOp::verify() { auto *bodyBlock = getBody(); auto blockArgs = bodyBlock->getArguments(); // Checks if the number of `inputs` match the arity of the `mapper` region. if (getInputs().size() != blockArgs.size()) return emitOpError() << "expects number of operands to match the arity of " "mapper, but got: " << getInputs().size() << " and " << blockArgs.size(); // The parameters of mapper should all match the element type of inputs. for (const auto &[bbArgType, inputArg] : llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) { auto inputElemType = llvm::cast(inputArg.getType()).getElementType(); if (bbArgType != inputElemType) { return emitOpError() << "expected element type of input " << inputElemType << " to match bbArg type " << bbArgType; } } // The shape of each input must match the shape of the output. auto outputShape = getInit().getType().getShape(); for (Type inputArgType : TypeRange{getInputs()}) { auto inputElemShape = llvm::cast(inputArgType).getShape(); if (inputElemShape != outputShape) { return emitOpError() << "expected shape of input (" << inputElemShape << ") to match shape of output (" << outputShape << ")"; } } return success(); } SmallVector MapOp::getIteratorTypesArray() { int64_t rank = getInit().getType().getRank(); return SmallVector(rank, utils::IteratorType::parallel); } ArrayAttr MapOp::getIndexingMaps() { Builder builder(getContext()); int64_t rank = getInit().getType().getRank(); int64_t numIndexingMaps = getOperands().size(); return builder.getAffineMapArrayAttr(SmallVector( numIndexingMaps, builder.getMultiDimIdentityMap(rank))); } void MapOp::getEffects( SmallVectorImpl> &effects) { getGenericEffectsImpl(effects, cast(getOperation())); } Speculation::Speculatability MapOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast(getOperation())); } //===----------------------------------------------------------------------===// // ReduceOp //===----------------------------------------------------------------------===// void ReduceOp::getAsmBlockArgumentNames(Region ®ion, OpAsmSetValueNameFn setNameFn) { for (Value v : getRegionInputArgs()) setNameFn(v, "in"); for (Value v : getRegionOutputArgs()) setNameFn(v, "init"); } void ReduceOp::getAsmResultNames( function_ref setNameFn) { if (!getResults().empty()) setNameFn(getResults().front(), "reduced"); } void ReduceOp::build( OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange inits, ArrayRef dimensions, function_ref bodyBuild, ArrayRef attributes) { build(builder, result, TypeRange{}, inputs, inits, dimensions); result.addAttributes(attributes); // Add output types for `RankedTensorType` output arguments. for (Value init : inits) { Type initType = init.getType(); if (llvm::isa(initType)) result.addTypes(initType); } if (bodyBuild) buildGenericRegion(builder, result.location, *result.regions.front(), inputs, inits, bodyBuild); } SmallVector ReduceOp::getIteratorTypesArray() { int64_t inputRank = llvm::cast(getInputs()[0].getType()).getRank(); SmallVector iteratorTypes(inputRank, utils::IteratorType::parallel); for (int64_t reductionDim : getDimensions()) iteratorTypes[reductionDim] = utils::IteratorType::reduction; return iteratorTypes; } ArrayAttr ReduceOp::getIndexingMaps() { int64_t inputRank = llvm::cast(getInputs()[0].getType()).getRank(); SmallVector affineMaps( getNumDpsInputs(), AffineMap::getMultiDimIdentityMap(inputRank, getContext())); AffineMap resultMap = AffineMap::getMultiDimIdentityMap(inputRank, getContext()) .dropResults(getDimensions()); for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i) affineMaps.push_back(resultMap); return Builder(getContext()).getAffineMapArrayAttr(affineMaps); } void ReduceOp::getEffects( SmallVectorImpl> &effects) { getGenericEffectsImpl(effects, cast(getOperation())); } Speculation::Speculatability ReduceOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast(getOperation())); } static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName) { if (parser.parseKeyword(attributeName) || parser.parseEqual()) return failure(); attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{})); return success(); } ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { std::optional payloadOpName; NamedAttrList payloadOpAttrs; if (succeeded(parser.parseOptionalLBrace())) { FailureOr operationName = parser.parseCustomOperationName(); if (failed(operationName)) return failure(); if (parser.parseOptionalAttrDict(payloadOpAttrs)) return failure(); payloadOpName = operationName.value(); if (parser.parseRBrace()) return failure(); } if (parseDstStyleOp( parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { return parseDenseI64ArrayAttr(parser, attributes, "dimensions"); })) return failure(); if (payloadOpName.has_value()) { addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs, ArrayRef(result.operands), /*initFirst=*/true); } else { SmallVector regionArgs; if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren, /*allowType=*/true, /*allowAttrs=*/true)) { return failure(); } Region *body = result.addRegion(); if (parser.parseRegion(*body, regionArgs)) return failure(); } return success(); } static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef attributeValue) { p << ' ' << attributeName << " = [" << attributeValue << "] "; } void ReduceOp::print(OpAsmPrinter &p) { Block *mapper = getBody(); Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true); if (payloadOp) { printShortForm(p, payloadOp); } printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); if (!payloadOp) { // Print region if the payload op was not detected. p.increaseIndent(); p.printNewline(); p << "("; llvm::interleaveComma(mapper->getArguments(), p, [&](auto arg) { p.printRegionArgument(arg); }); p << ") "; p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false); p.decreaseIndent(); } } LogicalResult ReduceOp::verify() { ArrayRef dimensionsRef = getDimensions(); for (int64_t i = 1; i < getNumDpsInputs(); ++i) { if (llvm::cast(getInputs()[i].getType()).getShape() != llvm::cast(getInputs()[0].getType()).getShape()) { return emitOpError() << "expects all inputs to have the same shapes. " "Shape at input-index " << i << " is not equal to the shape at input-index 0."; } } for (int64_t i = 1; i < getNumDpsInits(); ++i) { if (llvm::cast(getInits()[i].getType()).getShape() != llvm::cast(getInits()[0].getType()).getShape()) { return emitOpError() << "expects all outputs to have the same shapes. " "Shape at output-index " << i << " is not equal to the shape at output-index 0."; } } auto inputType = llvm::cast(getInputs()[0].getType()); auto initType = llvm::cast(getInits()[0].getType()); DenseSet dimensionsToReduce; for (int64_t dimension : dimensionsRef) { if (dimension < 0 || dimension >= inputType.getRank()) { return emitOpError() << "dimensions for reduction should be in the range [0, " << inputType.getRank() - 1 << "]."; } dimensionsToReduce.insert(dimension); } auto inputDims = inputType.getShape(); auto initDims = initType.getShape(); // Input dimensions that will be left after the reduction. SmallVector reducedInputDims; for (const auto &en : llvm::enumerate(inputDims)) { if (!dimensionsToReduce.count(en.index())) reducedInputDims.push_back(en.value()); } if (reducedInputDims.size() != static_cast(initType.getRank())) { return emitOpError() << "number of dimensions after reduction " << reducedInputDims.size() << " doesn't match the init rank " << initType.getRank(); } if (reducedInputDims != initDims) return emitOpError() << "init dimensions [" << initDims << "] doesn't match input dimensions after reduction [" << reducedInputDims << "]"; Block *block = getBody(); if (block->getNumArguments() != this->getNumOperands()) return emitOpError() << "mismatching number of operands and block arguments"; // Check that the first block arguments match the element type of the inputs. for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) { Type inputElementType = llvm::cast(input.getType()).getElementType(); if (inputElementType != bbArg.getType()) return emitOpError() << "input element type " << inputElementType << " does not match corresponding block argument type " << bbArg.getType(); } // Check that the last block arguments match the element type of the outputs. for (auto [output, bbArg] : llvm::zip( getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) { auto outputElementType = llvm::cast(output.getType()).getElementType(); if (outputElementType != bbArg.getType()) return emitOpError() << "output element type " << outputElementType << " does not match corresponding block argument type " << bbArg.getType(); } return success(); } //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs) { buildGenericRegion(builder, loc, region, inputs, outputs, [](OpBuilder &b, Location loc, ValueRange args) { if (!args.empty()) b.create(loc, args[0]); }); } void TransposeOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &result, Value input, Value init, DenseI64ArrayAttr permutation, ArrayRef attributes) { result.addOperands(input); result.addOperands(init); result.addAttribute(getPermutationAttrName(result.name), permutation); result.addAttributes(attributes); // Add output types for `RankedTensorType` output arguments. Type initType = init.getType(); if (llvm::isa(initType)) result.addTypes(initType); buildIdentityRegion(builder, result.location, *result.addRegion(), input, init); } void TransposeOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &result, Value input, Value init, ArrayRef permutation, ArrayRef attributes) { build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation), attributes); } ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) { if (failed(parseDstStyleOp( parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { return parseDenseI64ArrayAttr(parser, attributes, "permutation"); }))) return failure(); OpBuilder builder(parser.getContext()); buildIdentityRegion(builder, result.location, *result.addRegion(), /*inputs=*/result.operands, /*outputs=*/{}); return success(); } void TransposeOp::getAsmResultNames( function_ref setNameFn) { if (!getResults().empty()) setNameFn(getResults().front(), "transposed"); } void TransposeOp::print(OpAsmPrinter &p) { printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation()); p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()}); } LogicalResult TransposeOp::verify() { ArrayRef permutationRef = getPermutation(); if (!isPermutationVector(permutationRef)) return emitOpError("permutation is not valid"); auto inputType = getInput().getType(); auto initType = getInit().getType(); int64_t rank = inputType.getRank(); if (rank != initType.getRank()) return emitOpError() << "input rank " << rank << " does not match init rank " << initType.getRank(); if (rank != static_cast(permutationRef.size())) return emitOpError() << "size of permutation " << permutationRef.size() << " does not match the argument rank " << rank; auto inputDims = inputType.getShape(); auto initDims = initType.getShape(); for (int64_t i = 0; i < rank; ++i) { int64_t inputDim = inputDims[permutationRef[i]]; int64_t initDim = initDims[i]; if (inputDim != initDim) { return emitOpError() << "dim(result, " << i << ") = " << initDim << " doesn't match dim(input, permutation[" << i << "]) = " << inputDim; } } return success(); } SmallVector TransposeOp::getIteratorTypesArray() { int64_t rank = getInit().getType().getRank(); return SmallVector(rank, utils::IteratorType::parallel); } ArrayAttr TransposeOp::getIndexingMaps() { Builder builder(getContext()); int64_t rank = getInit().getType().getRank(); return builder.getAffineMapArrayAttr( {inversePermutation(AffineMap::getPermutationMap( llvm::to_vector_of(getPermutation()), getContext())), builder.getMultiDimIdentityMap(rank)}); } void TransposeOp::getEffects( SmallVectorImpl> &effects) { getGenericEffectsImpl(effects, cast(getOperation())); } Speculation::Speculatability TransposeOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast(getOperation())); } LogicalResult TransposeOp::fold(FoldAdaptor adaptor, SmallVectorImpl &result) { // Only the tensor type is supported. if (!isa(getInput().getType())) return failure(); // Single dimension transpose. if (getPermutation().size() == 0) { result.push_back(getInput()); return success(); } // Identity permutation. if (isIdentityPermutation(getPermutation())) { result.push_back(getInput()); return success(); } return failure(); } /// Fold transpose with transpose. struct FoldTransposeWithTranspose : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override { auto defTransposeOp = transposeOp.getInput().getDefiningOp(); if (!defTransposeOp) return failure(); ArrayRef defPerms = defTransposeOp.getPermutation(); ArrayRef perms = transposeOp.getPermutation(); SmallVector foldedPerms; foldedPerms.reserve(perms.size()); for (int64_t perm : perms) foldedPerms.push_back(defPerms[perm]); rewriter.replaceOpWithNewOp( transposeOp, defTransposeOp.getInput(), transposeOp.getInit(), foldedPerms); return success(); } }; /// This pattern canonicalize transpose by swapping the order of /// broadcast and transpose: /// transpose(broadcast(input)) -> broadcast(transpose(input)) struct SwapTransposeWithBroadcast : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override { Value input = transposeOp.getInput(); BroadcastOp broadcastOp = input.getDefiningOp(); if (!input.hasOneUse() || !broadcastOp) return failure(); ArrayRef dimensions = broadcastOp.getDimensions(); ArrayRef perms = transposeOp.getPermutation(); // Get new perms and new dimensions. SmallVector resultPerms = dropDims(perms, dimensions); SmallVector invertPerm = invertPermutationVector(perms); SmallVector resultDimensions; unsigned dimensionSize = dimensions.size(); for (unsigned i = 0; i < dimensionSize; ++i) resultDimensions.push_back(invertPerm[dimensions[i]]); // Create transpose result. Value broadcastInput = broadcastOp.getInput(); Location loc = transposeOp.getLoc(); MLIRContext *ctx = transposeOp.getContext(); SmallVector dims; auto broadcastInputTy = mlir::cast(broadcastInput.getType()); unsigned inputRank = broadcastInputTy.getRank(); for (unsigned i = 0; i < inputRank; ++i) { if (broadcastInputTy.isDynamicDim(i)) { dims.push_back(rewriter.create(loc, broadcastInput, i) ->getResult(0)); } else { dims.push_back(IntegerAttr::get(IndexType::get(ctx), broadcastInputTy.getDimSize(i))); } } SmallVector transposeResultShapes = applyPermutation(dims, resultPerms); Value transposeInit = rewriter.create( transposeOp.getLoc(), transposeResultShapes, broadcastInputTy.getElementType()); // Create broadcast(transpose(input)). Value transposeResult = rewriter .create(loc, broadcastOp.getInput(), transposeInit, resultPerms) ->getResult(0); rewriter.replaceOpWithNewOp( transposeOp, transposeResult, transposeOp.getInit(), resultDimensions); return success(); } }; void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// void BroadcastOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &result, Value input, Value init, DenseI64ArrayAttr dimensions, ArrayRef attributes) { result.addOperands(input); result.addOperands(init); result.addAttribute(getDimensionsAttrName(result.name), dimensions); result.addAttributes(attributes); // Add output types for `RankedTensorType` output arguments. Type initType = init.getType(); if (llvm::isa(initType)) result.addTypes(initType); buildIdentityRegion(builder, result.location, *result.addRegion(), input, init); } void BroadcastOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &result, Value input, Value init, ArrayRef dimensions, ArrayRef attributes) { build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions), attributes); } ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) { if (failed(parseDstStyleOp( parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { return parseDenseI64ArrayAttr(parser, attributes, "dimensions"); }))) return failure(); OpBuilder builder(parser.getContext()); buildIdentityRegion(builder, result.location, *result.addRegion(), /*inputs=*/result.operands, /*outputs=*/{}); return success(); } void BroadcastOp::getAsmResultNames( function_ref setNameFn) { if (!getResults().empty()) setNameFn(getResults().front(), "broadcasted"); } void BroadcastOp::print(OpAsmPrinter &p) { printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); } LogicalResult BroadcastOp::verify() { ArrayRef dimensionsRef = getDimensions(); auto inputType = getInput().getType(); auto initType = getInit().getType(); int64_t inputRank = inputType.getRank(); int64_t initRank = initType.getRank(); auto inputShape = inputType.getShape(); auto initShape = initType.getShape(); if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank) return emitOpError() << "input rank plus added dimensions does not " "match init rank. input rank: " << inputRank << ", dimensions size: " << dimensionsRef.size() << ", init rank: " << initRank; for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) { if (dim < 0 || dim >= initRank) return emitOpError() << "dimension " << idx << " is out of range. expected range: [0, " << initRank - 1 << "], got: " << dim; } // Mapping from input dims to init dims. SmallVector dimMap; for (auto dim : llvm::seq(0, initRank)) { if (!llvm::is_contained(dimensionsRef, dim)) dimMap.push_back(dim); } for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) { // This dimensions is mapped from the input. Init and input dims should // match. if (inputShape[inputDimIdx] != initShape[initDimIdx]) return emitOpError() << "input dim " << inputDimIdx << " should match init dim " << initDimIdx << ". input: " << inputShape[inputDimIdx] << ", init: " << initShape[initDimIdx]; } return success(); } SmallVector BroadcastOp::getIteratorTypesArray() { int64_t rank = getInit().getType().getRank(); return SmallVector(rank, utils::IteratorType::parallel); } ArrayAttr BroadcastOp::getIndexingMaps() { Builder builder(getContext()); int64_t rank = getInit().getType().getRank(); return builder.getAffineMapArrayAttr( {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()), builder.getMultiDimIdentityMap(rank)}); } void BroadcastOp::getEffects( SmallVectorImpl> &effects) { getGenericEffectsImpl(effects, cast(getOperation())); } Speculation::Speculatability BroadcastOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast(getOperation())); } void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add>(context); } //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// void linalg::YieldOp::print(OpAsmPrinter &p) { if (getNumOperands() > 0) p << ' ' << getOperands(); p.printOptionalAttrDict((*this)->getAttrs()); if (getNumOperands() > 0) p << " : " << getOperandTypes(); } ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector opInfo; SmallVector types; SMLoc loc = parser.getCurrentLocation(); return failure(parser.parseOperandList(opInfo) || parser.parseOptionalAttrDict(result.attributes) || (!opInfo.empty() && parser.parseColonTypeList(types)) || parser.resolveOperands(opInfo, types, loc, result.operands)); } // Check the operand number and types must match the element types of the // LinalgOp interface's shaped operands. static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) { if (op.getNumOperands() != linalgOp.getNumDpsInits()) return op.emitOpError("expected number of yield values (") << op.getNumOperands() << ") to match the number of inits / outs operands of the enclosing " << "LinalgOp (" << linalgOp.getNumDpsInits() << ")"; for (OpOperand &opOperand : op->getOpOperands()) { OpOperand *outputOperand = linalgOp.getDpsInitOperand(opOperand.getOperandNumber()); Type elementType = outputOperand->get().getType(); if (isa(elementType)) elementType = getElementTypeOrSelf(outputOperand->get().getType()); if (opOperand.get().getType() != elementType) return op.emitOpError("type of yield operand ") << (opOperand.getOperandNumber() + 1) << " (" << opOperand.get().getType() << ") doesn't match " << "the element type of the enclosing linalg.generic op (" << elementType << ")"; } return success(); } LogicalResult linalg::YieldOp::verify() { auto *parentOp = (*this)->getParentOp(); if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty()) return emitOpError("expected single non-empty parent region"); if (auto linalgOp = dyn_cast(parentOp)) return verifyYield(*this, linalgOp); return emitOpError("expected parent op with LinalgOp interface"); } //===----------------------------------------------------------------------===// // IndexOp //===----------------------------------------------------------------------===// LogicalResult IndexOp::verify() { auto linalgOp = dyn_cast((*this)->getParentOp()); if (!linalgOp) return emitOpError("expected parent op with LinalgOp interface"); if (linalgOp.getNumLoops() <= getDim()) return emitOpError("expected dim (") << getDim() << ") to be lower than the number of loops (" << linalgOp.getNumLoops() << ") of the enclosing LinalgOp"; return success(); } /////// Operations corresponding to library calls defined with Tablegen //////// #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc" #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" AffineMap mlir::linalg::extractOrIdentityMap(std::optional maybeMap, unsigned rank, MLIRContext *context) { if (maybeMap) return *maybeMap; if (rank == 0) return AffineMap::get(context); return AffineMap::getMultiDimIdentityMap(rank, context); } SmallVector mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context) { SmallVector res; res.reserve(num); for (unsigned i = 0; i < num; ++i) res.push_back(getAffineDimExpr(startIdx++, context)); return res; } SmallVector mlir::linalg::concat(ArrayRef a, ArrayRef b) { auto rangeA = llvm::make_range(a.begin(), a.end()); auto rangeB = llvm::make_range(b.begin(), b.end()); auto concatRanges = llvm::concat(rangeA, rangeB); return llvm::to_vector<4>(concatRanges); } static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) { if (auto memref = llvm::dyn_cast(t)) { ss << "view"; for (auto size : memref.getShape()) if (size < 0) ss << "sx"; else ss << size << "x"; if (failed(appendMangledType(ss, memref.getElementType()))) return failure(); if (auto as = memref.getMemorySpace()) { if (auto attr = llvm::dyn_cast(as)) ss << "as" << attr.getInt(); else return failure(); } return success(); } if (auto vec = llvm::dyn_cast(t)) { ss << "vector"; llvm::interleave( vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; }); if (failed(appendMangledType(ss, vec.getElementType()))) return failure(); return success(); } if (t.isSignlessIntOrIndexOrFloat()) { ss << t; return success(); } return failure(); } std::string mlir::linalg::generateLibraryCallName(Operation *op) { assert(isa(op)); std::string name(op->getName().getStringRef().str()); std::string fun = ""; for (NamedAttribute kv : op->getAttrs()) { if (UnaryFnAttr ufa = llvm::dyn_cast(kv.getValue())) { fun = stringifyEnum(ufa.getValue()).str() + "_"; } else if (BinaryFnAttr bfa = llvm::dyn_cast(kv.getValue())) { fun = stringifyEnum(bfa.getValue()).str() + "_"; } } name.reserve(128); std::replace(name.begin(), name.end(), '.', '_'); llvm::raw_string_ostream ss(name); ss << "_" << fun; for (Type t : op->getOperandTypes()) { if (failed(appendMangledType(ss, t))) return std::string(); ss << "_"; } name.pop_back(); return name; } //===----------------------------------------------------------------------===// // Canonicalizers and Folders. //===----------------------------------------------------------------------===// namespace { struct EraseDeadLinalgOp : public OpInterfaceRewritePattern { using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override { for (OpOperand &opOperand : op->getOpOperands()) { // Linalg "inputs" may be either tensor or memref type. // tensor<0xelt_type> is a convention that may not always mean // "0 iterations". Only erase in cases we see memref<...x0x...>. auto mt = llvm::dyn_cast(opOperand.get().getType()); if (!mt) continue; if (llvm::is_contained(op.getShape(&opOperand), 0)) { rewriter.eraseOp(op); return success(); } } return failure(); } }; /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has /// result that is more static than the linalg op. struct FoldTensorCastConsumerOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::CastOp castOp, PatternRewriter &rewriter) const override { if (!tensor::canFoldIntoProducerOp(castOp)) return failure(); auto linalgOp = castOp.getSource().getDefiningOp(); if (!linalgOp) return failure(); // Cast can be in conditionally reachable region, if which case folding will // generate invalid code. Only conservatively fold ops in same block for // now. if (castOp->getBlock() != linalgOp->getBlock()) return failure(); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(linalgOp); Location loc = linalgOp.getLoc(); OpResult resultValue = llvm::cast(castOp.getSource()); unsigned resultNumber = resultValue.getResultNumber(); auto resultType = llvm::cast(castOp->getResult(0).getType()); // Replace the `outs` for the result with a `tensor.cast`. This cast is now // going from a more dynamic shape to a less dynamic shape. If the producer // for this cast, i.e. producer of the out operand, is also an operation // that folds with tensor.cast consumer (like this pattern), the cast will // continue to propagate as far up the stack as it can go. OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber); Value newOperand = rewriter.create(loc, resultType, outOperand->get()); SmallVector newOperands = linalgOp.getDpsInputs(); SmallVector outputOperands(linalgOp.getDpsInits().begin(), linalgOp.getDpsInits().end()); outputOperands[resultNumber] = newOperand; newOperands.append(outputOperands.begin(), outputOperands.end()); SmallVector resultTypes(linalgOp->result_type_begin(), linalgOp->result_type_end()); resultTypes[resultNumber] = resultType; Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands); // Create a tensor.cast operation back to the original type. Value castBack = rewriter.create( loc, resultValue.getType(), newOp->getResult(resultNumber)); SmallVector results(newOp->result_begin(), newOp->result_end()); results[resultNumber] = castBack; rewriter.replaceOp(linalgOp, results); rewriter.replaceOp(castOp, newOp->getResult(resultNumber)); return success(); } }; /// For each of the operand in `operands` this function maps the static sizes of /// dimensions to their affine dim expressions. static void populateMap(LinalgOp linalgOp, MutableArrayRef operands, llvm::DenseMap &affineExprToSize) { for (OpOperand &opOperand : operands) { if (linalgOp.isScalar(&opOperand)) continue; Value src = opOperand.get(); auto sourceType = llvm::cast(src.getType()); auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand); // Get the `sourceShape` of the `sourceType`. If the operand is a result of // `tensor.cast` operation and source of the cast operation has a static // shape, then assign it to the `sourceShape`. auto *parentOp = src.getDefiningOp(); ArrayRef sourceShape = sourceType.getShape(); if (parentOp) { if (auto castOp = dyn_cast(parentOp)) { Value castSource = castOp.getSource(); auto castSourceType = llvm::dyn_cast(castSource.getType()); if (castSourceType && castSourceType.hasStaticShape()) sourceShape = castSourceType.getShape(); } } // If the source shape's dimension has a static shape, map the affine dim // expression to the known static size. for (unsigned i = 0; i < sourceShape.size(); i++) { if (sourceType.isDynamicDim(i)) continue; if (auto affineDimExpr = dyn_cast(sourceMap.getResult(i))) affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]); } } } /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes /// mapped in `affineExprToSize`. New operands are created in `newOperands` and /// their result types is stored in `resultTypes`. If `opOperand` requires no /// change then `changeNeeded` is false and same operand is added in the /// `newOperands` list. static void createNewOperandWithStaticSizes( Location loc, PatternRewriter &rewriter, OpOperand *opOperand, llvm::DenseMap &affineExprToSize, LinalgOp linalgOp, SmallVector &newOperands, SmallVector &resultTypes, bool &changeNeeded) { Value src = opOperand->get(); newOperands.push_back(src); if (linalgOp.isScalar(opOperand)) return; auto sourceType = llvm::cast(src.getType()); Type resultType = sourceType; if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) { resultTypes.push_back(resultType); return; } ArrayRef sourceShape = sourceType.getShape(); AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand); SmallVector newShape; // If operand is updated with new shape, `newOperandNeeded` will be // true. bool newOperandNeeded = false; for (unsigned i = 0; i < sourceShape.size(); i++) { int64_t dimShape = sourceShape[i]; AffineExpr dimExpr = sourceMap.getResult(i); if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) { newShape.push_back(dimShape); continue; } // Dimension has a dynamic shape and corresponding affine dim // expression is present in the map. So assign the size for the // given affine dim expression to the dimension. newShape.push_back(affineExprToSize[dimExpr]); newOperandNeeded = true; } resultType = RankedTensorType::get(newShape, sourceType.getElementType()); if (newOperandNeeded) { changeNeeded = true; // Get the new operand value given its size and element type by // casting it. Value newOperand = rewriter.create(loc, resultType, src); unsigned index = opOperand->getOperandNumber(); newOperands[index] = newOperand; } if (linalgOp.isDpsInit(opOperand)) resultTypes.push_back(resultType); } /// Static shapes for the operands can be inferred if any one of the operands /// have a static shape. This can be done by referring to the affine dim /// expressions for the operand. struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern { using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(LinalgOp linalgOp, PatternRewriter &rewriter) const override { if (!linalgOp.hasPureTensorSemantics()) return failure(); // Maps must be projected permutations. if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) { return !map.isProjectedPermutation(); })) return failure(); // Maps affine dim expressions to the static size of that dimension. llvm::DenseMap affineExprToSize; Location loc = linalgOp.getLoc(); // For each of the affine dim expression, check if the size is known. If // known add that in the map. populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize); SmallVector newOperands; SmallVector resultTypes; // `changeNeeded` is `false` if the operands of `linalgOp` require no // change in their types. bool changeNeeded = false; newOperands.reserve(linalgOp->getNumOperands()); resultTypes.reserve(linalgOp.getNumDpsInits()); // Iterate over all the operands and update the static sizes. for (OpOperand &opOperand : linalgOp->getOpOperands()) { createNewOperandWithStaticSizes(loc, rewriter, &opOperand, affineExprToSize, linalgOp, newOperands, resultTypes, changeNeeded); } // If the generic op has all the required static information, no // canonicalization needed. if (!changeNeeded) return failure(); // Clone op. Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands); SmallVector replacements; replacements.reserve(newOp->getNumResults()); for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) { Value newResult = std::get<1>(it); Value oldResult = std::get<0>(it); Type newType = newResult.getType(); Type oldType = oldResult.getType(); replacements.push_back( (newType != oldType) ? rewriter.create(loc, oldType, newResult) : newResult); } rewriter.replaceOp(linalgOp, replacements); return success(); } }; } // namespace // All named ops canonicalizers and folders are auto-generated in the // .cpp.inc. //===----------------------------------------------------------------------===// // SoftmaxOp //===----------------------------------------------------------------------===// LogicalResult SoftmaxOp::verify() { ShapedType inputType = getInputOperandType(); ShapedType outputType = getOutputOperandType(); ArrayRef inputShape = inputType.getShape(); ArrayRef outputShape = outputType.getShape(); if (failed(verifyCompatibleShape(inputShape, outputShape))) return emitOpError("incompatible output shape"); int64_t inputRank = getInputOperandRank(); int64_t dimension = getDimension(); if ((dimension < 0) || (dimension >= inputRank)) return emitOpError("incorrect dimension specified"); return success(); } SmallVector SoftmaxOp::getIterationDomain(OpBuilder &builder) { int64_t operandRank = getInputOperandRank(); SmallVector loopBounds(operandRank); Location loc = getLoc(); Value zero = builder.create(loc, 0); Value one = builder.create(loc, 1); Value source = getInput(); for (auto dim : llvm::seq(0, operandRank)) { loopBounds[dim].offset = zero; loopBounds[dim].size = getDimValue(builder, loc, source, dim); loopBounds[dim].stride = one; } return loopBounds; } SmallVector SoftmaxOp::getLoopIteratorTypes() { SmallVector iteratorTypes(getInputOperandRank(), utils::IteratorType::parallel); iteratorTypes[getDimension()] = utils::IteratorType::reduction; return iteratorTypes; } FailureOr SoftmaxOp::getTiledImplementation(OpBuilder &builder, ArrayRef offsets, ArrayRef sizes) { int64_t rank = getInputOperandRank(); auto oneAttr = builder.getI64IntegerAttr(1); SmallVector strides(rank, oneAttr); SmallVector tiledOperands; Operation *inputSlice = getSlice(builder, getLoc(), getInput(), offsets, sizes, strides); if (!inputSlice) { return emitOpError("failed to compute input slice"); } tiledOperands.emplace_back(inputSlice->getResult(0)); Operation *outputSlice = getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides); if (!outputSlice) { return emitOpError("failed to compute output slice"); } tiledOperands.emplace_back(outputSlice->getResult(0)); SmallVector resultTypes; if (hasPureTensorSemantics()) resultTypes.push_back(tiledOperands[1].getType()); Operation *tiledOp = mlir::clone(builder, getOperation(), resultTypes, tiledOperands); return TilingResult{ {tiledOp}, SmallVector(tiledOp->getResults()), llvm::to_vector(ArrayRef{inputSlice, outputSlice})}; } LogicalResult SoftmaxOp::getResultTilePosition( OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes, SmallVector &resultOffsets, SmallVector &resultSizes) { if (resultNumber == 0) { resultOffsets.assign(offsets.begin(), offsets.end()); resultSizes.assign(sizes.begin(), sizes.end()); return success(); } return failure(); } // cast(dynamic) -> static. LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl &) { return memref::foldMemRefCast(*this); } LogicalResult SoftmaxOp::reifyResultShapes(OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { SmallVector shapes; Location loc = getOperation()->getLoc(); IRRewriter rewriter(b); auto inputShapedType = llvm::cast(getInputOperandType()); auto outputShapedType = llvm::cast(getOutputOperandType()); for (int64_t dim : llvm::seq(0, getOutputOperandRank())) { if (!outputShapedType.isDynamicDim(dim)) { // Static dim: Return IntegerAttr. shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim))); } else { // Dynamic dim: Return Value. OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim); shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr)); } } reifiedReturnShapes.emplace_back(std::move(shapes)); return success(); } void SoftmaxOp::getEffects( SmallVectorImpl> &effects) { for (auto [index, operand] : llvm::enumerate(getDpsInputs())) { if (!llvm::isa(operand.getType())) continue; effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(index), /*stage=*/0, /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get()); } for (OpOperand &operand : getDpsInitsMutable()) { if (!llvm::isa(operand.get().getType())) continue; effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0, /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get()); effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0, /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get()); } } // Helper functions for softmax decomposition. // @{ // Helper function to produce the iterator types (reduction or parallel) and // affine maps for the iterators used in the decomposition of softmax. // This method creates: // If allParallel == true: // - iterator type: {parallel, ..., parallel} // - affine maps: // -- identity with inputRank dimensions. // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN), // where N == inputRank. // // If allParallel == false: // - iterator type at dim(i) == parallel for i != \p dim and // dim(dim) == reduction. // - affine map: // -- identity with inputRank dimensions. // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN), // where N == inputRank. static std::tuple, SmallVector> computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel = false) { SmallVector iteratorTypes(inputRank, utils::IteratorType::parallel); if (!allParallel) iteratorTypes[dim] = utils::IteratorType::reduction; MLIRContext *ctxt = builder.getContext(); auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt); SmallVector affineExprs; for (int i = 0; i < inputRank; i++) { if (i != dim) affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt)); } auto reductionMap = AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt); SmallVector indexingMaps{identityMap, reductionMap}; return std::make_tuple(iteratorTypes, indexingMaps); } // Helper function to produce a linalg.generic that computes a reduction on // dimension \p dim with the operation type \p T. template static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim) { auto inputType = cast(input.getType()); ArrayRef inputShape = inputType.getShape(); int64_t inputRank = inputShape.size(); auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(builder, inputRank, dim); assert(indexingMaps.size() == 2 && "We should have two maps: 1 for the input, 1 for the output"); assert(indexingMaps[0].isIdentity() && "input map should be identity"); auto genericOp = builder.create( loc, output.getType(), input, output, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value result = b.create(loc, args[0], args[1]); b.create(loc, result); }); return genericOp.getResult(0); } /// Produce a linalg generic that computes the second step of the softmax /// decomposition: res = exp(input - max), where \p max is the max of \p input /// on dimension \p dim. static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim) { auto inputType = cast(input.getType()); ArrayRef inputShape = inputType.getShape(); int64_t inputRank = inputShape.size(); auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps( builder, inputRank, dim, /*allParallel=*/true); assert(indexingMaps.size() == 2 && "We should have one map for each input"); assert(indexingMaps[0].isIdentity() && "input map should be identity"); // Add the affine map for the output argument. indexingMaps.push_back(indexingMaps[0]); auto genericOp = builder.create( loc, input.getType(), ValueRange{input, max}, output, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value diff = b.create(loc, args[0], args[1]); Value result = b.create(loc, diff); b.create(loc, result); }); return genericOp.getResult(0); } /// Produce a linalg generic that computes the final step of the softmax /// decomposition. /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) { /// yield n / d /// } static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim) { auto inputType = cast(numerator.getType()); ArrayRef inputShape = inputType.getShape(); int64_t inputRank = inputShape.size(); auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps( builder, inputRank, dim, /*allParallel=*/true); assert(indexingMaps.size() == 2 && "We should have one map for each input (2)"); assert(indexingMaps[0].isIdentity() && "Numerator map should be identity"); // Add the affine map for the output tensor. indexingMaps.push_back(indexingMaps[0]); auto genericOp = builder.create( loc, numerator.getType(), ValueRange{numerator, denominator}, output, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value result = b.create(loc, args[0], args[1]); b.create(loc, result); }); return genericOp.getResult(0); } // @} End helper functions for softmax decomposition. /// Given an N-dimensional tensor x, this method converts /// softmax(x) to the following sequence of operations: /// /// 1. Compute the max of x along dimension d. This results /// in a N-1 dimensional tensor m. /// m = max(x, dim = d) /// /// 2. Subtract a broadcasted m from x and exponentiate. This results in /// a N dimensional tensor z. /// z = exp(x - m) /// /// 3. Compute the sum of z along dimension d. This results in /// a N-1 dimensional tensor l. /// l = sum(z, dim = d) /// /// 4. Divide z and l. This gives the N-dimensional softmax. /// softmax = z / l /// FailureOr> SoftmaxOp::decomposeOperation(OpBuilder &b) { OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(*this); Location loc = getLoc(); Value input = getInput(); ShapedType inputType = getInputOperandType(); Type elementType = inputType.getElementType(); int64_t reductionDim = getDimension(); SmallVector dims = tensor::getMixedSizes(b, loc, input); Value output = getOutput(); dims.erase(dims.begin() + reductionDim); // Step 1: Compute max along dim. Value outputReduce = b.create(loc, dims, elementType); Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf, elementType, b, loc, /*useOnlyFiniteValue=*/true); Value neutralForMaxFInit = b.create(loc, Value{neutralForMaxF}, outputReduce) .result(); Value max = reduce(b, loc, input, neutralForMaxFInit, reductionDim); // Step 2: Subtract max from input and exponentiate. Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim); // Step 3: Compute sum along dim. Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, b, loc, /*useOnlyFiniteValue=*/true); Value zeroInit = b.create(loc, Value{zero}, outputReduce).result(); Value denominator = reduce(b, loc, numerator, zeroInit, reductionDim); // Step 4: Compute softmax. Value result = buildDivOp(b, loc, numerator, denominator, output, reductionDim); return SmallVector{result}; } //===----------------------------------------------------------------------===// // WinogradFilterTransformOp //===----------------------------------------------------------------------===// LogicalResult WinogradFilterTransformOp::verify() { auto filterType = cast(getFilter().getType()); ArrayRef filterShape = filterType.getShape(); int64_t filterH = filterShape[getFilterHDim()]; int64_t filterW = filterShape[getFilterWDim()]; int64_t r = getR(); int64_t m = getM(); if (filterH != r && filterH != 1) return emitOpError("expect filter height either equals to r or 1"); if (filterW != r && filterW != 1) return emitOpError("expect filter width either equals to r or 1"); if (filterH == 1 && filterW == 1) return emitOpError("expect either filter height or width equals to r"); SmallVector expectedOutputShape; expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1); expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1); expectedOutputShape.push_back(filterShape[getFilterCDim()]); expectedOutputShape.push_back(filterShape[getFilterFDim()]); auto outputType = cast(getOutput().getType()); ArrayRef outputShape = outputType.getShape(); if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { return emitOpError("the output shape is not expected"); } return success(); } SmallVector WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) { Location loc = getLoc(); IntegerAttr zeroAttr = builder.getIndexAttr(0); IntegerAttr oneAttr = builder.getIndexAttr(1); Value filter = getFilter(); int64_t filterRank = getFilterOperandRank(); SmallVector loopBounds(filterRank); for (unsigned dim = 0; dim < filterRank; ++dim) { loopBounds[dim].offset = zeroAttr; loopBounds[dim].size = getDimValue(builder, loc, filter, dim); loopBounds[dim].stride = oneAttr; } return loopBounds; } SmallVector WinogradFilterTransformOp::getLoopIteratorTypes() { int64_t filterRank = getFilterOperandRank(); SmallVector iteratorTypes(filterRank, utils::IteratorType::parallel); return iteratorTypes; } LogicalResult WinogradFilterTransformOp::getResultTilePosition( OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes, SmallVector &resultOffsets, SmallVector &resultSizes) { IntegerAttr zeroAttr = builder.getI64IntegerAttr(0); ShapedType filterType = getFilterOperandType(); ArrayRef filterShape = filterType.getShape(); int64_t filterH = filterShape[getFilterHDim()]; int64_t filterW = filterShape[getFilterWDim()]; int64_t m = getM(); int64_t r = getR(); int64_t alpha = m + r - 1; int64_t alphaH = filterH != 1 ? alpha : 1; int64_t alphaW = filterW != 1 ? alpha : 1; IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH); IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW); resultOffsets.append( {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]}); resultSizes.append( {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]}); return success(); } /// Implement tiling for winograd_filter_transform /// The input of winograd_filter_transform is (F, KH, KW, C). /// The output of winograd_filter_transform is (alphaH, alphaW, C, F) /// Users can specify the tile sizes of F and C. /// `offsets` are the values for the offsets of F, KH, KW, C for one tile. /// `sizes` are the values for the sizes of F, KH, KW, C for one tile. FailureOr WinogradFilterTransformOp::getTiledImplementation( OpBuilder &builder, ArrayRef offsets, ArrayRef sizes) { IntegerAttr oneAttr = builder.getI64IntegerAttr(1); IntegerAttr zeroAttr = builder.getI64IntegerAttr(0); ShapedType filterType = getFilterOperandType(); ArrayRef filterShape = filterType.getShape(); int64_t filterH = filterShape[getFilterHDim()]; int64_t filterW = filterShape[getFilterWDim()]; IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH); IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW); SmallVector tiledOperands; SmallVector sliceOffsets, sliceSizes; sliceOffsets.append( {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]}); sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr, sizes[getFilterCDim()]}); int64_t filterRank = getFilterOperandRank(); SmallVector filterStrides(filterRank, oneAttr); Location loc = getLoc(); auto filterSlice = builder.create( loc, getFilter(), sliceOffsets, sliceSizes, filterStrides); tiledOperands.emplace_back(filterSlice); SmallVector resultOffsets, resultSizes; if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets, resultSizes))) return failure(); int64_t outputRank = getOutputOperandRank(); SmallVector outputStrides(outputRank, oneAttr); auto outputSlice = builder.create( loc, getOutput(), resultOffsets, resultSizes, outputStrides); tiledOperands.emplace_back(outputSlice); SmallVector resultTypes; resultTypes.push_back(tiledOperands[1].getType()); Operation *tiledOp = mlir::clone(builder, getOperation(), resultTypes, tiledOperands); return TilingResult{ {tiledOp}, SmallVector(tiledOp->getResults()), llvm::to_vector(ArrayRef{filterSlice, outputSlice})}; } //===----------------------------------------------------------------------===// // WinogradInputTransformOp //===----------------------------------------------------------------------===// LogicalResult WinogradInputTransformOp::verify() { auto inputType = cast(getInput().getType()); ArrayRef inputShape = inputType.getShape(); int64_t inputH = inputShape[getInputHDim()]; int64_t inputW = inputShape[getInputWDim()]; int m = getM(); int r = getR(); int64_t tileSize = m + r - 1; auto outputType = cast(getOutput().getType()); ArrayRef outputShape = outputType.getShape(); bool leftTransform = outputShape[getOutputAlphaHDim()] != 1; bool rightTransform = outputShape[getOutputAlphaWDim()] != 1; SmallVector expectedOutputShape(6, inputH); if (ShapedType::isDynamic(inputH)) { expectedOutputShape[getOutputAlphaHDim()] = tileSize; expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic; } else { expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1; expectedOutputShape[getOutputTileHDim()] = leftTransform ? (inputH - (r - 1)) / m : inputH; } if (ShapedType::isDynamic(inputW)) { expectedOutputShape[getOutputAlphaWDim()] = tileSize; expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic; } else { expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1; expectedOutputShape[getOutputTileWDim()] = rightTransform ? (inputW - (r - 1)) / m : inputW; } expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()]; expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()]; if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { return emitOpError("the output shape is not expected"); } return success(); } SmallVector WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) { Location loc = getLoc(); IntegerAttr zeroAttr = builder.getIndexAttr(0); IntegerAttr oneAttr = builder.getIndexAttr(1); Value output = getOutput(); int64_t outputRank = getOutputOperandRank(); SmallVector loopBounds(outputRank); for (unsigned dim = 0; dim < outputRank; ++dim) { loopBounds[dim].offset = zeroAttr; // alphaH, alphaW, tileH, tileW, N, C loopBounds[dim].size = getDimValue(builder, loc, output, dim); loopBounds[dim].stride = oneAttr; } return loopBounds; } SmallVector WinogradInputTransformOp::getLoopIteratorTypes() { int64_t outputRank = getOutputOperandRank(); SmallVector iteratorTypes(outputRank, utils::IteratorType::parallel); return iteratorTypes; } LogicalResult WinogradInputTransformOp::getResultTilePosition( OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes, SmallVector &resultOffsets, SmallVector &resultSizes) { IntegerAttr zeroAttr = builder.getI64IntegerAttr(0); ShapedType outputType = getOutputOperandType(); ArrayRef outputShape = outputType.getShape(); int64_t outputAlphaH = outputShape[getOutputAlphaHDim()]; int64_t outputAlphaW = outputShape[getOutputAlphaWDim()]; int64_t m = getM(); int64_t r = getR(); int64_t alpha = m + r - 1; int64_t alphaH = outputAlphaH != 1 ? alpha : 1; int64_t alphaW = outputAlphaW != 1 ? alpha : 1; IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH); IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW); resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()], offsets[getOutputTileWDim()], offsets[getOutputNDim()], offsets[getOutputCDim()]}); resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()], sizes[getOutputTileWDim()], sizes[getOutputNDim()], sizes[getOutputCDim()]}); return success(); } /// Implement tiling for winograd_input_transform /// The input of winograd_input_transform is (N, H, W, C). /// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N, /// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are /// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are /// the values for the sizes of tileH, tileW, N, C for one tile. FailureOr WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder, ArrayRef offsets, ArrayRef sizes) { IntegerAttr oneAttr = builder.getI64IntegerAttr(1); int64_t m = getM(); int64_t r = getR(); ShapedType outputType = getOutputOperandType(); ArrayRef outputShape = outputType.getShape(); int64_t alphaH = outputShape[getOutputAlphaHDim()]; int64_t alphaW = outputShape[getOutputAlphaWDim()]; Location loc = getLoc(); MLIRContext *context = builder.getContext(); auto identityAffineMap = AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context); auto offsetAffineMap = AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); Value mappedOffsetH = affine::makeComposedAffineApply( builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap), offsets[getOutputTileHDim()]); Value mappedOffsetW = affine::makeComposedAffineApply( builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap), offsets[getOutputTileWDim()]); auto sizeAffineMap = AffineMap::get( 1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context); Value mappedSizeH = affine::makeComposedAffineApply( builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]); Value mappedSizeW = affine::makeComposedAffineApply( builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]); SmallVector tiledOperands; SmallVector sliceOffsets, sliceSizes; OpFoldResult offsetH = OpFoldResult(mappedOffsetH); OpFoldResult offsetW = OpFoldResult(mappedOffsetW); sliceOffsets.append( {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]}); OpFoldResult sizeH = alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr); OpFoldResult sizeW = alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr); sliceSizes.append( {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]}); int64_t inputRank = getInputOperandRank(); SmallVector inputStrides(inputRank, oneAttr); auto inputSlice = builder.create( loc, getInput(), sliceOffsets, sliceSizes, inputStrides); tiledOperands.emplace_back(inputSlice); SmallVector resultOffsets, resultSizes; if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets, resultSizes))) return failure(); int64_t outputRank = getOutputOperandRank(); SmallVector outputStrides(outputRank, oneAttr); auto outputSlice = builder.create( loc, getOutput(), resultOffsets, resultSizes, outputStrides); tiledOperands.emplace_back(outputSlice); SmallVector resultTypes; resultTypes.push_back(tiledOperands[1].getType()); Operation *tiledOp = mlir::clone(builder, getOperation(), resultTypes, tiledOperands); return TilingResult{ {tiledOp}, SmallVector(tiledOp->getResults()), llvm::to_vector(ArrayRef{inputSlice, outputSlice})}; } //===----------------------------------------------------------------------===// // WinogradOutputTransformOp //===----------------------------------------------------------------------===// LogicalResult WinogradOutputTransformOp::verify() { auto valueType = cast(getValue().getType()); ArrayRef valueShape = valueType.getShape(); int64_t valueH = valueShape[getValueAlphaHDim()]; int64_t valueW = valueShape[getValueAlphaWDim()]; int64_t valueTileH = valueShape[getValueTileHDim()]; int64_t valueTileW = valueShape[getValueTileWDim()]; int m = getM(); int r = getR(); bool leftTransform = valueH != 1; bool rightTransform = valueW != 1; int64_t outputRank = getOutputOperandRank(); SmallVector expectedOutputShape(outputRank, valueH); if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) { expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic; } else { if (valueH != (leftTransform ? m + r - 1 : 1)) return emitOpError("expect input height equals to input tile size"); expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH; } if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) { expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic; } else { if (valueW != (rightTransform ? m + r - 1 : 1)) return emitOpError("expect input width equals to input tile size"); expectedOutputShape[getOutputWDim()] = (rightTransform ? m : 1) * valueTileW; } expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()]; expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()]; auto outputType = cast(getOutput().getType()); ArrayRef outputShape = outputType.getShape(); if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { return emitOpError("the output shape is not expected"); } return success(); } SmallVector WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) { Location loc = getLoc(); IntegerAttr zeroAttr = builder.getIndexAttr(0); IntegerAttr oneAttr = builder.getIndexAttr(1); Value value = getValue(); int64_t valueRank = getValueOperandRank(); SmallVector loopBounds(valueRank); for (unsigned dim = 0; dim < valueRank; ++dim) { loopBounds[dim].offset = zeroAttr; // alphaH, alphaW, tileH, tileW, N, F loopBounds[dim].size = getDimValue(builder, loc, value, dim); loopBounds[dim].stride = oneAttr; } return loopBounds; } SmallVector WinogradOutputTransformOp::getLoopIteratorTypes() { int64_t valueRank = getValueOperandRank(); SmallVector iteratorTypes(valueRank, utils::IteratorType::parallel); return iteratorTypes; } LogicalResult WinogradOutputTransformOp::getResultTilePosition( OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes, SmallVector &resultOffsets, SmallVector &resultSizes) { int64_t m = getM(); Location loc = getLoc(); MLIRContext *context = builder.getContext(); auto identityAffineMap = AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context); auto affineMap = AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); ShapedType valueType = getValueOperandType(); ArrayRef valueShape = valueType.getShape(); int64_t valueH = valueShape[0]; int64_t valueW = valueShape[1]; Value mappedOffsetH = affine::makeComposedAffineApply( builder, loc, (valueH != 1 ? affineMap : identityAffineMap), offsets[getValueTileHDim()]); Value mappedOffsetW = affine::makeComposedAffineApply( builder, loc, (valueW != 1 ? affineMap : identityAffineMap), offsets[getValueTileWDim()]); Value mappedSizeH = affine::makeComposedAffineApply( builder, loc, affineMap, sizes[getValueTileHDim()]); Value mappedSizeW = affine::makeComposedAffineApply( builder, loc, affineMap, sizes[getValueTileWDim()]); IntegerAttr oneAttr = builder.getI64IntegerAttr(1); OpFoldResult offsetH = OpFoldResult(mappedOffsetH); OpFoldResult offsetW = OpFoldResult(mappedOffsetW); OpFoldResult sizeH = valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr); OpFoldResult sizeW = valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr); resultOffsets.append( {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]}); resultSizes.append( {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]}); return success(); } /// Implement tiling for winograd_output_transform /// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N, /// F). The output of winograd_output_transform is (N, H, W, F) Users can /// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values /// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values /// for the sizes of tileH, tileW, N, F for one tile. FailureOr WinogradOutputTransformOp::getTiledImplementation( OpBuilder &builder, ArrayRef offsets, ArrayRef sizes) { IntegerAttr oneAttr = builder.getI64IntegerAttr(1); IntegerAttr zeroAttr = builder.getI64IntegerAttr(0); Location loc = getLoc(); SmallVector tiledOperands; SmallVector sliceOffsets, sliceSizes; ShapedType valueType = getValueOperandType(); ArrayRef valueShape = valueType.getShape(); int64_t alphaH = valueShape[getValueAlphaHDim()]; int64_t alphaW = valueShape[getValueAlphaWDim()]; IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH); IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW); sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()], offsets[getValueTileWDim()], offsets[getValueNDim()], offsets[getValueFDim()]}); sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()], sizes[getValueTileWDim()], sizes[getValueNDim()], sizes[getValueFDim()]}); int64_t valueRank = getValueOperandRank(); SmallVector sliceStrides(valueRank, oneAttr); auto valueSlice = builder.create( loc, getValue(), sliceOffsets, sliceSizes, sliceStrides); tiledOperands.emplace_back(valueSlice); SmallVector resultOffsets, resultSizes; if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets, resultSizes))) return failure(); int64_t outputRank = getOutputOperandRank(); SmallVector strides(outputRank, oneAttr); auto outputSlice = builder.create( loc, getOutput(), resultOffsets, resultSizes, strides); tiledOperands.emplace_back(outputSlice); SmallVector resultTypes; resultTypes.push_back(tiledOperands[1].getType()); Operation *tiledOp = mlir::clone(builder, getOperation(), resultTypes, tiledOperands); return TilingResult{ {tiledOp}, SmallVector(tiledOp->getResults()), llvm::to_vector(ArrayRef{valueSlice, outputSlice})}; } //===----------------------------------------------------------------------===// // LinalgDialect //===----------------------------------------------------------------------===// void LinalgDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { results.add(getContext()); } Operation *LinalgDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { return arith::ConstantOp::materialize(builder, value, type, loc); } /// Returns true if the result AffineExpr of the \p explicitMap is same as \p /// defaultMap. static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) { auto explicitRange = explictMap.getResults(); auto defaultRange = defaultMap.getResults(); DenseSet explicitSet(explicitRange.begin(), explicitRange.end()); DenseSet defaultSet(defaultRange.begin(), defaultRange.end()); llvm::set_union(explicitSet, defaultSet); return explicitSet == defaultSet; } /// Returns true if the \p explictMap is broadcasted with respect to the /// \p defaultMap. static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) { return explictMap.getNumResults() < defaultMap.getNumResults(); } /// Verifies the broadcast and transpose semantic sepecified by the explicit /// indexing map for the MatmulOp \p op for each operand specified by \p /// opIndex. static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex) { SmallVector opIndexingMaps = matmulOp.getIndexingMapsArray(); SmallVector defaultIndexingMaps = matmulOp.getDefaultIndexingMaps(matmulOp->getContext()); auto opIndexingMap = opIndexingMaps[opIndex]; auto defaultIndexingMap = defaultIndexingMaps[opIndex]; // Check general validity of indexing map results. if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap)) return matmulOp->emitOpError() << "Unexpected dim expression in map result."; // Check if the requested broadcast is valid. if (isBroadcasted(opIndexingMap, defaultIndexingMap)) { if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) { return matmulOp->emitOpError() << "Invalid broadcast requested, should be (d2)."; } return success(); } return success(); } namespace mlir { namespace linalg { //===----------------------------------------------------------------------===// // MatMulOp //===----------------------------------------------------------------------===// /// Returns a list of AffineMap with the typical matmul indexing charactristic. SmallVector MatmulOp::getDefaultIndexingMaps(MLIRContext *context) { AffineExpr d0, d1, d2; SmallVector indexingMaps; bindDims(context, d0, d1, d2); indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context)); indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context)); indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context)); return indexingMaps; } SmallVector MatmulOp::getIteratorTypesArray() { return SmallVector{utils::IteratorType::parallel, utils::IteratorType::parallel, utils::IteratorType::reduction}; } unsigned MatmulOp::getNumRegionArgs() { return 3; } std::string MatmulOp::getLibraryCallName() { return generateLibraryCallName(getOperation()); } bool MatmulOp::hasDynamicIndexingMaps() { return true; } /// Check if the op has broadcast and/or transpose semantic. Returns true if /// the user defined indexing maps are not equal to default map. bool MatmulOp::hasUserDefinedMaps() { SmallVector defaultMaps = getDefaultIndexingMaps(this->getContext()); SmallVector explicitMaps = getIndexingMapsArray(); return defaultMaps != explicitMaps; } /// Implements the block region builder for the MatmulOp. This is called by /// 'fillStructuredOpRegion'. void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, ArrayRef attrs) { assert(3 > 0 && block.getNumArguments() == 3 && "MatmulOp regionBuilder expects 3 (>=0) args"); RegionBuilderHelper helper(b, block); SmallVector yields; TypeFn castVal = TypeFn::cast_signed; auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) { return attr.getName() == "cast"; }); if (castIter != attrs.end()) { if (auto attr = llvm::dyn_cast(castIter->getValue())) castVal = attr.getValue(); } Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(), block.getArgument(0)); Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(), block.getArgument(1)); Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2); Value value4 = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3); yields.push_back(value4); helper.yieldOutputs(yields); } /// Returns true if the given broadcast map \p bcastMap is valid for this op. bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) { assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr."); AffineExpr exp = bcastMap.getResult(0); // Invalid map if the common dimension of matmul not found. return exp.isFunctionOfDim(bcastMap.getNumDims() - 1); } FailureOr parseIndexingMapsAttr(OpAsmParser &parser) { if (parser.parseOptionalKeyword("indexing_maps")) return {nullptr}; // Success in case indexing_maps was not provided. ArrayAttr arrayAttr; if (parser.parseEqual() || parser.parseAttribute(arrayAttr)) return failure(); if (llvm::any_of(arrayAttr, [](auto elt) { return !dyn_cast(elt); })) return parser.emitError(parser.getCurrentLocation()) << "element of indexing_maps array is not an affine_map"; return arrayAttr; } ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) { FailureOr indexingMapsAttr = parseIndexingMapsAttr(parser); if (failed(indexingMapsAttr)) return failure(); if (*indexingMapsAttr == nullptr) { auto indexingMapAttrs = llvm::map_to_vector( MatmulOp::getDefaultIndexingMaps(parser.getContext()), [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs); } result.addAttribute("indexing_maps", *indexingMapsAttr); return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(), MatmulOp::getRegionBuilder()); } void MatmulOp::print(OpAsmPrinter &p) { SmallVector elidedAttrs = { "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"}; printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(), elidedAttrs); SmallVector indexingMaps = llvm::map_to_vector( MatmulOp::getDefaultIndexingMaps(getContext()), [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); if (!llvm::equal(getIndexingMaps(), indexingMaps)) { p << " indexing_maps = ["; llvm::interleaveComma(getIndexingMaps(), p, [&](Attribute attr) { p.printAttribute(attr); }); p << "]"; } } /// Verify the user defined indexing maps. LogicalResult MatmulOp::verify() { // Verification of pure matmul is handled by verifyStructuredOpInterface(). if (!hasUserDefinedMaps()) return success(); for (unsigned opIndex = 0; opIndex < 2; opIndex++) { if (failed(verifyExtendedMatmulSemantic(*this, opIndex))) return failure(); } return success(); } LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl &) { return memref::foldMemRefCast(*this); } void MatmulOp::getEffects( SmallVectorImpl> &effects) { if (hasPureTensorSemantics()) return; getGenericEffectsImpl(effects, cast(getOperation())); } Speculation::Speculatability MatmulOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast(getOperation())); } //===----------------------------------------------------------------------===// // ContractOp //===----------------------------------------------------------------------===// SmallVector ContractOp::getIteratorTypesArray() { AffineMap outAffineMap = getIndexingMapsArray().pop_back_val(); // On well-formed IR, indexing_maps is non-empty, contained affine_maps' // domains are all the same, and each implements a projected permutation. // Each iteration space dim must occur for at least one operand and either // takes part in a contraction/reduction or else has parallel iteration type. // We have that a dim is a contraction/reduction dim if and only if the dim // occurs for the output operand. We use this fact for fast inference: // NB: In case we allow dims to occur solely for one input, the above still // holds: per the einsum semantics, these are reduction dims as well. SmallVector dimsInOutput(outAffineMap.getNumDims(), false); for (auto result : outAffineMap.getResults()) { auto dimExpr = dyn_cast(result); assert(dimExpr && "affine_map is a projected permutation"); dimsInOutput[dimExpr.getPosition()] = true; } SmallVector iteratorTypes; for (auto dimOccursInOutput : dimsInOutput) iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel : utils::IteratorType::reduction); return iteratorTypes; } unsigned ContractOp::getNumRegionArgs() { return 3; } /// Implement block region builder, which is called by 'fillStructuredOpRegion'. void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, ArrayRef attrs) { assert(block.getNumArguments() == 3 && "ContractOp regionBuilder expects 3 args"); RegionBuilderHelper helper(b, block); TypeFn castSignedness = TypeFn::cast_signed; auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) { return attr.getName() == "cast"; }); if (castIter != attrs.end()) { if (auto attr = llvm::dyn_cast(castIter->getValue())) castSignedness = attr.getValue(); } // TODO: Support fields with operators besides mult & add. Type outType = block.getArgument(2).getType(); Value lhsAtOutType = helper.buildTypeFn(castSignedness, outType, block.getArgument(0)); Value rhsAtOutType = helper.buildTypeFn(castSignedness, outType, block.getArgument(1)); Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType); Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), productAtOutType); helper.yieldOutputs({result}); } ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) { FailureOr indexingMapsAttr = parseIndexingMapsAttr(parser); if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr) return parser.emitError(parser.getCurrentLocation(), "expected 'indexing_maps' attribute"); result.addAttribute("indexing_maps", *indexingMapsAttr); return parseNamedStructuredOp(parser, result, getNumRegionArgs(), regionBuilder); } void ContractOp::print(OpAsmPrinter &p) { p << " indexing_maps = ["; llvm::interleaveComma(getIndexingMaps(), p, [&](Attribute attr) { p.printAttribute(attr); }); p << "]"; printNamedStructuredOp( p, getOperation(), getInputs(), getOutputs(), /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"}); } LogicalResult ContractOp::verify() { int iterationSpaceDims = -1; // Map iter space dims to #occurrences in inputs' and output's affine_maps: // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to // access an input operand (so occurrence count can be at most 2) and // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc. SmallVector inOccurrences; SmallVector outOccurrences; // A helper so that for each operand's affine_map and type we check that ... auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType, bool isInput) -> LogicalResult { // ... the affine_map is a projected permutation; if (!affineMap.isProjectedPermutation()) return emitError("provided affine_map is not a projected permutation"); // ... the rank of the affine_map's results and corresponding type match; if (auto shapedType = dyn_cast(operandType)) { if (affineMap.getNumResults() != shapedType.getRank()) return emitError("ranks of shaped operand and results of corresponding " "affine_map differ"); } else if (affineMap.getNumResults() != 0) { return emitError("affine_map specifies shaped access while operand has " "non-shaped type"); } // ... the rank of the affine_map's domain is the same as those seen prior; if (iterationSpaceDims == -1) { iterationSpaceDims = affineMap.getNumDims(); inOccurrences = SmallVector(iterationSpaceDims, 0); outOccurrences = SmallVector(iterationSpaceDims, 0); } else if (iterationSpaceDims != (int)affineMap.getNumDims()) { return emitError("iteration spaces of provided affine_maps differ"); } // ... update counts of dims used to access either an input or the output. for (AffineExpr affineExpr : affineMap.getResults()) { auto affineDimExpr = dyn_cast(affineExpr); if (!affineDimExpr) llvm_unreachable("affine_map is a projected permutation"); if (isInput) inOccurrences[affineDimExpr.getPosition()] += 1; else outOccurrences[affineDimExpr.getPosition()] += 1; } return success(); }; for (auto &&[affineMap, operandType, isInput] : llvm::zip(getIndexingMapsArray(), getOperandTypes(), SmallVector{true, true, false})) { if (failed(checkAffineMapAndType(affineMap, operandType, isInput))) return failure(); // NB: checkAffineMapAndType will emit relevant error. } bool hasContractingDim = false; for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) { size_t inOccCount = inOccurrences[dimIndex]; size_t outOccCount = outOccurrences[dimIndex]; // We have a contracting dim if and only if ... hasContractingDim |= inOccCount == 2 && outOccCount == 0; if (inOccCount == 0 && outOccCount == 0) return emitError() << "iteration space dim at index " << dimIndex << " not used to access any operand"; // NB: We disallow a dim which occurs for only one input operand and not // for the output. In terms of einsum semantics such dims have a // sensible meaning - namely an additional reduction per each such dim. // By contrast, the ContractionOpInterface does not know about this // iter type - cf. inferContractionDims' supported dim kinds. Similarly, // while vector.contract's verifier accepts dims of this kind many of // its lowerings give up on encountering these dims. // TODO: Remove following once we have comprehensive support for input-only // reduction dims, at both the linalg- and vector-dialect levels. if (inOccCount == 1 && outOccCount != 1) return emitError() << "iteration space dim at index " << dimIndex << " is neither a contracting dim nor of parallel iteration type"; } if (!hasContractingDim) return emitError("'indexing_maps' do not specify a contracting dimension"); return success(); } LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl &) { return memref::foldMemRefCast(*this); } void ContractOp::getEffects( SmallVectorImpl> &effects) { if (hasPureTensorSemantics()) return; getGenericEffectsImpl(effects, cast(getOperation())); } Speculation::Speculatability ContractOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast(getOperation())); } } // namespace linalg } // namespace mlir