1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the Linalg operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/IR/Linalg.h" 14 15 #include "mlir/AsmParser/AsmParser.h" 16 #include "mlir/Dialect/Affine/IR/AffineOps.h" 17 #include "mlir/Dialect/Arith/IR/Arith.h" 18 #include "mlir/Dialect/Arith/Utils/Utils.h" 19 #include "mlir/Dialect/Complex/IR/Complex.h" 20 #include "mlir/Dialect/Math/IR/Math.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/IR/SCF.h" 23 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/Dialect/Utils/IndexingUtils.h" 26 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" 27 #include "mlir/Dialect/Utils/StaticValueUtils.h" 28 #include "mlir/IR/AffineExprVisitor.h" 29 #include "mlir/IR/AffineMap.h" 30 #include "mlir/IR/Attributes.h" 31 #include "mlir/IR/BuiltinAttributes.h" 32 #include "mlir/IR/BuiltinTypeInterfaces.h" 33 #include "mlir/IR/Matchers.h" 34 #include "mlir/IR/OpImplementation.h" 35 #include "mlir/IR/OperationSupport.h" 36 #include "mlir/IR/PatternMatch.h" 37 #include "mlir/Interfaces/InferTypeOpInterface.h" 38 #include "mlir/Interfaces/SideEffectInterfaces.h" 39 40 #include "llvm/ADT/DenseMap.h" 41 #include "llvm/ADT/STLExtras.h" 42 #include "llvm/ADT/SetOperations.h" 43 #include "llvm/ADT/SmallSet.h" 44 #include "llvm/ADT/SmallVector.h" 45 #include "llvm/ADT/StringSet.h" 46 #include "llvm/ADT/TypeSwitch.h" 47 #include "llvm/Support/FormatVariadic.h" 48 #include "llvm/Support/LogicalResult.h" 49 #include "llvm/Support/MathExtras.h" 50 #include "llvm/Support/raw_ostream.h" 51 #include <cassert> 52 #include <optional> 53 54 using namespace mlir; 55 using namespace mlir::linalg; 56 57 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`. 58 static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, 59 int64_t dim) { 60 auto type = cast<ShapedType>(v.getType()); 61 if (!type.isDynamicDim(dim)) 62 return builder.getIndexAttr(type.getDimSize(dim)); 63 64 return getAsOpFoldResult( 65 TypeSwitch<Type, Value>(v.getType()) 66 .Case<RankedTensorType>([&](RankedTensorType t) -> Value { 67 return builder.create<tensor::DimOp>(loc, v, dim); 68 }) 69 .Case<MemRefType>([&](MemRefType t) -> Value { 70 return builder.create<memref::DimOp>(loc, v, dim); 71 })); 72 } 73 74 /// Returns a memref.subview or a tensor.extract_slice based on the type of the 75 /// `source`. 76 static Operation *getSlice(OpBuilder &b, Location loc, Value source, 77 ArrayRef<OpFoldResult> offsets, 78 ArrayRef<OpFoldResult> sizes, 79 ArrayRef<OpFoldResult> strides) { 80 return TypeSwitch<Type, Operation *>(source.getType()) 81 .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * { 82 return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes, 83 strides); 84 }) 85 .Case<MemRefType>([&](MemRefType type) -> Operation * { 86 return b.create<memref::SubViewOp>(loc, source, offsets, sizes, 87 strides); 88 }) 89 .Default([&](Type t) -> Operation * { return nullptr; }); 90 } 91 92 //===----------------------------------------------------------------------===// 93 // Helper functions 94 //===----------------------------------------------------------------------===// 95 96 Value linalg::createOrFoldDimOp(OpBuilder &b, Location loc, Value source, 97 int64_t dim) { 98 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType())) 99 return b.createOrFold<memref::DimOp>(loc, source, dim); 100 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType())) 101 return b.createOrFold<tensor::DimOp>(loc, source, dim); 102 llvm_unreachable("Expected MemRefType or TensorType"); 103 } 104 105 OpFoldResult linalg::createFoldedDimOp(OpBuilder &b, Location loc, Value source, 106 int64_t dim) { 107 auto shapedType = llvm::cast<ShapedType>(source.getType()); 108 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) 109 return createOrFoldDimOp(b, loc, source, dim); 110 return b.getIndexAttr(shapedType.getDimSize(dim)); 111 } 112 113 //===----------------------------------------------------------------------===// 114 // Support for named Linalg ops defined in ods-gen. 115 //===----------------------------------------------------------------------===// 116 117 using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &, 118 ArrayRef<NamedAttribute>)>; 119 120 /// Fills the region of a structured operation using the provided 121 /// `regionBuilder`. The method is used by both named structured ops created by 122 /// ods-gen and by manually defined C++ ops. It is called by both builders and 123 /// parsers and creates a block with arguments corresponding to the elemental 124 /// types of `inputTypes` and `outputTypes`. 125 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, 126 TypeRange inputTypes, TypeRange outputTypes, 127 ArrayRef<NamedAttribute> attrs, 128 RegionBuilderFn regionBuilder) { 129 SmallVector<Type, 8> argTypes; 130 SmallVector<Location, 8> argLocs; 131 for (auto containers : {inputTypes, outputTypes}) { 132 for (auto t : containers) { 133 argTypes.push_back( 134 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t); 135 136 // TODO: Pass in a proper location here. 137 argLocs.push_back(opBuilder.getUnknownLoc()); 138 } 139 } 140 141 // RAII. 142 OpBuilder::InsertionGuard guard(opBuilder); 143 Block *body = 144 opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs); 145 146 opBuilder.setInsertionPointToStart(body); 147 ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); 148 regionBuilder(b, *body, attrs); 149 150 // indexing_maps is an auto-generated method. 151 152 // iterator_types is an auto-generated method. 153 } 154 155 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`. 156 /// The result types are derived automatically if `resultTensorTypes` is none. 157 /// The body of the operation is filled using `regionBuilder`. All ods-gen 158 /// created structured operations use the method to implement their builders. 159 static void buildStructuredOp(OpBuilder &b, OperationState &state, 160 std::optional<TypeRange> resultTensorTypes, 161 ValueRange inputs, ValueRange outputs, 162 ArrayRef<NamedAttribute> attributes, 163 RegionBuilderFn regionBuilder) { 164 // Derive the result types if needed. 165 SmallVector<Type> derivedResultTypes = 166 resultTensorTypes.value_or(TypeRange()); 167 if (!resultTensorTypes) 168 copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes), 169 llvm::IsaPred<RankedTensorType>); 170 171 state.addOperands(inputs); 172 state.addOperands(outputs); 173 state.addTypes(derivedResultTypes); 174 175 state.addAttributes(attributes); 176 state.addAttribute( 177 "operandSegmentSizes", 178 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()), 179 static_cast<int32_t>(outputs.size())})); 180 181 // Create and fill the region of the structured operation. 182 Region ®ion = *state.addRegion(); 183 fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs), 184 state.attributes.getAttrs(), regionBuilder); 185 } 186 187 static void buildMatmulOp(OpBuilder &b, OperationState &state, 188 std::optional<TypeRange> resultTensorTypes, 189 ValueRange inputs, ValueRange outputs, 190 ArrayRef<NamedAttribute> attributes, 191 RegionBuilderFn regionBuilder, 192 ArrayRef<AffineMap> indexingMaps) { 193 // Initialize indexingMaps attribute, for MatmulOp. 194 SmallVector<Attribute, 3> indexingMapsAttrVal; 195 indexingMapsAttrVal = llvm::map_to_vector( 196 MatmulOp::getDefaultIndexingMaps(b.getContext()), 197 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); 198 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal)); 199 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs, 200 attributes, regionBuilder); 201 } 202 203 /// Common parsing used for both named structured ops created by ods-gen and by 204 /// manually defined C++ ops. Does not handle regions. 205 static ParseResult 206 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, 207 SmallVectorImpl<Type> &inputTypes, 208 SmallVectorImpl<Type> &outputTypes, 209 bool addOperandSegmentSizes = true) { 210 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc; 211 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands, 212 outputsOperands; 213 214 if (succeeded(parser.parseOptionalLess())) { 215 if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater()) 216 return failure(); 217 } 218 attrsLoc = parser.getCurrentLocation(); 219 if (parser.parseOptionalAttrDict(result.attributes)) 220 return failure(); 221 222 if (succeeded(parser.parseOptionalKeyword("ins"))) { 223 if (parser.parseLParen()) 224 return failure(); 225 226 inputsOperandsLoc = parser.getCurrentLocation(); 227 if (parser.parseOperandList(inputsOperands) || 228 parser.parseColonTypeList(inputTypes) || parser.parseRParen()) 229 return failure(); 230 } 231 232 if (succeeded(parser.parseOptionalKeyword("outs"))) { 233 outputsOperandsLoc = parser.getCurrentLocation(); 234 if (parser.parseLParen() || parser.parseOperandList(outputsOperands) || 235 parser.parseColonTypeList(outputTypes) || parser.parseRParen()) 236 return failure(); 237 } 238 239 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc, 240 result.operands) || 241 parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc, 242 result.operands)) 243 return failure(); 244 245 if (addOperandSegmentSizes) { 246 // This is a bit complex because we're trying to be backward compatible with 247 // operation syntax that mix the inherent attributes and the discardable 248 // ones in the same dictionary. If the properties are used, we append the 249 // operandSegmentSizes there directly. Otherwise we append it to the 250 // discardable attributes dictionary where it is handled by the generic 251 // Operation::create(...) method. 252 if (result.propertiesAttr) { 253 NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr); 254 attrs.append("operandSegmentSizes", 255 parser.getBuilder().getDenseI32ArrayAttr( 256 {static_cast<int32_t>(inputsOperands.size()), 257 static_cast<int32_t>(outputsOperands.size())})); 258 result.propertiesAttr = attrs.getDictionary(parser.getContext()); 259 } else { 260 result.addAttribute("operandSegmentSizes", 261 parser.getBuilder().getDenseI32ArrayAttr( 262 {static_cast<int32_t>(inputsOperands.size()), 263 static_cast<int32_t>(outputsOperands.size())})); 264 } 265 } 266 if (!result.propertiesAttr) { 267 std::optional<RegisteredOperationName> info = 268 result.name.getRegisteredInfo(); 269 if (info) { 270 if (failed(info->verifyInherentAttrs(result.attributes, [&]() { 271 return parser.emitError(attrsLoc) 272 << "'" << result.name.getStringRef() << "' op "; 273 }))) 274 return failure(); 275 } 276 } 277 return success(); 278 } 279 280 static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, 281 ValueRange outputs) { 282 if (!inputs.empty()) 283 p << " ins(" << inputs << " : " << inputs.getTypes() << ")"; 284 if (!outputs.empty()) 285 p << " outs(" << outputs << " : " << outputs.getTypes() << ")"; 286 } 287 288 //===----------------------------------------------------------------------===// 289 // Specific parsing and printing for named structured ops created by ods-gen. 290 //===----------------------------------------------------------------------===// 291 292 static ParseResult parseNamedStructuredOpRegion( 293 OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, 294 TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs, 295 RegionBuilderFn regionBuilder) { 296 if (numRegionArgs != inputTypes.size() + outputTypes.size()) { 297 return parser.emitError( 298 parser.getCurrentLocation(), 299 llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated " 300 "region expects {0} args, got {1}", 301 numRegionArgs, inputTypes.size() + outputTypes.size())); 302 } 303 304 OpBuilder opBuilder(parser.getContext()); 305 fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs, 306 regionBuilder); 307 return success(); 308 } 309 310 static ParseResult 311 parseNamedStructuredOpResults(OpAsmParser &parser, 312 SmallVectorImpl<Type> &resultTypes) { 313 if (parser.parseOptionalArrowTypeList(resultTypes)) 314 return failure(); 315 return success(); 316 } 317 318 static ParseResult parseNamedStructuredOp(OpAsmParser &parser, 319 OperationState &result, 320 unsigned numRegionArgs, 321 RegionBuilderFn regionBuilder) { 322 // TODO: Enable when ods-gen supports captures. 323 SmallVector<Type, 1> inputTypes, outputTypes; 324 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) 325 return failure(); 326 327 // Parse optional attributes. 328 if (parser.parseOptionalAttrDict(result.attributes)) 329 return failure(); 330 331 // TODO: consider merging results parsing into region parsing. 332 // Need to wait for declarative assembly resolution to decide. 333 SmallVector<Type, 1> outputTensorsTypes; 334 if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) 335 return failure(); 336 result.addTypes(outputTensorsTypes); 337 338 std::unique_ptr<Region> region = std::make_unique<Region>(); 339 if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes, 340 outputTypes, result.attributes.getAttrs(), 341 regionBuilder)) 342 return failure(); 343 result.addRegion(std::move(region)); 344 345 return success(); 346 } 347 348 static void printNamedStructuredOpResults(OpAsmPrinter &p, 349 TypeRange resultTypes) { 350 if (resultTypes.empty()) 351 return; 352 p.printOptionalArrowTypeList(resultTypes); 353 } 354 355 static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, 356 ValueRange inputs, ValueRange outputs, 357 ArrayRef<StringRef> elidedAttrs = {}) { 358 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs); 359 360 // Printing is shared with generic ops, except for the region and 361 // attributes. 362 printCommonStructuredOpParts(p, inputs, outputs); 363 364 // Results printing. 365 printNamedStructuredOpResults(p, op->getResultTypes()); 366 367 // Region is elided. 368 } 369 370 //===----------------------------------------------------------------------===// 371 // Region builder helper. 372 // TODO: Move this to a utility library. 373 // The public methods on this class are referenced directly from generated code. 374 // Helper build the unary, binary, and type conversion functions defined by the 375 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this 376 // class. 377 // 378 // Implementations of the math functions must be polymorphic over numeric types, 379 // internally performing necessary casts. If the function application makes no 380 // sense, then the only recourse is to assert and return nullptr. This can be 381 // extended later if it becomes possible to fail construction of the region. The 382 // invariant should be enforced at a higher level. 383 // 384 // TODO: These helpers are currently type polymorphic over the class of integer 385 // and floating point types, but they will not internally cast within bit 386 // widths of a class (mixed precision such as i8->i32) or across classes 387 // (i.e. mixed float and integer). Many such combinations are ambiguous or need 388 // to be handled with care and work is being considered to extend the op 389 // language to make such cases explicit. In the mean-time, violating this will 390 // fail verification, which is deemed acceptable. 391 //===----------------------------------------------------------------------===// 392 393 namespace { 394 395 class RegionBuilderHelper { 396 public: 397 RegionBuilderHelper(OpBuilder &builder, Block &block) 398 : builder(builder), block(block) {} 399 400 // Build the unary functions defined by OpDSL. 401 Value buildUnaryFn(UnaryFn unaryFn, Value arg) { 402 if (!isFloatingPoint(arg)) 403 llvm_unreachable("unsupported non numeric type"); 404 OpBuilder::InsertionGuard g(builder); 405 builder.setInsertionPointToEnd(&block); 406 switch (unaryFn) { 407 case UnaryFn::exp: 408 return builder.create<math::ExpOp>(arg.getLoc(), arg); 409 case UnaryFn::log: 410 return builder.create<math::LogOp>(arg.getLoc(), arg); 411 case UnaryFn::abs: 412 return builder.create<math::AbsFOp>(arg.getLoc(), arg); 413 case UnaryFn::ceil: 414 return builder.create<math::CeilOp>(arg.getLoc(), arg); 415 case UnaryFn::floor: 416 return builder.create<math::FloorOp>(arg.getLoc(), arg); 417 case UnaryFn::negf: 418 return builder.create<arith::NegFOp>(arg.getLoc(), arg); 419 case UnaryFn::reciprocal: { 420 Attribute oneAttr = builder.getOneAttr(arg.getType()); 421 auto one = builder.create<arith::ConstantOp>(arg.getLoc(), 422 ::cast<TypedAttr>(oneAttr)); 423 return builder.create<arith::DivFOp>(arg.getLoc(), one, arg); 424 } 425 case UnaryFn::round: 426 return builder.create<math::RoundOp>(arg.getLoc(), arg); 427 case UnaryFn::sqrt: 428 return builder.create<math::SqrtOp>(arg.getLoc(), arg); 429 case UnaryFn::rsqrt: 430 return builder.create<math::RsqrtOp>(arg.getLoc(), arg); 431 case UnaryFn::square: 432 return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg); 433 case UnaryFn::tanh: 434 return builder.create<math::TanhOp>(arg.getLoc(), arg); 435 case UnaryFn::erf: 436 return builder.create<math::ErfOp>(arg.getLoc(), arg); 437 } 438 llvm_unreachable("unsupported unary function"); 439 } 440 441 // Build the binary functions defined by OpDSL. 442 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) { 443 bool allComplex = isComplex(arg0) && isComplex(arg1); 444 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1); 445 bool allInteger = isInteger(arg0) && isInteger(arg1); 446 bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 && 447 arg1.getType().getIntOrFloatBitWidth() == 1; 448 if (!allComplex && !allFloatingPoint && !allInteger) 449 llvm_unreachable("unsupported non numeric type"); 450 OpBuilder::InsertionGuard g(builder); 451 builder.setInsertionPointToEnd(&block); 452 switch (binaryFn) { 453 case BinaryFn::add: 454 if (allComplex) 455 return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1); 456 if (allFloatingPoint) 457 return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1); 458 if (allBool) 459 return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1); 460 return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1); 461 case BinaryFn::sub: 462 if (allComplex) 463 return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1); 464 if (allFloatingPoint) 465 return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1); 466 if (allBool) 467 llvm_unreachable("unsupported operation: sub with bools"); 468 return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1); 469 case BinaryFn::mul: 470 if (allComplex) 471 return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1); 472 if (allFloatingPoint) 473 return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1); 474 if (allBool) 475 return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1); 476 return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1); 477 case BinaryFn::div: 478 if (allComplex) 479 return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1); 480 if (allFloatingPoint) 481 return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1); 482 if (allBool) 483 llvm_unreachable("unsupported operation: div with bools"); 484 return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1); 485 case BinaryFn::div_unsigned: 486 if (!allInteger || allBool) 487 llvm_unreachable("unsupported operation: unsigned div not on uint"); 488 return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1); 489 case BinaryFn::max_signed: 490 assert(!allComplex); 491 if (allFloatingPoint) 492 return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1); 493 return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1); 494 case BinaryFn::min_signed: 495 assert(!allComplex); 496 if (allFloatingPoint) 497 return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1); 498 return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1); 499 case BinaryFn::max_unsigned: 500 assert(!allComplex); 501 if (allFloatingPoint) 502 return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1); 503 return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1); 504 case BinaryFn::min_unsigned: 505 assert(!allComplex); 506 if (allFloatingPoint) 507 return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1); 508 return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1); 509 case BinaryFn::powf: 510 assert(allFloatingPoint); 511 return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1); 512 } 513 llvm_unreachable("unsupported binary function"); 514 } 515 516 // Build the ternary functions defined by OpDSL. 517 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, 518 Value arg2) { 519 bool headBool = 520 isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1; 521 bool tailFloatingPoint = 522 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2); 523 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2); 524 OpBuilder::InsertionGuard g(builder); 525 builder.setInsertionPointToEnd(&block); 526 switch (ternaryFn) { 527 case TernaryFn::select: 528 if (!headBool && !(tailFloatingPoint || tailInteger)) 529 llvm_unreachable("unsupported non numeric type"); 530 return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2); 531 } 532 llvm_unreachable("unsupported ternary function"); 533 } 534 535 // Build the type functions defined by OpDSL. 536 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) { 537 switch (typeFn) { 538 case TypeFn::cast_signed: 539 return cast(toType, operand, false); 540 case TypeFn::cast_unsigned: 541 return cast(toType, operand, true); 542 } 543 llvm_unreachable("unsupported type conversion function"); 544 } 545 546 void yieldOutputs(ValueRange values) { 547 OpBuilder::InsertionGuard g(builder); 548 builder.setInsertionPointToEnd(&block); 549 Location loc = builder.getUnknownLoc(); 550 builder.create<YieldOp>(loc, values); 551 } 552 553 Value constant(const std::string &value) { 554 OpBuilder::InsertionGuard g(builder); 555 builder.setInsertionPointToEnd(&block); 556 Location loc = builder.getUnknownLoc(); 557 Attribute valueAttr = parseAttribute(value, builder.getContext()); 558 return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr)); 559 } 560 561 Value index(int64_t dim) { 562 OpBuilder::InsertionGuard g(builder); 563 builder.setInsertionPointToEnd(&block); 564 return builder.create<IndexOp>(builder.getUnknownLoc(), dim); 565 } 566 567 Type getIntegerType(unsigned width) { 568 return IntegerType::get(builder.getContext(), width); 569 } 570 571 Type getFloat32Type() { return Float32Type::get(builder.getContext()); } 572 Type getFloat64Type() { return Float64Type::get(builder.getContext()); } 573 574 private: 575 // Generates operations to cast the given operand to a specified type. 576 // If the cast cannot be performed, a warning will be issued and the 577 // operand returned as-is (which will presumably yield a verification 578 // issue downstream). 579 Value cast(Type toType, Value operand, bool isUnsignedCast) { 580 OpBuilder::InsertionGuard g(builder); 581 builder.setInsertionPointToEnd(&block); 582 auto loc = operand.getLoc(); 583 return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast); 584 } 585 586 bool isComplex(Value value) { 587 return llvm::isa<ComplexType>(value.getType()); 588 } 589 bool isFloatingPoint(Value value) { 590 return llvm::isa<FloatType>(value.getType()); 591 } 592 bool isInteger(Value value) { 593 return llvm::isa<IntegerType>(value.getType()); 594 } 595 596 OpBuilder &builder; 597 Block █ 598 }; 599 600 } // namespace 601 602 //===----------------------------------------------------------------------===// 603 // CopyOp 604 //===----------------------------------------------------------------------===// 605 606 namespace { 607 608 struct EraseSelfCopy : OpRewritePattern<CopyOp> { 609 using OpRewritePattern<CopyOp>::OpRewritePattern; 610 LogicalResult matchAndRewrite(CopyOp copyOp, 611 PatternRewriter &rewriter) const override { 612 if (copyOp.getInputs() != copyOp.getOutputs()) 613 return rewriter.notifyMatchFailure(copyOp, "not a self copy"); 614 if (copyOp.hasPureBufferSemantics()) 615 rewriter.eraseOp(copyOp); 616 else 617 rewriter.replaceOp(copyOp, copyOp.getInputs()); 618 619 return success(); 620 } 621 }; 622 623 } // namespace 624 625 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, 626 MLIRContext *context) { 627 results.add<EraseSelfCopy>(context); 628 } 629 630 //===----------------------------------------------------------------------===// 631 // FillOp 632 //===----------------------------------------------------------------------===// 633 634 namespace { 635 636 /// Fold linalg.fill -> tensor.expand/collapse_shape chain. 637 /// 638 /// For such op chains, we can create new linalg.fill ops with the result 639 /// type of the tensor.expand/collapse_shape op. 640 template <typename TensorReshapeOp> 641 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> { 642 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 643 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 644 PatternRewriter &rewriter) const override { 645 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>(); 646 if (!oldFill) 647 return failure(); 648 649 Location loc = oldFill.getLoc(); 650 TensorReshapeOp newInit; 651 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) { 652 653 newInit = rewriter.create<TensorReshapeOp>( 654 loc, reshapeOp.getResultType(), oldFill.output(), 655 reshapeOp.getReassociation(), reshapeOp.getOutputShape(), 656 reshapeOp.getStaticOutputShape()); 657 } else { 658 newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(), 659 oldFill.output(), 660 reshapeOp.getReassociation()); 661 } 662 rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()}, 663 ValueRange{newInit}); 664 return success(); 665 } 666 }; 667 668 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the 669 /// filling value are the same. 670 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> { 671 using OpRewritePattern::OpRewritePattern; 672 673 LogicalResult matchAndRewrite(tensor::PadOp padOp, 674 PatternRewriter &rewriter) const override { 675 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>(); 676 if (!fillOp) 677 return failure(); 678 679 // We can only fold if the padding value is the same as the original 680 // filling value. 681 Value padValue = padOp.getConstantPaddingValue(); 682 if (!padValue || fillOp.value() != padValue) 683 return failure(); 684 685 ReifiedRankedShapedTypeDims reifiedShape; 686 if (failed(reifyResultShapes(rewriter, padOp, reifiedShape))) 687 return rewriter.notifyMatchFailure( 688 padOp, "failed to reify tensor.pad op result shape"); 689 690 auto emptyTensor = rewriter.create<tensor::EmptyOp>( 691 padOp.getLoc(), reifiedShape.front(), 692 padOp.getResultType().getElementType()); 693 Value replacement = 694 rewriter 695 .create<FillOp>(fillOp.getLoc(), ValueRange{padValue}, 696 ValueRange{emptyTensor}) 697 .getResult(0); 698 if (replacement.getType() != padOp.getResultType()) { 699 replacement = rewriter.create<tensor::CastOp>( 700 fillOp.getLoc(), padOp.getResultType(), replacement); 701 } 702 rewriter.replaceOp(padOp, replacement); 703 return success(); 704 } 705 }; 706 707 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into 708 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the 709 /// filling value are the same. 710 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> { 711 using OpRewritePattern::OpRewritePattern; 712 713 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, 714 PatternRewriter &rewriter) const override { 715 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>(); 716 if (!srcPadOp) 717 return failure(); 718 719 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank()) 720 return failure(); 721 722 // Walk back the tensor.insert_slice chain and find the first destination 723 // value at the start of the chain. 724 Value firstDest = insertOp.getDest(); 725 while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) { 726 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank()) 727 return failure(); 728 729 // Make sure the range of values accessed are disjoint. Without this, we 730 // cannot fold tensor.pad away. 731 bool disjoint = false; 732 for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) { 733 // If the dimension has dynamic offset/size, we cannot guarantee 734 // disjoint. So just skip it. 735 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) || 736 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) || 737 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i)) 738 continue; 739 740 // Get the range start and end, inclusively for both. 741 int64_t prevStart = prevOp.getStaticOffset(i); 742 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) * 743 prevOp.getStaticStride(i); 744 int64_t nextStart = insertOp.getStaticOffset(i); 745 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) * 746 insertOp.getStaticStride(i); 747 if (prevEnd < nextStart || nextEnd < prevStart) { 748 disjoint = true; 749 break; 750 } 751 } 752 753 if (!disjoint) 754 break; 755 firstDest = prevOp.getDest(); 756 } 757 758 // Check whether the first destination is a fill op. For overlapped cases, 759 // this also cannot be true. 760 auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>(); 761 if (!dstFillOp) 762 return failure(); 763 764 // We can only fold if the padding value is the same as the original 765 // filling value. 766 Value padValue = srcPadOp.getConstantPaddingValue(); 767 if (!padValue || dstFillOp.value() != padValue) 768 return failure(); 769 770 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad(); 771 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets(); 772 773 Location loc = insertOp.getLoc(); 774 MLIRContext *context = getContext(); 775 776 AffineExpr sym0, sym1; 777 bindSymbols(context, sym0, sym1); 778 auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context); 779 780 // Calculate the new offsets for the insert. It should be the old offsets 781 // plus low padding sizes. 782 SmallVector<OpFoldResult, 4> newOffsets; 783 for (const auto &p : llvm::zip(lowPads, oldOffsets)) { 784 newOffsets.push_back(affine::makeComposedFoldedAffineApply( 785 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)})); 786 } 787 788 RankedTensorType srcPadType = srcPadOp.getSourceType(); 789 SmallVector<OpFoldResult, 4> newSizes; 790 for (int i = 0, e = srcPadType.getRank(); i < e; ++i) { 791 if (srcPadType.isDynamicDim(i)) { 792 newSizes.push_back( 793 rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i) 794 .getResult()); 795 } else { 796 newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i))); 797 } 798 } 799 800 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 801 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets, 802 newSizes, insertOp.getMixedStrides()); 803 return success(); 804 } 805 }; 806 807 /// Fold tensor.extract(linalg.fill(<input>)) into <input> 808 struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> { 809 public: 810 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 811 812 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, 813 PatternRewriter &rewriter) const override { 814 // See if tensor input of tensor.extract op is the result of a linalg.fill 815 // op. 816 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>(); 817 if (!fillOp) 818 return failure(); 819 820 // Get scalar input operand of linalg.fill op. 821 Value extractedScalar = fillOp.getInputs()[0]; 822 823 // Replace tensor.extract op with scalar value used to fill the tensor. 824 rewriter.replaceOp(extractOp, extractedScalar); 825 return success(); 826 } 827 }; 828 829 /// Folds pack(fill) into a single fill op if 830 /// 1. The pack op does not have padding value, or 831 /// 2. The filled value and padding value are the same. 832 static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter, 833 tensor::PackOp packOp) { 834 auto fillOp = packOp.getSource().getDefiningOp<FillOp>(); 835 if (!fillOp) 836 return failure(); 837 838 if (auto paddingValue = packOp.getPaddingValue()) 839 if (!isEqualConstantIntOrValue(paddingValue, fillOp.value())) 840 return failure(); 841 842 Value packOpDest = packOp.getDest(); 843 if (!packOpDest.hasOneUse()) 844 return failure(); 845 846 return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(), 847 packOp.getDest()); 848 } 849 850 /// Wrapper pattern that applies foldFillPackIntoFillOp method. 851 struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> { 852 public: 853 FoldFillWithPack(MLIRContext *context) 854 : OpRewritePattern<tensor::PackOp>(context) {} 855 856 LogicalResult matchAndRewrite(tensor::PackOp packOp, 857 PatternRewriter &rewriter) const override { 858 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp); 859 if (failed(fillOp)) 860 return failure(); 861 rewriter.replaceOp(packOp, fillOp.value().result()); 862 return success(); 863 } 864 }; 865 866 /// Fold fill with copy. 867 struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> { 868 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern; 869 870 LogicalResult matchAndRewrite(linalg::CopyOp copyOp, 871 PatternRewriter &rewriter) const override { 872 if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) { 873 rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(), 874 fillOp.getInputs(), 875 copyOp.getOutputs()); 876 return success(); 877 } 878 if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) { 879 rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(), 880 fillOp.getOutputs()); 881 return success(); 882 } 883 return failure(); 884 } 885 }; 886 887 /// Fold fill with transpose. 888 struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> { 889 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern; 890 891 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, 892 PatternRewriter &rewriter) const override { 893 if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) { 894 rewriter.replaceOpWithNewOp<FillOp>( 895 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(), 896 transposeOp.getDpsInitOperand(0)->get()); 897 return success(); 898 } 899 return failure(); 900 } 901 }; 902 903 /// Fold a concat with all elements being fills of the same value 904 /// into a fill of the concat result shape. 905 struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> { 906 using OpRewritePattern::OpRewritePattern; 907 908 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp, 909 PatternRewriter &rewriter) const override { 910 auto concatOperands = concatOp.getInputs(); 911 if (concatOperands.empty()) { 912 return failure(); 913 } 914 915 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>(); 916 if (!firstFillOp) { 917 return failure(); 918 } 919 // Prefetch the fill value. 920 OpFoldResult firstFillVal = 921 getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get()); 922 // Collect all the outs values for the fill operations. 923 SmallVector<Value> allOuts; 924 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get()); 925 926 auto isDefinedByCompatibleFillOp = [&](Value v) -> bool { 927 auto fillOp = v.getDefiningOp<linalg::FillOp>(); 928 if (!fillOp) { 929 return false; 930 } 931 932 OpFoldResult fillVal = 933 getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get()); 934 if (fillVal != firstFillVal) 935 return false; 936 937 allOuts.push_back(fillOp.getDpsInitOperand(0)->get()); 938 return true; 939 }; 940 if (!llvm::all_of(concatOperands.drop_front(), 941 isDefinedByCompatibleFillOp)) { 942 return rewriter.notifyMatchFailure( 943 concatOp, "not all operands are defined by a compatible fill op"); 944 } 945 946 Value outsConcat = rewriter.create<tensor::ConcatOp>( 947 concatOp.getLoc(), concatOp.getDim(), allOuts); 948 rewriter.replaceOpWithNewOp<linalg::FillOp>( 949 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat); 950 return success(); 951 } 952 }; 953 954 } // namespace 955 956 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results, 957 MLIRContext *context) { 958 results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract, 959 FoldFillWithPack, FoldFillWithPad, 960 FoldFillWithTensorReshape<tensor::CollapseShapeOp>, 961 FoldFillWithTensorReshape<tensor::ExpandShapeOp>, 962 FoldInsertPadIntoFill, FoldFillWithTranspose>(context); 963 } 964 965 //===----------------------------------------------------------------------===// 966 // GenericOp 967 //===----------------------------------------------------------------------===// 968 969 static void buildGenericRegion( 970 OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, 971 ValueRange outputs, 972 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) { 973 SmallVector<Type, 4> blockArgTypes; 974 SmallVector<Location, 4> blockArgLocs; 975 for (ValueRange container : {inputs, outputs}) { 976 for (Value v : container) { 977 Type t = v.getType(); 978 blockArgTypes.push_back( 979 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t); 980 blockArgLocs.push_back(v.getLoc()); 981 } 982 } 983 984 OpBuilder::InsertionGuard guard(builder); 985 Block *bodyBlock = 986 builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); 987 bodyBuild(builder, loc, bodyBlock->getArguments()); 988 } 989 990 void GenericOp::getAsmBlockArgumentNames(Region ®ion, 991 OpAsmSetValueNameFn setNameFn) { 992 for (Value v : getRegionInputArgs()) 993 setNameFn(v, "in"); 994 for (Value v : getRegionOutputArgs()) 995 setNameFn(v, "out"); 996 } 997 998 void GenericOp::build( 999 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, 1000 ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps, 1001 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall, 1002 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, 1003 ArrayRef<NamedAttribute> attributes) { 1004 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps, 1005 iteratorTypes, doc, libraryCall); 1006 result.addAttributes(attributes); 1007 if (bodyBuild) 1008 buildGenericRegion(builder, result.location, *result.regions.front(), 1009 inputs, outputs, bodyBuild); 1010 } 1011 1012 void GenericOp::build( 1013 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, 1014 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps, 1015 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc, 1016 StringRef libraryCall, 1017 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, 1018 ArrayRef<NamedAttribute> attributes) { 1019 build(builder, result, resultTensorTypes, inputs, outputs, 1020 builder.getAffineMapArrayAttr(indexingMaps), 1021 builder.getArrayAttr(llvm::to_vector(llvm::map_range( 1022 iteratorTypes, 1023 [&](utils::IteratorType iter) -> mlir::Attribute { 1024 return IteratorTypeAttr::get(builder.getContext(), iter); 1025 }))), 1026 doc.empty() ? StringAttr() : builder.getStringAttr(doc), 1027 libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall), 1028 bodyBuild, attributes); 1029 } 1030 1031 void GenericOp::build( 1032 OpBuilder &builder, OperationState &result, ValueRange inputs, 1033 ValueRange outputs, ArrayRef<AffineMap> indexingMaps, 1034 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc, 1035 StringRef libraryCall, 1036 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, 1037 ArrayRef<NamedAttribute> attributes) { 1038 build(builder, result, TypeRange{}, inputs, outputs, indexingMaps, 1039 iteratorTypes, doc, libraryCall, bodyBuild, attributes); 1040 } 1041 1042 void GenericOp::build( 1043 OpBuilder &builder, OperationState &result, ValueRange inputs, 1044 ValueRange outputs, ArrayRef<AffineMap> indexingMaps, 1045 ArrayRef<utils::IteratorType> iteratorTypes, 1046 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, 1047 ArrayRef<NamedAttribute> attributes) { 1048 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes, 1049 /*doc=*/"", 1050 /*libraryCall=*/"", bodyBuild, attributes); 1051 } 1052 1053 void GenericOp::build( 1054 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, 1055 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps, 1056 ArrayRef<utils::IteratorType> iteratorTypes, 1057 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, 1058 ArrayRef<NamedAttribute> attributes) { 1059 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps, 1060 iteratorTypes, 1061 /*doc=*/"", 1062 /*libraryCall=*/"", bodyBuild, attributes); 1063 } 1064 1065 void GenericOp::print(OpAsmPrinter &p) { 1066 p << " "; 1067 1068 // Print extra attributes. 1069 auto genericAttrNames = linalgTraitAttrNames(); 1070 1071 llvm::StringSet<> genericAttrNamesSet; 1072 genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end()); 1073 SmallVector<NamedAttribute, 8> genericAttrs; 1074 for (auto attr : (*this)->getAttrs()) { 1075 if (attr.getName() == getIteratorTypesAttrName()) { 1076 auto iteratorTypes = 1077 llvm::cast<ArrayAttr>(attr.getValue()) 1078 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>(); 1079 // Convert IteratorType enums into the string representation. This is 1080 // needed, because tests still use the old format when 'iterator_types' 1081 // attribute is represented as an array of strings. 1082 // TODO: Remove this conversion once tests are fixed. 1083 SmallVector<Attribute> iteratorTypeNames = 1084 llvm::to_vector(llvm::map_range( 1085 iteratorTypes, [&](utils::IteratorType t) -> Attribute { 1086 return StringAttr::get(getContext(), stringifyIteratorType(t)); 1087 })); 1088 1089 genericAttrs.emplace_back( 1090 getIteratorTypesAttrName(), 1091 ArrayAttr::get(getContext(), iteratorTypeNames)); 1092 } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) { 1093 genericAttrs.push_back(attr); 1094 } 1095 } 1096 if (!genericAttrs.empty()) { 1097 auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs); 1098 p << genericDictAttr; 1099 } 1100 1101 // Printing is shared with named ops, except for the region and attributes 1102 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); 1103 1104 genericAttrNames.push_back("operandSegmentSizes"); 1105 genericAttrNamesSet.insert(genericAttrNames.back()); 1106 1107 bool hasExtraAttrs = false; 1108 for (NamedAttribute n : (*this)->getAttrs()) { 1109 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref()))) 1110 break; 1111 } 1112 if (hasExtraAttrs) { 1113 p << " attrs = "; 1114 p.printOptionalAttrDict((*this)->getAttrs(), 1115 /*elidedAttrs=*/genericAttrNames); 1116 } 1117 1118 // Print region. 1119 if (!getRegion().empty()) { 1120 p << ' '; 1121 p.printRegion(getRegion()); 1122 } 1123 1124 // Print results. 1125 printNamedStructuredOpResults(p, getResultTensors().getTypes()); 1126 } 1127 1128 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) { 1129 DictionaryAttr dictAttr; 1130 // Parse the core linalg traits that must check into a dictAttr. 1131 // The name is unimportant as we will overwrite result.attributes. 1132 // The core linalg traits must contain the information necessary to pass the 1133 // verifier. 1134 llvm::SMLoc attributeLocation = parser.getCurrentLocation(); 1135 if (parser.parseAttribute(dictAttr, "_", result.attributes)) 1136 return failure(); 1137 result.attributes.assign(dictAttr.getValue().begin(), 1138 dictAttr.getValue().end()); 1139 1140 // Convert array of string into an array of IteratorType enums. This is 1141 // needed, because tests still use the old format when 'iterator_types' 1142 // attribute is represented as an array of strings. 1143 // TODO: Remove this conversion once tests are fixed. 1144 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>( 1145 result.attributes.get(getIteratorTypesAttrName(result.name))); 1146 if (!iteratorTypes) { 1147 return parser.emitError(attributeLocation) 1148 << "expected " << getIteratorTypesAttrName(result.name) 1149 << " array attribute"; 1150 } 1151 1152 SmallVector<Attribute> iteratorTypeAttrs; 1153 1154 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) { 1155 auto maybeIteratorType = utils::symbolizeIteratorType(s); 1156 if (!maybeIteratorType.has_value()) 1157 return parser.emitError(parser.getCurrentLocation()) 1158 << "unexpected iterator_type (" << s << ")"; 1159 1160 iteratorTypeAttrs.push_back( 1161 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value())); 1162 } 1163 result.attributes.set(getIteratorTypesAttrName(result.name), 1164 parser.getBuilder().getArrayAttr(iteratorTypeAttrs)); 1165 1166 // Parsing is shared with named ops, except for the region. 1167 SmallVector<Type, 1> inputTypes, outputTypes; 1168 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) 1169 return failure(); 1170 1171 // Optional attributes may be added. 1172 if (succeeded(parser.parseOptionalKeyword("attrs"))) 1173 if (failed(parser.parseEqual()) || 1174 failed(parser.parseOptionalAttrDict(result.attributes))) 1175 return failure(); 1176 1177 std::unique_ptr<Region> region = std::make_unique<Region>(); 1178 if (parser.parseRegion(*region, {})) 1179 return failure(); 1180 result.addRegion(std::move(region)); 1181 1182 // Generic ops may specify that a subset of its outputs are tensors. Such 1183 // outputs are specified in the result type. 1184 // TODO: may need to move output parsing before region parsing. 1185 // Need to wait for declarative assembly resolution to decide. 1186 SmallVector<Type, 1> outputTensorsTypes; 1187 if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) 1188 return failure(); 1189 result.addTypes(outputTensorsTypes); 1190 1191 return success(); 1192 } 1193 1194 static void getGenericEffectsImpl( 1195 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 1196 &effects, 1197 LinalgOp linalgOp) { 1198 for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) { 1199 if (!llvm::isa<MemRefType>(operand.getType())) 1200 continue; 1201 effects.emplace_back( 1202 MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0, 1203 /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get()); 1204 } 1205 1206 for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { 1207 if (!llvm::isa<MemRefType>(operand.get().getType())) 1208 continue; 1209 if (linalgOp.payloadUsesValueFromOperand(&operand)) { 1210 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0, 1211 /*effectOnFullRegion=*/true, 1212 SideEffects::DefaultResource::get()); 1213 } 1214 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0, 1215 /*effectOnFullRegion=*/true, 1216 SideEffects::DefaultResource::get()); 1217 } 1218 } 1219 1220 void GenericOp::getEffects( 1221 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 1222 &effects) { 1223 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation())); 1224 } 1225 1226 static Speculation::Speculatability 1227 getGenericSpeculatabilityImpl(LinalgOp linalgOp) { 1228 // Operands with value semantics are speculatable, while operands with memory 1229 // semantics are not. 1230 if (!linalgOp.hasPureTensorSemantics()) 1231 return Speculation::NotSpeculatable; 1232 // The body of the op can still have speculation in its region. 1233 return Speculation::RecursivelySpeculatable; 1234 } 1235 1236 Speculation::Speculatability GenericOp::getSpeculatability() { 1237 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation())); 1238 } 1239 1240 LogicalResult GenericOp::verify() { return success(); } 1241 1242 namespace { 1243 1244 /// Remove any linalg operation (on tensors) that are just copying 1245 /// the values from inputs to the results. Requirements are 1246 /// 1) All iterator types are parallel 1247 /// 2) The body contains just a yield operation with the yielded values being 1248 /// the arguments corresponding to the operands. 1249 template <typename OpTy> 1250 struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> { 1251 using OpRewritePattern<OpTy>::OpRewritePattern; 1252 1253 LogicalResult matchAndRewrite(OpTy linalgOp, 1254 PatternRewriter &rewriter) const override { 1255 // All indexing maps must be equal. It follows that they are permutations. 1256 if (!llvm::all_equal(linalgOp.getIndexingMapsArray())) 1257 return failure(); 1258 1259 // Check that the body of the linalg operation is just a linalg.yield 1260 // operation. 1261 Block &body = linalgOp->getRegion(0).front(); 1262 if (!llvm::hasSingleElement(body)) 1263 return failure(); 1264 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator()); 1265 if (!yieldOp) 1266 return failure(); 1267 1268 // In the buffer case, we need to check exact buffer equality. 1269 if (linalgOp.hasPureBufferSemantics()) { 1270 if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 && 1271 linalgOp.getDpsInputOperand(0)->get() == 1272 linalgOp.getDpsInitOperand(0)->get()) { 1273 rewriter.eraseOp(linalgOp); 1274 return success(); 1275 } 1276 return failure(); 1277 } 1278 1279 // Mixed semantics is not supported yet. 1280 if (!linalgOp.hasPureTensorSemantics()) 1281 return failure(); 1282 1283 // Get the argument number of the returned values. That is the operand 1284 // number to use for replacing uses of this operation. 1285 SmallVector<Value> returnedArgs; 1286 for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) { 1287 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value()); 1288 if (!yieldArg || yieldArg.getOwner() != &body) 1289 return failure(); 1290 unsigned argumentNumber = yieldArg.getArgNumber(); 1291 Value returnedArg = linalgOp->getOperand(argumentNumber); 1292 Type resultType = linalgOp->getResult(yieldVal.index()).getType(); 1293 // The input can have a different type than the result, e.g. a dynamic 1294 // input dimension can be turned into a static output dimension. 1295 Type returnType = returnedArg.getType(); 1296 if (returnType != resultType) { 1297 // Distinguish between sparse conversion or dense tensor casting. 1298 // TODO: unify the two ops? 1299 if (sparse_tensor::getSparseTensorEncoding(returnType) || 1300 sparse_tensor::getSparseTensorEncoding(resultType)) 1301 returnedArg = rewriter.create<sparse_tensor::ConvertOp>( 1302 linalgOp.getLoc(), resultType, returnedArg); 1303 else { 1304 if (!tensor::CastOp::areCastCompatible(returnedArg.getType(), 1305 resultType)) 1306 return failure(); 1307 returnedArg = rewriter.create<tensor::CastOp>( 1308 linalgOp.getLoc(), resultType, returnedArg); 1309 } 1310 } 1311 returnedArgs.push_back(returnedArg); 1312 } 1313 1314 if (returnedArgs.size() != linalgOp->getNumResults()) 1315 return failure(); 1316 rewriter.replaceOp(linalgOp, returnedArgs); 1317 return success(); 1318 } 1319 }; 1320 1321 } // namespace 1322 1323 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, 1324 MLIRContext *context) { 1325 results.add<EraseIdentityLinalgOp<GenericOp>>(context); 1326 } 1327 1328 LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) { 1329 return memref::foldMemRefCast(*this); 1330 } 1331 1332 //===----------------------------------------------------------------------===// 1333 // MapOp 1334 //===----------------------------------------------------------------------===// 1335 1336 static ParseResult parseDstStyleOp( 1337 OpAsmParser &parser, OperationState &result, 1338 function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn = 1339 nullptr) { 1340 // Parse `ins` and `outs`. 1341 SmallVector<Type, 4> inputTypes, outputTypes; 1342 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes, 1343 /*addOperandSegmentSizes=*/false)) 1344 return failure(); 1345 1346 // Add result types. 1347 for (Type outputType : outputTypes) { 1348 if (llvm::isa<RankedTensorType>(outputType)) 1349 result.addTypes(outputType); 1350 } 1351 1352 // Parse required attributes. 1353 if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes))) 1354 return failure(); 1355 1356 // Parse optional attributes. 1357 if (parser.parseOptionalAttrDict(result.attributes)) 1358 return failure(); 1359 return success(); 1360 } 1361 1362 void MapOp::getAsmBlockArgumentNames(Region ®ion, 1363 OpAsmSetValueNameFn setNameFn) { 1364 for (Value v : getRegionInputArgs()) 1365 setNameFn(v, "in"); 1366 } 1367 1368 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { 1369 if (!getResults().empty()) 1370 setNameFn(getResults().front(), "mapped"); 1371 } 1372 1373 void MapOp::build( 1374 OpBuilder &builder, OperationState &result, ValueRange inputs, Value init, 1375 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, 1376 ArrayRef<NamedAttribute> attributes) { 1377 build(builder, result, TypeRange{}, inputs, init); 1378 result.addAttributes(attributes); 1379 1380 // Add output types for `RankedTensorType` output arguments. 1381 Type initType = init.getType(); 1382 if (llvm::isa<RankedTensorType>(initType)) 1383 result.addTypes(initType); 1384 1385 if (bodyBuild) 1386 buildGenericRegion(builder, result.location, *result.regions.front(), 1387 inputs, /*outputs=*/{}, bodyBuild); 1388 } 1389 1390 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, 1391 const OperationName &payloadOpName, 1392 const NamedAttrList &payloadOpAttrs, 1393 ArrayRef<Value> operands, 1394 bool initFirst = false) { 1395 OpBuilder b(parser.getContext()); 1396 Region *body = result.addRegion(); 1397 Block &block = body->emplaceBlock(); 1398 b.setInsertionPointToStart(&block); 1399 SmallVector<Value> bbArgs; 1400 for (auto &operand : operands) { 1401 block.addArgument( 1402 llvm::cast<ShapedType>(operand.getType()).getElementType(), 1403 b.getUnknownLoc()); 1404 } 1405 SmallVector<Value> payloadOpOperands; 1406 // If initFirst flag is enabled, we consider init as the first position of 1407 // payload operands. 1408 if (initFirst) { 1409 payloadOpOperands.push_back(block.getArguments().back()); 1410 for (const auto &arg : block.getArguments().drop_back()) 1411 payloadOpOperands.push_back(arg); 1412 } else { 1413 payloadOpOperands = {block.getArguments().begin(), 1414 block.getArguments().end()}; 1415 } 1416 1417 Operation *payloadOp = b.create( 1418 result.location, b.getStringAttr(payloadOpName.getStringRef()), 1419 payloadOpOperands, 1420 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType()) 1421 .getElementType()}, 1422 payloadOpAttrs); 1423 b.create<YieldOp>(result.location, payloadOp->getResults()); 1424 } 1425 1426 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { 1427 std::optional<OperationName> payloadOpName; 1428 NamedAttrList payloadOpAttrs; 1429 if (succeeded(parser.parseOptionalLBrace())) { 1430 FailureOr<OperationName> operationName = parser.parseCustomOperationName(); 1431 if (failed(operationName)) 1432 return failure(); 1433 if (parser.parseOptionalAttrDict(payloadOpAttrs)) 1434 return failure(); 1435 payloadOpName = operationName.value(); 1436 if (parser.parseRBrace()) 1437 return failure(); 1438 } 1439 1440 if (parseDstStyleOp(parser, result)) 1441 return failure(); 1442 1443 if (payloadOpName.has_value()) { 1444 if (!result.operands.empty()) 1445 addBodyWithPayloadOp(parser, result, payloadOpName.value(), 1446 payloadOpAttrs, 1447 ArrayRef(result.operands).drop_back()); 1448 else 1449 result.addRegion(); 1450 } else { 1451 SmallVector<OpAsmParser::Argument> regionArgs; 1452 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren, 1453 /*allowType=*/true, /*allowAttrs=*/true)) { 1454 return failure(); 1455 } 1456 Region *body = result.addRegion(); 1457 if (parser.parseRegion(*body, regionArgs)) 1458 return failure(); 1459 } 1460 return success(); 1461 } 1462 1463 // Retrieve the operation from the body, if it is the only one (except 1464 // yield) and if it gets the same amount of arguments as the body does. 1465 // If initFirst flag is enabled, we check that init takes the first position in 1466 // operands of payload. 1467 static Operation *findPayloadOp(Block *body, bool initFirst = false) { 1468 if (body->getOperations().size() != 2) 1469 return nullptr; 1470 Operation &payload = body->getOperations().front(); 1471 assert(isa<YieldOp>(body->getOperations().back())); 1472 1473 if (payload.getNumOperands() == 0 || 1474 payload.getNumOperands() != body->getNumArguments()) 1475 return nullptr; 1476 if (initFirst) { 1477 // check init 1478 if (payload.getOperands().back() != body->getArgument(0)) 1479 return nullptr; 1480 // check rest 1481 for (const auto &[operand, bbArg] : 1482 llvm::zip(payload.getOperands(), body->getArguments().drop_front())) { 1483 if (bbArg != operand) 1484 return nullptr; 1485 } 1486 } else { 1487 for (const auto &[operand, bbArg] : 1488 llvm::zip(payload.getOperands(), body->getArguments())) { 1489 if (bbArg != operand) 1490 return nullptr; 1491 } 1492 } 1493 return &payload; 1494 } 1495 1496 void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { 1497 SmallVector<StringRef> elidedAttrs; 1498 std::string attrToElide; 1499 p << " { " << payloadOp->getName().getStringRef(); 1500 for (const auto &attr : payloadOp->getAttrs()) { 1501 auto fastAttr = 1502 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue()); 1503 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) { 1504 attrToElide = attr.getName().str(); 1505 elidedAttrs.push_back(attrToElide); 1506 break; 1507 } 1508 } 1509 p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs); 1510 p << " }"; 1511 } 1512 1513 void MapOp::print(OpAsmPrinter &p) { 1514 Block *mapper = getBody(); 1515 Operation *payloadOp = findPayloadOp(mapper); 1516 if (payloadOp) { 1517 printShortForm(p, payloadOp); 1518 } 1519 1520 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); 1521 p.printOptionalAttrDict((*this)->getAttrs()); 1522 1523 if (!payloadOp) { 1524 // Print region if the payload op was not detected. 1525 p.increaseIndent(); 1526 p.printNewline(); 1527 p << "("; 1528 llvm::interleaveComma(mapper->getArguments(), p, 1529 [&](auto arg) { p.printRegionArgument(arg); }); 1530 p << ") "; 1531 1532 p.printRegion(getMapper(), /*printEntryBlockArgs=*/false); 1533 p.decreaseIndent(); 1534 } 1535 } 1536 1537 LogicalResult MapOp::verify() { 1538 auto *bodyBlock = getBody(); 1539 auto blockArgs = bodyBlock->getArguments(); 1540 1541 // Checks if the number of `inputs` match the arity of the `mapper` region. 1542 if (getInputs().size() != blockArgs.size()) 1543 return emitOpError() << "expects number of operands to match the arity of " 1544 "mapper, but got: " 1545 << getInputs().size() << " and " << blockArgs.size(); 1546 1547 // The parameters of mapper should all match the element type of inputs. 1548 for (const auto &[bbArgType, inputArg] : 1549 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) { 1550 auto inputElemType = 1551 llvm::cast<ShapedType>(inputArg.getType()).getElementType(); 1552 if (bbArgType != inputElemType) { 1553 return emitOpError() << "expected element type of input " << inputElemType 1554 << " to match bbArg type " << bbArgType; 1555 } 1556 } 1557 1558 // The shape of each input must match the shape of the output. 1559 auto outputShape = getInit().getType().getShape(); 1560 for (Type inputArgType : TypeRange{getInputs()}) { 1561 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape(); 1562 if (inputElemShape != outputShape) { 1563 return emitOpError() << "expected shape of input (" << inputElemShape 1564 << ") to match shape of output (" << outputShape 1565 << ")"; 1566 } 1567 } 1568 1569 return success(); 1570 } 1571 1572 SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() { 1573 int64_t rank = getInit().getType().getRank(); 1574 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel); 1575 } 1576 1577 ArrayAttr MapOp::getIndexingMaps() { 1578 Builder builder(getContext()); 1579 int64_t rank = getInit().getType().getRank(); 1580 int64_t numIndexingMaps = getOperands().size(); 1581 return builder.getAffineMapArrayAttr(SmallVector<AffineMap>( 1582 numIndexingMaps, builder.getMultiDimIdentityMap(rank))); 1583 } 1584 1585 void MapOp::getEffects( 1586 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 1587 &effects) { 1588 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation())); 1589 } 1590 1591 Speculation::Speculatability MapOp::getSpeculatability() { 1592 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation())); 1593 } 1594 1595 //===----------------------------------------------------------------------===// 1596 // ReduceOp 1597 //===----------------------------------------------------------------------===// 1598 1599 void ReduceOp::getAsmBlockArgumentNames(Region ®ion, 1600 OpAsmSetValueNameFn setNameFn) { 1601 for (Value v : getRegionInputArgs()) 1602 setNameFn(v, "in"); 1603 for (Value v : getRegionOutputArgs()) 1604 setNameFn(v, "init"); 1605 } 1606 1607 void ReduceOp::getAsmResultNames( 1608 function_ref<void(Value, StringRef)> setNameFn) { 1609 if (!getResults().empty()) 1610 setNameFn(getResults().front(), "reduced"); 1611 } 1612 1613 void ReduceOp::build( 1614 OpBuilder &builder, OperationState &result, ValueRange inputs, 1615 ValueRange inits, ArrayRef<int64_t> dimensions, 1616 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, 1617 ArrayRef<NamedAttribute> attributes) { 1618 build(builder, result, TypeRange{}, inputs, inits, dimensions); 1619 result.addAttributes(attributes); 1620 1621 // Add output types for `RankedTensorType` output arguments. 1622 for (Value init : inits) { 1623 Type initType = init.getType(); 1624 if (llvm::isa<RankedTensorType>(initType)) 1625 result.addTypes(initType); 1626 } 1627 1628 if (bodyBuild) 1629 buildGenericRegion(builder, result.location, *result.regions.front(), 1630 inputs, inits, bodyBuild); 1631 } 1632 1633 SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() { 1634 int64_t inputRank = 1635 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank(); 1636 SmallVector<utils::IteratorType> iteratorTypes(inputRank, 1637 utils::IteratorType::parallel); 1638 for (int64_t reductionDim : getDimensions()) 1639 iteratorTypes[reductionDim] = utils::IteratorType::reduction; 1640 return iteratorTypes; 1641 } 1642 1643 ArrayAttr ReduceOp::getIndexingMaps() { 1644 int64_t inputRank = 1645 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank(); 1646 SmallVector<AffineMap> affineMaps( 1647 getNumDpsInputs(), 1648 AffineMap::getMultiDimIdentityMap(inputRank, getContext())); 1649 AffineMap resultMap = 1650 AffineMap::getMultiDimIdentityMap(inputRank, getContext()) 1651 .dropResults(getDimensions()); 1652 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i) 1653 affineMaps.push_back(resultMap); 1654 return Builder(getContext()).getAffineMapArrayAttr(affineMaps); 1655 } 1656 1657 void ReduceOp::getEffects( 1658 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 1659 &effects) { 1660 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation())); 1661 } 1662 1663 Speculation::Speculatability ReduceOp::getSpeculatability() { 1664 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation())); 1665 } 1666 1667 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, 1668 NamedAttrList &attributes, 1669 StringRef attributeName) { 1670 if (parser.parseKeyword(attributeName) || parser.parseEqual()) 1671 return failure(); 1672 1673 attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{})); 1674 return success(); 1675 } 1676 1677 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { 1678 std::optional<OperationName> payloadOpName; 1679 NamedAttrList payloadOpAttrs; 1680 if (succeeded(parser.parseOptionalLBrace())) { 1681 FailureOr<OperationName> operationName = parser.parseCustomOperationName(); 1682 if (failed(operationName)) 1683 return failure(); 1684 if (parser.parseOptionalAttrDict(payloadOpAttrs)) 1685 return failure(); 1686 payloadOpName = operationName.value(); 1687 if (parser.parseRBrace()) 1688 return failure(); 1689 } 1690 1691 if (parseDstStyleOp( 1692 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { 1693 return parseDenseI64ArrayAttr(parser, attributes, "dimensions"); 1694 })) 1695 return failure(); 1696 1697 if (payloadOpName.has_value()) { 1698 addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs, 1699 ArrayRef(result.operands), /*initFirst=*/true); 1700 } else { 1701 SmallVector<OpAsmParser::Argument> regionArgs; 1702 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren, 1703 /*allowType=*/true, /*allowAttrs=*/true)) { 1704 return failure(); 1705 } 1706 1707 Region *body = result.addRegion(); 1708 if (parser.parseRegion(*body, regionArgs)) 1709 return failure(); 1710 } 1711 1712 return success(); 1713 } 1714 1715 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, 1716 ArrayRef<int64_t> attributeValue) { 1717 p << ' ' << attributeName << " = [" << attributeValue << "] "; 1718 } 1719 1720 void ReduceOp::print(OpAsmPrinter &p) { 1721 Block *mapper = getBody(); 1722 Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true); 1723 if (payloadOp) { 1724 printShortForm(p, payloadOp); 1725 } 1726 1727 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); 1728 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); 1729 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); 1730 if (!payloadOp) { 1731 // Print region if the payload op was not detected. 1732 p.increaseIndent(); 1733 p.printNewline(); 1734 p << "("; 1735 llvm::interleaveComma(mapper->getArguments(), p, 1736 [&](auto arg) { p.printRegionArgument(arg); }); 1737 p << ") "; 1738 1739 p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false); 1740 p.decreaseIndent(); 1741 } 1742 } 1743 1744 LogicalResult ReduceOp::verify() { 1745 ArrayRef<int64_t> dimensionsRef = getDimensions(); 1746 1747 for (int64_t i = 1; i < getNumDpsInputs(); ++i) { 1748 if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() != 1749 llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) { 1750 return emitOpError() << "expects all inputs to have the same shapes. " 1751 "Shape at input-index " 1752 << i 1753 << " is not equal to the shape at input-index 0."; 1754 } 1755 } 1756 for (int64_t i = 1; i < getNumDpsInits(); ++i) { 1757 if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() != 1758 llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) { 1759 return emitOpError() << "expects all outputs to have the same shapes. " 1760 "Shape at output-index " 1761 << i 1762 << " is not equal to the shape at output-index 0."; 1763 } 1764 } 1765 auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType()); 1766 auto initType = llvm::cast<ShapedType>(getInits()[0].getType()); 1767 1768 DenseSet<int64_t> dimensionsToReduce; 1769 for (int64_t dimension : dimensionsRef) { 1770 if (dimension < 0 || dimension >= inputType.getRank()) { 1771 return emitOpError() 1772 << "dimensions for reduction should be in the range [0, " 1773 << inputType.getRank() - 1 << "]."; 1774 } 1775 dimensionsToReduce.insert(dimension); 1776 } 1777 1778 auto inputDims = inputType.getShape(); 1779 auto initDims = initType.getShape(); 1780 1781 // Input dimensions that will be left after the reduction. 1782 SmallVector<int64_t> reducedInputDims; 1783 for (const auto &en : llvm::enumerate(inputDims)) { 1784 if (!dimensionsToReduce.count(en.index())) 1785 reducedInputDims.push_back(en.value()); 1786 } 1787 1788 if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) { 1789 return emitOpError() << "number of dimensions after reduction " 1790 << reducedInputDims.size() 1791 << " doesn't match the init rank " 1792 << initType.getRank(); 1793 } 1794 1795 if (reducedInputDims != initDims) 1796 return emitOpError() << "init dimensions [" << initDims 1797 << "] doesn't match input dimensions after reduction [" 1798 << reducedInputDims << "]"; 1799 1800 Block *block = getBody(); 1801 if (block->getNumArguments() != this->getNumOperands()) 1802 return emitOpError() 1803 << "mismatching number of operands and block arguments"; 1804 1805 // Check that the first block arguments match the element type of the inputs. 1806 for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) { 1807 Type inputElementType = 1808 llvm::cast<ShapedType>(input.getType()).getElementType(); 1809 if (inputElementType != bbArg.getType()) 1810 return emitOpError() 1811 << "input element type " << inputElementType 1812 << " does not match corresponding block argument type " 1813 << bbArg.getType(); 1814 } 1815 1816 // Check that the last block arguments match the element type of the outputs. 1817 for (auto [output, bbArg] : llvm::zip( 1818 getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) { 1819 auto outputElementType = 1820 llvm::cast<ShapedType>(output.getType()).getElementType(); 1821 if (outputElementType != bbArg.getType()) 1822 return emitOpError() 1823 << "output element type " << outputElementType 1824 << " does not match corresponding block argument type " 1825 << bbArg.getType(); 1826 } 1827 return success(); 1828 } 1829 1830 //===----------------------------------------------------------------------===// 1831 // TransposeOp 1832 //===----------------------------------------------------------------------===// 1833 1834 static void buildIdentityRegion(OpBuilder &builder, Location loc, 1835 Region ®ion, ValueRange inputs, 1836 ValueRange outputs) { 1837 buildGenericRegion(builder, loc, region, inputs, outputs, 1838 [](OpBuilder &b, Location loc, ValueRange args) { 1839 if (!args.empty()) 1840 b.create<linalg::YieldOp>(loc, args[0]); 1841 }); 1842 } 1843 1844 void TransposeOp::build(::mlir::OpBuilder &builder, 1845 ::mlir::OperationState &result, Value input, Value init, 1846 DenseI64ArrayAttr permutation, 1847 ArrayRef<NamedAttribute> attributes) { 1848 result.addOperands(input); 1849 result.addOperands(init); 1850 result.addAttribute(getPermutationAttrName(result.name), permutation); 1851 result.addAttributes(attributes); 1852 1853 // Add output types for `RankedTensorType` output arguments. 1854 Type initType = init.getType(); 1855 if (llvm::isa<RankedTensorType>(initType)) 1856 result.addTypes(initType); 1857 1858 buildIdentityRegion(builder, result.location, *result.addRegion(), input, 1859 init); 1860 } 1861 1862 void TransposeOp::build(::mlir::OpBuilder &builder, 1863 ::mlir::OperationState &result, Value input, Value init, 1864 ArrayRef<int64_t> permutation, 1865 ArrayRef<NamedAttribute> attributes) { 1866 build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation), 1867 attributes); 1868 } 1869 1870 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) { 1871 if (failed(parseDstStyleOp( 1872 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { 1873 return parseDenseI64ArrayAttr(parser, attributes, "permutation"); 1874 }))) 1875 return failure(); 1876 1877 OpBuilder builder(parser.getContext()); 1878 buildIdentityRegion(builder, result.location, *result.addRegion(), 1879 /*inputs=*/result.operands, 1880 /*outputs=*/{}); 1881 return success(); 1882 } 1883 1884 void TransposeOp::getAsmResultNames( 1885 function_ref<void(Value, StringRef)> setNameFn) { 1886 if (!getResults().empty()) 1887 setNameFn(getResults().front(), "transposed"); 1888 } 1889 1890 void TransposeOp::print(OpAsmPrinter &p) { 1891 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); 1892 printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation()); 1893 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()}); 1894 } 1895 1896 LogicalResult TransposeOp::verify() { 1897 ArrayRef<int64_t> permutationRef = getPermutation(); 1898 1899 if (!isPermutationVector(permutationRef)) 1900 return emitOpError("permutation is not valid"); 1901 1902 auto inputType = getInput().getType(); 1903 auto initType = getInit().getType(); 1904 1905 int64_t rank = inputType.getRank(); 1906 1907 if (rank != initType.getRank()) 1908 return emitOpError() << "input rank " << rank 1909 << " does not match init rank " << initType.getRank(); 1910 1911 if (rank != static_cast<int64_t>(permutationRef.size())) 1912 return emitOpError() << "size of permutation " << permutationRef.size() 1913 << " does not match the argument rank " << rank; 1914 1915 auto inputDims = inputType.getShape(); 1916 auto initDims = initType.getShape(); 1917 1918 for (int64_t i = 0; i < rank; ++i) { 1919 int64_t inputDim = inputDims[permutationRef[i]]; 1920 int64_t initDim = initDims[i]; 1921 1922 if (inputDim != initDim) { 1923 return emitOpError() << "dim(result, " << i << ") = " << initDim 1924 << " doesn't match dim(input, permutation[" << i 1925 << "]) = " << inputDim; 1926 } 1927 } 1928 1929 return success(); 1930 } 1931 1932 SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() { 1933 int64_t rank = getInit().getType().getRank(); 1934 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel); 1935 } 1936 1937 ArrayAttr TransposeOp::getIndexingMaps() { 1938 Builder builder(getContext()); 1939 int64_t rank = getInit().getType().getRank(); 1940 return builder.getAffineMapArrayAttr( 1941 {inversePermutation(AffineMap::getPermutationMap( 1942 llvm::to_vector_of<unsigned>(getPermutation()), getContext())), 1943 builder.getMultiDimIdentityMap(rank)}); 1944 } 1945 1946 void TransposeOp::getEffects( 1947 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 1948 &effects) { 1949 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation())); 1950 } 1951 1952 Speculation::Speculatability TransposeOp::getSpeculatability() { 1953 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation())); 1954 } 1955 1956 LogicalResult TransposeOp::fold(FoldAdaptor adaptor, 1957 SmallVectorImpl<OpFoldResult> &result) { 1958 // Only the tensor type is supported. 1959 if (!isa<TensorType>(getInput().getType())) 1960 return failure(); 1961 1962 // Single dimension transpose. 1963 if (getPermutation().size() == 0) { 1964 result.push_back(getInput()); 1965 return success(); 1966 } 1967 // Identity permutation. 1968 if (isIdentityPermutation(getPermutation())) { 1969 result.push_back(getInput()); 1970 return success(); 1971 } 1972 1973 return failure(); 1974 } 1975 1976 /// Fold transpose with transpose. 1977 struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> { 1978 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern; 1979 1980 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, 1981 PatternRewriter &rewriter) const override { 1982 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>(); 1983 if (!defTransposeOp) 1984 return failure(); 1985 ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation(); 1986 ArrayRef<int64_t> perms = transposeOp.getPermutation(); 1987 SmallVector<int64_t> foldedPerms; 1988 foldedPerms.reserve(perms.size()); 1989 for (int64_t perm : perms) 1990 foldedPerms.push_back(defPerms[perm]); 1991 1992 rewriter.replaceOpWithNewOp<TransposeOp>( 1993 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(), 1994 foldedPerms); 1995 return success(); 1996 } 1997 }; 1998 1999 /// This pattern canonicalize transpose by swapping the order of 2000 /// broadcast and transpose: 2001 /// transpose(broadcast(input)) -> broadcast(transpose(input)) 2002 struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> { 2003 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern; 2004 2005 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, 2006 PatternRewriter &rewriter) const override { 2007 Value input = transposeOp.getInput(); 2008 BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>(); 2009 if (!input.hasOneUse() || !broadcastOp) 2010 return failure(); 2011 2012 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions(); 2013 ArrayRef<int64_t> perms = transposeOp.getPermutation(); 2014 2015 // Get new perms and new dimensions. 2016 SmallVector<int64_t> resultPerms = dropDims(perms, dimensions); 2017 SmallVector<int64_t> invertPerm = invertPermutationVector(perms); 2018 SmallVector<int64_t> resultDimensions; 2019 unsigned dimensionSize = dimensions.size(); 2020 for (unsigned i = 0; i < dimensionSize; ++i) 2021 resultDimensions.push_back(invertPerm[dimensions[i]]); 2022 2023 // Create transpose result. 2024 Value broadcastInput = broadcastOp.getInput(); 2025 Location loc = transposeOp.getLoc(); 2026 MLIRContext *ctx = transposeOp.getContext(); 2027 SmallVector<OpFoldResult> dims; 2028 auto broadcastInputTy = 2029 mlir::cast<RankedTensorType>(broadcastInput.getType()); 2030 unsigned inputRank = broadcastInputTy.getRank(); 2031 for (unsigned i = 0; i < inputRank; ++i) { 2032 if (broadcastInputTy.isDynamicDim(i)) { 2033 dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i) 2034 ->getResult(0)); 2035 } else { 2036 dims.push_back(IntegerAttr::get(IndexType::get(ctx), 2037 broadcastInputTy.getDimSize(i))); 2038 } 2039 } 2040 SmallVector<OpFoldResult> transposeResultShapes = 2041 applyPermutation(dims, resultPerms); 2042 Value transposeInit = rewriter.create<tensor::EmptyOp>( 2043 transposeOp.getLoc(), transposeResultShapes, 2044 broadcastInputTy.getElementType()); 2045 2046 // Create broadcast(transpose(input)). 2047 Value transposeResult = 2048 rewriter 2049 .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit, 2050 resultPerms) 2051 ->getResult(0); 2052 rewriter.replaceOpWithNewOp<BroadcastOp>( 2053 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions); 2054 return success(); 2055 } 2056 }; 2057 2058 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, 2059 MLIRContext *context) { 2060 results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context); 2061 } 2062 2063 //===----------------------------------------------------------------------===// 2064 // BroadcastOp 2065 //===----------------------------------------------------------------------===// 2066 2067 void BroadcastOp::build(::mlir::OpBuilder &builder, 2068 ::mlir::OperationState &result, Value input, Value init, 2069 DenseI64ArrayAttr dimensions, 2070 ArrayRef<NamedAttribute> attributes) { 2071 result.addOperands(input); 2072 result.addOperands(init); 2073 result.addAttribute(getDimensionsAttrName(result.name), dimensions); 2074 result.addAttributes(attributes); 2075 2076 // Add output types for `RankedTensorType` output arguments. 2077 Type initType = init.getType(); 2078 if (llvm::isa<RankedTensorType>(initType)) 2079 result.addTypes(initType); 2080 2081 buildIdentityRegion(builder, result.location, *result.addRegion(), input, 2082 init); 2083 } 2084 2085 void BroadcastOp::build(::mlir::OpBuilder &builder, 2086 ::mlir::OperationState &result, Value input, Value init, 2087 ArrayRef<int64_t> dimensions, 2088 ArrayRef<NamedAttribute> attributes) { 2089 build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions), 2090 attributes); 2091 } 2092 2093 ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) { 2094 if (failed(parseDstStyleOp( 2095 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { 2096 return parseDenseI64ArrayAttr(parser, attributes, "dimensions"); 2097 }))) 2098 return failure(); 2099 2100 OpBuilder builder(parser.getContext()); 2101 buildIdentityRegion(builder, result.location, *result.addRegion(), 2102 /*inputs=*/result.operands, 2103 /*outputs=*/{}); 2104 return success(); 2105 } 2106 2107 void BroadcastOp::getAsmResultNames( 2108 function_ref<void(Value, StringRef)> setNameFn) { 2109 if (!getResults().empty()) 2110 setNameFn(getResults().front(), "broadcasted"); 2111 } 2112 2113 void BroadcastOp::print(OpAsmPrinter &p) { 2114 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); 2115 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); 2116 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); 2117 } 2118 2119 LogicalResult BroadcastOp::verify() { 2120 ArrayRef<int64_t> dimensionsRef = getDimensions(); 2121 2122 auto inputType = getInput().getType(); 2123 auto initType = getInit().getType(); 2124 2125 int64_t inputRank = inputType.getRank(); 2126 int64_t initRank = initType.getRank(); 2127 2128 auto inputShape = inputType.getShape(); 2129 auto initShape = initType.getShape(); 2130 2131 if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank) 2132 return emitOpError() << "input rank plus added dimensions does not " 2133 "match init rank. input rank: " 2134 << inputRank 2135 << ", dimensions size: " << dimensionsRef.size() 2136 << ", init rank: " << initRank; 2137 2138 for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) { 2139 if (dim < 0 || dim >= initRank) 2140 return emitOpError() << "dimension " << idx 2141 << " is out of range. expected range: [0, " 2142 << initRank - 1 << "], got: " << dim; 2143 } 2144 2145 // Mapping from input dims to init dims. 2146 SmallVector<int64_t> dimMap; 2147 for (auto dim : llvm::seq<int64_t>(0, initRank)) { 2148 if (!llvm::is_contained(dimensionsRef, dim)) 2149 dimMap.push_back(dim); 2150 } 2151 2152 for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) { 2153 // This dimensions is mapped from the input. Init and input dims should 2154 // match. 2155 if (inputShape[inputDimIdx] != initShape[initDimIdx]) 2156 return emitOpError() << "input dim " << inputDimIdx 2157 << " should match init dim " << initDimIdx 2158 << ". input: " << inputShape[inputDimIdx] 2159 << ", init: " << initShape[initDimIdx]; 2160 } 2161 2162 return success(); 2163 } 2164 2165 SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() { 2166 int64_t rank = getInit().getType().getRank(); 2167 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel); 2168 } 2169 2170 ArrayAttr BroadcastOp::getIndexingMaps() { 2171 Builder builder(getContext()); 2172 int64_t rank = getInit().getType().getRank(); 2173 return builder.getAffineMapArrayAttr( 2174 {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()), 2175 builder.getMultiDimIdentityMap(rank)}); 2176 } 2177 2178 void BroadcastOp::getEffects( 2179 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 2180 &effects) { 2181 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation())); 2182 } 2183 2184 Speculation::Speculatability BroadcastOp::getSpeculatability() { 2185 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation())); 2186 } 2187 2188 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, 2189 MLIRContext *context) { 2190 results.add<EraseIdentityLinalgOp<BroadcastOp>>(context); 2191 } 2192 2193 //===----------------------------------------------------------------------===// 2194 // YieldOp 2195 //===----------------------------------------------------------------------===// 2196 2197 void linalg::YieldOp::print(OpAsmPrinter &p) { 2198 if (getNumOperands() > 0) 2199 p << ' ' << getOperands(); 2200 p.printOptionalAttrDict((*this)->getAttrs()); 2201 if (getNumOperands() > 0) 2202 p << " : " << getOperandTypes(); 2203 } 2204 2205 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) { 2206 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo; 2207 SmallVector<Type, 2> types; 2208 SMLoc loc = parser.getCurrentLocation(); 2209 return failure(parser.parseOperandList(opInfo) || 2210 parser.parseOptionalAttrDict(result.attributes) || 2211 (!opInfo.empty() && parser.parseColonTypeList(types)) || 2212 parser.resolveOperands(opInfo, types, loc, result.operands)); 2213 } 2214 2215 // Check the operand number and types must match the element types of the 2216 // LinalgOp interface's shaped operands. 2217 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) { 2218 if (op.getNumOperands() != linalgOp.getNumDpsInits()) 2219 return op.emitOpError("expected number of yield values (") 2220 << op.getNumOperands() 2221 << ") to match the number of inits / outs operands of the enclosing " 2222 << "LinalgOp (" << linalgOp.getNumDpsInits() << ")"; 2223 2224 for (OpOperand &opOperand : op->getOpOperands()) { 2225 OpOperand *outputOperand = 2226 linalgOp.getDpsInitOperand(opOperand.getOperandNumber()); 2227 Type elementType = outputOperand->get().getType(); 2228 if (isa<MemRefType, RankedTensorType>(elementType)) 2229 elementType = getElementTypeOrSelf(outputOperand->get().getType()); 2230 if (opOperand.get().getType() != elementType) 2231 return op.emitOpError("type of yield operand ") 2232 << (opOperand.getOperandNumber() + 1) << " (" 2233 << opOperand.get().getType() << ") doesn't match " 2234 << "the element type of the enclosing linalg.generic op (" 2235 << elementType << ")"; 2236 } 2237 return success(); 2238 } 2239 2240 LogicalResult linalg::YieldOp::verify() { 2241 auto *parentOp = (*this)->getParentOp(); 2242 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty()) 2243 return emitOpError("expected single non-empty parent region"); 2244 2245 if (auto linalgOp = dyn_cast<LinalgOp>(parentOp)) 2246 return verifyYield(*this, linalgOp); 2247 2248 return emitOpError("expected parent op with LinalgOp interface"); 2249 } 2250 2251 //===----------------------------------------------------------------------===// 2252 // IndexOp 2253 //===----------------------------------------------------------------------===// 2254 2255 LogicalResult IndexOp::verify() { 2256 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp()); 2257 if (!linalgOp) 2258 return emitOpError("expected parent op with LinalgOp interface"); 2259 if (linalgOp.getNumLoops() <= getDim()) 2260 return emitOpError("expected dim (") 2261 << getDim() << ") to be lower than the number of loops (" 2262 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp"; 2263 return success(); 2264 } 2265 2266 /////// Operations corresponding to library calls defined with Tablegen //////// 2267 2268 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc" 2269 2270 #define GET_OP_CLASSES 2271 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" 2272 2273 #define GET_OP_CLASSES 2274 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 2275 2276 AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap, 2277 unsigned rank, 2278 MLIRContext *context) { 2279 if (maybeMap) 2280 return *maybeMap; 2281 if (rank == 0) 2282 return AffineMap::get(context); 2283 return AffineMap::getMultiDimIdentityMap(rank, context); 2284 } 2285 2286 SmallVector<AffineExpr, 4> 2287 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx, 2288 MLIRContext *context) { 2289 SmallVector<AffineExpr, 4> res; 2290 res.reserve(num); 2291 for (unsigned i = 0; i < num; ++i) 2292 res.push_back(getAffineDimExpr(startIdx++, context)); 2293 return res; 2294 } 2295 2296 SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a, 2297 ArrayRef<AffineExpr> b) { 2298 auto rangeA = llvm::make_range(a.begin(), a.end()); 2299 auto rangeB = llvm::make_range(b.begin(), b.end()); 2300 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB); 2301 return llvm::to_vector<4>(concatRanges); 2302 } 2303 2304 static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) { 2305 if (auto memref = llvm::dyn_cast<MemRefType>(t)) { 2306 ss << "view"; 2307 for (auto size : memref.getShape()) 2308 if (size < 0) 2309 ss << "sx"; 2310 else 2311 ss << size << "x"; 2312 if (failed(appendMangledType(ss, memref.getElementType()))) 2313 return failure(); 2314 if (auto as = memref.getMemorySpace()) { 2315 if (auto attr = llvm::dyn_cast<IntegerAttr>(as)) 2316 ss << "as" << attr.getInt(); 2317 else 2318 return failure(); 2319 } 2320 return success(); 2321 } 2322 if (auto vec = llvm::dyn_cast<VectorType>(t)) { 2323 ss << "vector"; 2324 llvm::interleave( 2325 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; }); 2326 if (failed(appendMangledType(ss, vec.getElementType()))) 2327 return failure(); 2328 return success(); 2329 } 2330 if (t.isSignlessIntOrIndexOrFloat()) { 2331 ss << t; 2332 return success(); 2333 } 2334 return failure(); 2335 } 2336 2337 std::string mlir::linalg::generateLibraryCallName(Operation *op) { 2338 assert(isa<LinalgOp>(op)); 2339 std::string name(op->getName().getStringRef().str()); 2340 std::string fun = ""; 2341 for (NamedAttribute kv : op->getAttrs()) { 2342 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) { 2343 fun = stringifyEnum(ufa.getValue()).str() + "_"; 2344 } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) { 2345 fun = stringifyEnum(bfa.getValue()).str() + "_"; 2346 } 2347 } 2348 name.reserve(128); 2349 std::replace(name.begin(), name.end(), '.', '_'); 2350 llvm::raw_string_ostream ss(name); 2351 ss << "_" << fun; 2352 for (Type t : op->getOperandTypes()) { 2353 if (failed(appendMangledType(ss, t))) 2354 return std::string(); 2355 ss << "_"; 2356 } 2357 name.pop_back(); 2358 return name; 2359 } 2360 2361 //===----------------------------------------------------------------------===// 2362 // Canonicalizers and Folders. 2363 //===----------------------------------------------------------------------===// 2364 2365 namespace { 2366 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> { 2367 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern; 2368 2369 LogicalResult matchAndRewrite(LinalgOp op, 2370 PatternRewriter &rewriter) const override { 2371 for (OpOperand &opOperand : op->getOpOperands()) { 2372 // Linalg "inputs" may be either tensor or memref type. 2373 // tensor<0xelt_type> is a convention that may not always mean 2374 // "0 iterations". Only erase in cases we see memref<...x0x...>. 2375 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType()); 2376 if (!mt) 2377 continue; 2378 if (llvm::is_contained(op.getShape(&opOperand), 0)) { 2379 rewriter.eraseOp(op); 2380 return success(); 2381 } 2382 } 2383 return failure(); 2384 } 2385 }; 2386 2387 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has 2388 /// result that is more static than the linalg op. 2389 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> { 2390 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 2391 2392 LogicalResult matchAndRewrite(tensor::CastOp castOp, 2393 PatternRewriter &rewriter) const override { 2394 if (!tensor::canFoldIntoProducerOp(castOp)) 2395 return failure(); 2396 2397 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>(); 2398 if (!linalgOp) 2399 return failure(); 2400 2401 // Cast can be in conditionally reachable region, if which case folding will 2402 // generate invalid code. Only conservatively fold ops in same block for 2403 // now. 2404 if (castOp->getBlock() != linalgOp->getBlock()) 2405 return failure(); 2406 2407 OpBuilder::InsertionGuard guard(rewriter); 2408 rewriter.setInsertionPoint(linalgOp); 2409 2410 Location loc = linalgOp.getLoc(); 2411 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource()); 2412 unsigned resultNumber = resultValue.getResultNumber(); 2413 auto resultType = 2414 llvm::cast<RankedTensorType>(castOp->getResult(0).getType()); 2415 // Replace the `outs` for the result with a `tensor.cast`. This cast is now 2416 // going from a more dynamic shape to a less dynamic shape. If the producer 2417 // for this cast, i.e. producer of the out operand, is also an operation 2418 // that folds with tensor.cast consumer (like this pattern), the cast will 2419 // continue to propagate as far up the stack as it can go. 2420 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber); 2421 Value newOperand = 2422 rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get()); 2423 SmallVector<Value> newOperands = linalgOp.getDpsInputs(); 2424 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(), 2425 linalgOp.getDpsInits().end()); 2426 outputOperands[resultNumber] = newOperand; 2427 newOperands.append(outputOperands.begin(), outputOperands.end()); 2428 2429 SmallVector<Type> resultTypes(linalgOp->result_type_begin(), 2430 linalgOp->result_type_end()); 2431 resultTypes[resultNumber] = resultType; 2432 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands); 2433 2434 // Create a tensor.cast operation back to the original type. 2435 Value castBack = rewriter.create<tensor::CastOp>( 2436 loc, resultValue.getType(), newOp->getResult(resultNumber)); 2437 2438 SmallVector<Value> results(newOp->result_begin(), newOp->result_end()); 2439 results[resultNumber] = castBack; 2440 rewriter.replaceOp(linalgOp, results); 2441 rewriter.replaceOp(castOp, newOp->getResult(resultNumber)); 2442 return success(); 2443 } 2444 }; 2445 2446 /// For each of the operand in `operands` this function maps the static sizes of 2447 /// dimensions to their affine dim expressions. 2448 static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands, 2449 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) { 2450 for (OpOperand &opOperand : operands) { 2451 if (linalgOp.isScalar(&opOperand)) 2452 continue; 2453 Value src = opOperand.get(); 2454 auto sourceType = llvm::cast<RankedTensorType>(src.getType()); 2455 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand); 2456 2457 // Get the `sourceShape` of the `sourceType`. If the operand is a result of 2458 // `tensor.cast` operation and source of the cast operation has a static 2459 // shape, then assign it to the `sourceShape`. 2460 auto *parentOp = src.getDefiningOp(); 2461 ArrayRef<int64_t> sourceShape = sourceType.getShape(); 2462 if (parentOp) { 2463 if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) { 2464 Value castSource = castOp.getSource(); 2465 auto castSourceType = 2466 llvm::dyn_cast<RankedTensorType>(castSource.getType()); 2467 if (castSourceType && castSourceType.hasStaticShape()) 2468 sourceShape = castSourceType.getShape(); 2469 } 2470 } 2471 2472 // If the source shape's dimension has a static shape, map the affine dim 2473 // expression to the known static size. 2474 for (unsigned i = 0; i < sourceShape.size(); i++) { 2475 if (sourceType.isDynamicDim(i)) 2476 continue; 2477 if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i))) 2478 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]); 2479 } 2480 } 2481 } 2482 2483 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes 2484 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and 2485 /// their result types is stored in `resultTypes`. If `opOperand` requires no 2486 /// change then `changeNeeded` is false and same operand is added in the 2487 /// `newOperands` list. 2488 static void createNewOperandWithStaticSizes( 2489 Location loc, PatternRewriter &rewriter, OpOperand *opOperand, 2490 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp, 2491 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes, 2492 bool &changeNeeded) { 2493 Value src = opOperand->get(); 2494 newOperands.push_back(src); 2495 if (linalgOp.isScalar(opOperand)) 2496 return; 2497 auto sourceType = llvm::cast<RankedTensorType>(src.getType()); 2498 Type resultType = sourceType; 2499 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) { 2500 resultTypes.push_back(resultType); 2501 return; 2502 } 2503 ArrayRef<int64_t> sourceShape = sourceType.getShape(); 2504 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand); 2505 SmallVector<int64_t> newShape; 2506 // If operand is updated with new shape, `newOperandNeeded` will be 2507 // true. 2508 bool newOperandNeeded = false; 2509 for (unsigned i = 0; i < sourceShape.size(); i++) { 2510 int64_t dimShape = sourceShape[i]; 2511 AffineExpr dimExpr = sourceMap.getResult(i); 2512 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) { 2513 newShape.push_back(dimShape); 2514 continue; 2515 } 2516 // Dimension has a dynamic shape and corresponding affine dim 2517 // expression is present in the map. So assign the size for the 2518 // given affine dim expression to the dimension. 2519 newShape.push_back(affineExprToSize[dimExpr]); 2520 newOperandNeeded = true; 2521 } 2522 resultType = RankedTensorType::get(newShape, sourceType.getElementType()); 2523 if (newOperandNeeded) { 2524 changeNeeded = true; 2525 // Get the new operand value given its size and element type by 2526 // casting it. 2527 Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src); 2528 unsigned index = opOperand->getOperandNumber(); 2529 newOperands[index] = newOperand; 2530 } 2531 if (linalgOp.isDpsInit(opOperand)) 2532 resultTypes.push_back(resultType); 2533 } 2534 2535 /// Static shapes for the operands can be inferred if any one of the operands 2536 /// have a static shape. This can be done by referring to the affine dim 2537 /// expressions for the operand. 2538 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> { 2539 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern; 2540 2541 LogicalResult matchAndRewrite(LinalgOp linalgOp, 2542 PatternRewriter &rewriter) const override { 2543 if (!linalgOp.hasPureTensorSemantics()) 2544 return failure(); 2545 2546 // Maps must be projected permutations. 2547 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) { 2548 return !map.isProjectedPermutation(); 2549 })) 2550 return failure(); 2551 2552 // Maps affine dim expressions to the static size of that dimension. 2553 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize; 2554 Location loc = linalgOp.getLoc(); 2555 2556 // For each of the affine dim expression, check if the size is known. If 2557 // known add that in the map. 2558 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize); 2559 2560 SmallVector<Value> newOperands; 2561 SmallVector<Type> resultTypes; 2562 2563 // `changeNeeded` is `false` if the operands of `linalgOp` require no 2564 // change in their types. 2565 bool changeNeeded = false; 2566 newOperands.reserve(linalgOp->getNumOperands()); 2567 resultTypes.reserve(linalgOp.getNumDpsInits()); 2568 2569 // Iterate over all the operands and update the static sizes. 2570 for (OpOperand &opOperand : linalgOp->getOpOperands()) { 2571 createNewOperandWithStaticSizes(loc, rewriter, &opOperand, 2572 affineExprToSize, linalgOp, newOperands, 2573 resultTypes, changeNeeded); 2574 } 2575 2576 // If the generic op has all the required static information, no 2577 // canonicalization needed. 2578 if (!changeNeeded) 2579 return failure(); 2580 2581 // Clone op. 2582 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands); 2583 SmallVector<Value> replacements; 2584 replacements.reserve(newOp->getNumResults()); 2585 for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) { 2586 Value newResult = std::get<1>(it); 2587 Value oldResult = std::get<0>(it); 2588 Type newType = newResult.getType(); 2589 Type oldType = oldResult.getType(); 2590 replacements.push_back( 2591 (newType != oldType) 2592 ? rewriter.create<tensor::CastOp>(loc, oldType, newResult) 2593 : newResult); 2594 } 2595 rewriter.replaceOp(linalgOp, replacements); 2596 return success(); 2597 } 2598 }; 2599 2600 } // namespace 2601 2602 // All named ops canonicalizers and folders are auto-generated in the 2603 // .cpp.inc. 2604 2605 //===----------------------------------------------------------------------===// 2606 // SoftmaxOp 2607 //===----------------------------------------------------------------------===// 2608 2609 LogicalResult SoftmaxOp::verify() { 2610 ShapedType inputType = getInputOperandType(); 2611 ShapedType outputType = getOutputOperandType(); 2612 2613 ArrayRef<int64_t> inputShape = inputType.getShape(); 2614 ArrayRef<int64_t> outputShape = outputType.getShape(); 2615 if (failed(verifyCompatibleShape(inputShape, outputShape))) 2616 return emitOpError("incompatible output shape"); 2617 2618 int64_t inputRank = getInputOperandRank(); 2619 int64_t dimension = getDimension(); 2620 if ((dimension < 0) || (dimension >= inputRank)) 2621 return emitOpError("incorrect dimension specified"); 2622 2623 return success(); 2624 } 2625 2626 SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) { 2627 int64_t operandRank = getInputOperandRank(); 2628 SmallVector<Range> loopBounds(operandRank); 2629 Location loc = getLoc(); 2630 Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); 2631 Value one = builder.create<arith::ConstantIndexOp>(loc, 1); 2632 Value source = getInput(); 2633 for (auto dim : llvm::seq<int64_t>(0, operandRank)) { 2634 loopBounds[dim].offset = zero; 2635 loopBounds[dim].size = getDimValue(builder, loc, source, dim); 2636 loopBounds[dim].stride = one; 2637 } 2638 return loopBounds; 2639 } 2640 2641 SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() { 2642 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(), 2643 utils::IteratorType::parallel); 2644 iteratorTypes[getDimension()] = utils::IteratorType::reduction; 2645 return iteratorTypes; 2646 } 2647 2648 FailureOr<TilingResult> 2649 SoftmaxOp::getTiledImplementation(OpBuilder &builder, 2650 ArrayRef<OpFoldResult> offsets, 2651 ArrayRef<OpFoldResult> sizes) { 2652 int64_t rank = getInputOperandRank(); 2653 auto oneAttr = builder.getI64IntegerAttr(1); 2654 SmallVector<OpFoldResult> strides(rank, oneAttr); 2655 SmallVector<Value> tiledOperands; 2656 Operation *inputSlice = 2657 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides); 2658 if (!inputSlice) { 2659 return emitOpError("failed to compute input slice"); 2660 } 2661 tiledOperands.emplace_back(inputSlice->getResult(0)); 2662 Operation *outputSlice = 2663 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides); 2664 if (!outputSlice) { 2665 return emitOpError("failed to compute output slice"); 2666 } 2667 tiledOperands.emplace_back(outputSlice->getResult(0)); 2668 2669 SmallVector<Type, 4> resultTypes; 2670 if (hasPureTensorSemantics()) 2671 resultTypes.push_back(tiledOperands[1].getType()); 2672 Operation *tiledOp = 2673 mlir::clone(builder, getOperation(), resultTypes, tiledOperands); 2674 2675 return TilingResult{ 2676 {tiledOp}, 2677 SmallVector<Value>(tiledOp->getResults()), 2678 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})}; 2679 } 2680 2681 LogicalResult SoftmaxOp::getResultTilePosition( 2682 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets, 2683 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets, 2684 SmallVector<OpFoldResult> &resultSizes) { 2685 if (resultNumber == 0) { 2686 resultOffsets.assign(offsets.begin(), offsets.end()); 2687 resultSizes.assign(sizes.begin(), sizes.end()); 2688 return success(); 2689 } 2690 return failure(); 2691 } 2692 2693 // cast(dynamic) -> static. 2694 LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) { 2695 return memref::foldMemRefCast(*this); 2696 } 2697 2698 LogicalResult 2699 SoftmaxOp::reifyResultShapes(OpBuilder &b, 2700 ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 2701 SmallVector<OpFoldResult> shapes; 2702 Location loc = getOperation()->getLoc(); 2703 IRRewriter rewriter(b); 2704 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType()); 2705 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType()); 2706 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) { 2707 if (!outputShapedType.isDynamicDim(dim)) { 2708 // Static dim: Return IntegerAttr. 2709 shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim))); 2710 } else { 2711 // Dynamic dim: Return Value. 2712 OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim); 2713 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr)); 2714 } 2715 } 2716 reifiedReturnShapes.emplace_back(std::move(shapes)); 2717 return success(); 2718 } 2719 2720 void SoftmaxOp::getEffects( 2721 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 2722 &effects) { 2723 for (auto [index, operand] : llvm::enumerate(getDpsInputs())) { 2724 if (!llvm::isa<MemRefType>(operand.getType())) 2725 continue; 2726 effects.emplace_back(MemoryEffects::Read::get(), 2727 &getOperation()->getOpOperand(index), /*stage=*/0, 2728 /*effectOnFullRegion=*/true, 2729 SideEffects::DefaultResource::get()); 2730 } 2731 2732 for (OpOperand &operand : getDpsInitsMutable()) { 2733 if (!llvm::isa<MemRefType>(operand.get().getType())) 2734 continue; 2735 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0, 2736 /*effectOnFullRegion=*/true, 2737 SideEffects::DefaultResource::get()); 2738 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0, 2739 /*effectOnFullRegion=*/true, 2740 SideEffects::DefaultResource::get()); 2741 } 2742 } 2743 2744 // Helper functions for softmax decomposition. 2745 // @{ 2746 2747 // Helper function to produce the iterator types (reduction or parallel) and 2748 // affine maps for the iterators used in the decomposition of softmax. 2749 // This method creates: 2750 // If allParallel == true: 2751 // - iterator type: {parallel, ..., parallel} 2752 // - affine maps: 2753 // -- identity with inputRank dimensions. 2754 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN), 2755 // where N == inputRank. 2756 // 2757 // If allParallel == false: 2758 // - iterator type at dim(i) == parallel for i != \p dim and 2759 // dim(dim) == reduction. 2760 // - affine map: 2761 // -- identity with inputRank dimensions. 2762 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN), 2763 // where N == inputRank. 2764 static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>> 2765 computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, 2766 int64_t dim, bool allParallel = false) { 2767 SmallVector<utils::IteratorType> iteratorTypes(inputRank, 2768 utils::IteratorType::parallel); 2769 if (!allParallel) 2770 iteratorTypes[dim] = utils::IteratorType::reduction; 2771 MLIRContext *ctxt = builder.getContext(); 2772 auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt); 2773 SmallVector<AffineExpr, 2> affineExprs; 2774 for (int i = 0; i < inputRank; i++) { 2775 if (i != dim) 2776 affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt)); 2777 } 2778 auto reductionMap = 2779 AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt); 2780 SmallVector<AffineMap> indexingMaps{identityMap, reductionMap}; 2781 return std::make_tuple(iteratorTypes, indexingMaps); 2782 } 2783 2784 // Helper function to produce a linalg.generic that computes a reduction on 2785 // dimension \p dim with the operation type \p T. 2786 template <typename T> 2787 static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, 2788 int64_t dim) { 2789 auto inputType = cast<ShapedType>(input.getType()); 2790 ArrayRef<int64_t> inputShape = inputType.getShape(); 2791 int64_t inputRank = inputShape.size(); 2792 auto [iteratorTypes, indexingMaps] = 2793 computeIteratorTypesAndIndexingMaps(builder, inputRank, dim); 2794 assert(indexingMaps.size() == 2 && 2795 "We should have two maps: 1 for the input, 1 for the output"); 2796 assert(indexingMaps[0].isIdentity() && "input map should be identity"); 2797 2798 auto genericOp = builder.create<linalg::GenericOp>( 2799 loc, output.getType(), input, output, indexingMaps, iteratorTypes, 2800 [&](OpBuilder &b, Location loc, ValueRange args) { 2801 Value result = b.create<T>(loc, args[0], args[1]); 2802 b.create<linalg::YieldOp>(loc, result); 2803 }); 2804 return genericOp.getResult(0); 2805 } 2806 2807 /// Produce a linalg generic that computes the second step of the softmax 2808 /// decomposition: res = exp(input - max), where \p max is the max of \p input 2809 /// on dimension \p dim. 2810 static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, 2811 Value max, Value output, int64_t dim) { 2812 auto inputType = cast<ShapedType>(input.getType()); 2813 ArrayRef<int64_t> inputShape = inputType.getShape(); 2814 int64_t inputRank = inputShape.size(); 2815 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps( 2816 builder, inputRank, dim, /*allParallel=*/true); 2817 assert(indexingMaps.size() == 2 && "We should have one map for each input"); 2818 assert(indexingMaps[0].isIdentity() && "input map should be identity"); 2819 // Add the affine map for the output argument. 2820 indexingMaps.push_back(indexingMaps[0]); 2821 auto genericOp = builder.create<linalg::GenericOp>( 2822 loc, input.getType(), ValueRange{input, max}, output, indexingMaps, 2823 iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { 2824 Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]); 2825 Value result = b.create<math::ExpOp>(loc, diff); 2826 b.create<linalg::YieldOp>(loc, result); 2827 }); 2828 return genericOp.getResult(0); 2829 } 2830 2831 /// Produce a linalg generic that computes the final step of the softmax 2832 /// decomposition. 2833 /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) { 2834 /// yield n / d 2835 /// } 2836 static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, 2837 Value denominator, Value output, int64_t dim) { 2838 auto inputType = cast<ShapedType>(numerator.getType()); 2839 ArrayRef<int64_t> inputShape = inputType.getShape(); 2840 int64_t inputRank = inputShape.size(); 2841 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps( 2842 builder, inputRank, dim, /*allParallel=*/true); 2843 assert(indexingMaps.size() == 2 && 2844 "We should have one map for each input (2)"); 2845 assert(indexingMaps[0].isIdentity() && "Numerator map should be identity"); 2846 // Add the affine map for the output tensor. 2847 indexingMaps.push_back(indexingMaps[0]); 2848 auto genericOp = builder.create<linalg::GenericOp>( 2849 loc, numerator.getType(), ValueRange{numerator, denominator}, output, 2850 indexingMaps, iteratorTypes, 2851 [&](OpBuilder &b, Location loc, ValueRange args) { 2852 Value result = b.create<arith::DivFOp>(loc, args[0], args[1]); 2853 b.create<linalg::YieldOp>(loc, result); 2854 }); 2855 return genericOp.getResult(0); 2856 } 2857 // @} End helper functions for softmax decomposition. 2858 2859 /// Given an N-dimensional tensor x, this method converts 2860 /// softmax(x) to the following sequence of operations: 2861 /// 2862 /// 1. Compute the max of x along dimension d. This results 2863 /// in a N-1 dimensional tensor m. 2864 /// m = max(x, dim = d) 2865 /// 2866 /// 2. Subtract a broadcasted m from x and exponentiate. This results in 2867 /// a N dimensional tensor z. 2868 /// z = exp(x - m) 2869 /// 2870 /// 3. Compute the sum of z along dimension d. This results in 2871 /// a N-1 dimensional tensor l. 2872 /// l = sum(z, dim = d) 2873 /// 2874 /// 4. Divide z and l. This gives the N-dimensional softmax. 2875 /// softmax = z / l 2876 /// 2877 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) { 2878 OpBuilder::InsertionGuard guard(b); 2879 b.setInsertionPoint(*this); 2880 Location loc = getLoc(); 2881 Value input = getInput(); 2882 ShapedType inputType = getInputOperandType(); 2883 Type elementType = inputType.getElementType(); 2884 int64_t reductionDim = getDimension(); 2885 SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input); 2886 Value output = getOutput(); 2887 dims.erase(dims.begin() + reductionDim); 2888 // Step 1: Compute max along dim. 2889 Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType); 2890 Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf, 2891 elementType, b, loc, 2892 /*useOnlyFiniteValue=*/true); 2893 Value neutralForMaxFInit = 2894 b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce) 2895 .result(); 2896 Value max = 2897 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim); 2898 2899 // Step 2: Subtract max from input and exponentiate. 2900 Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim); 2901 2902 // Step 3: Compute sum along dim. 2903 Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, 2904 b, loc, /*useOnlyFiniteValue=*/true); 2905 Value zeroInit = 2906 b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result(); 2907 Value denominator = 2908 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim); 2909 2910 // Step 4: Compute softmax. 2911 Value result = 2912 buildDivOp(b, loc, numerator, denominator, output, reductionDim); 2913 return SmallVector<Value>{result}; 2914 } 2915 2916 //===----------------------------------------------------------------------===// 2917 // WinogradFilterTransformOp 2918 //===----------------------------------------------------------------------===// 2919 2920 LogicalResult WinogradFilterTransformOp::verify() { 2921 auto filterType = cast<ShapedType>(getFilter().getType()); 2922 ArrayRef<int64_t> filterShape = filterType.getShape(); 2923 int64_t filterH = filterShape[getFilterHDim()]; 2924 int64_t filterW = filterShape[getFilterWDim()]; 2925 int64_t r = getR(); 2926 int64_t m = getM(); 2927 2928 if (filterH != r && filterH != 1) 2929 return emitOpError("expect filter height either equals to r or 1"); 2930 if (filterW != r && filterW != 1) 2931 return emitOpError("expect filter width either equals to r or 1"); 2932 if (filterH == 1 && filterW == 1) 2933 return emitOpError("expect either filter height or width equals to r"); 2934 2935 SmallVector<int64_t> expectedOutputShape; 2936 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1); 2937 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1); 2938 expectedOutputShape.push_back(filterShape[getFilterCDim()]); 2939 expectedOutputShape.push_back(filterShape[getFilterFDim()]); 2940 2941 auto outputType = cast<ShapedType>(getOutput().getType()); 2942 ArrayRef<int64_t> outputShape = outputType.getShape(); 2943 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { 2944 return emitOpError("the output shape is not expected"); 2945 } 2946 return success(); 2947 } 2948 2949 SmallVector<Range> 2950 WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) { 2951 Location loc = getLoc(); 2952 IntegerAttr zeroAttr = builder.getIndexAttr(0); 2953 IntegerAttr oneAttr = builder.getIndexAttr(1); 2954 Value filter = getFilter(); 2955 int64_t filterRank = getFilterOperandRank(); 2956 SmallVector<Range> loopBounds(filterRank); 2957 for (unsigned dim = 0; dim < filterRank; ++dim) { 2958 loopBounds[dim].offset = zeroAttr; 2959 loopBounds[dim].size = getDimValue(builder, loc, filter, dim); 2960 loopBounds[dim].stride = oneAttr; 2961 } 2962 return loopBounds; 2963 } 2964 2965 SmallVector<utils::IteratorType> 2966 WinogradFilterTransformOp::getLoopIteratorTypes() { 2967 int64_t filterRank = getFilterOperandRank(); 2968 SmallVector<utils::IteratorType> iteratorTypes(filterRank, 2969 utils::IteratorType::parallel); 2970 return iteratorTypes; 2971 } 2972 2973 LogicalResult WinogradFilterTransformOp::getResultTilePosition( 2974 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets, 2975 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets, 2976 SmallVector<OpFoldResult> &resultSizes) { 2977 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0); 2978 ShapedType filterType = getFilterOperandType(); 2979 ArrayRef<int64_t> filterShape = filterType.getShape(); 2980 int64_t filterH = filterShape[getFilterHDim()]; 2981 int64_t filterW = filterShape[getFilterWDim()]; 2982 int64_t m = getM(); 2983 int64_t r = getR(); 2984 int64_t alpha = m + r - 1; 2985 int64_t alphaH = filterH != 1 ? alpha : 1; 2986 int64_t alphaW = filterW != 1 ? alpha : 1; 2987 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH); 2988 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW); 2989 2990 resultOffsets.append( 2991 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]}); 2992 resultSizes.append( 2993 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]}); 2994 2995 return success(); 2996 } 2997 2998 /// Implement tiling for winograd_filter_transform 2999 /// The input of winograd_filter_transform is (F, KH, KW, C). 3000 /// The output of winograd_filter_transform is (alphaH, alphaW, C, F) 3001 /// Users can specify the tile sizes of F and C. 3002 /// `offsets` are the values for the offsets of F, KH, KW, C for one tile. 3003 /// `sizes` are the values for the sizes of F, KH, KW, C for one tile. 3004 FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation( 3005 OpBuilder &builder, ArrayRef<OpFoldResult> offsets, 3006 ArrayRef<OpFoldResult> sizes) { 3007 IntegerAttr oneAttr = builder.getI64IntegerAttr(1); 3008 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0); 3009 ShapedType filterType = getFilterOperandType(); 3010 ArrayRef<int64_t> filterShape = filterType.getShape(); 3011 int64_t filterH = filterShape[getFilterHDim()]; 3012 int64_t filterW = filterShape[getFilterWDim()]; 3013 IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH); 3014 IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW); 3015 SmallVector<Value> tiledOperands; 3016 SmallVector<OpFoldResult> sliceOffsets, sliceSizes; 3017 3018 sliceOffsets.append( 3019 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]}); 3020 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr, 3021 sizes[getFilterCDim()]}); 3022 int64_t filterRank = getFilterOperandRank(); 3023 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr); 3024 Location loc = getLoc(); 3025 auto filterSlice = builder.create<tensor::ExtractSliceOp>( 3026 loc, getFilter(), sliceOffsets, sliceSizes, filterStrides); 3027 tiledOperands.emplace_back(filterSlice); 3028 3029 SmallVector<OpFoldResult> resultOffsets, resultSizes; 3030 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets, 3031 resultSizes))) 3032 return failure(); 3033 3034 int64_t outputRank = getOutputOperandRank(); 3035 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr); 3036 auto outputSlice = builder.create<tensor::ExtractSliceOp>( 3037 loc, getOutput(), resultOffsets, resultSizes, outputStrides); 3038 tiledOperands.emplace_back(outputSlice); 3039 3040 SmallVector<Type> resultTypes; 3041 resultTypes.push_back(tiledOperands[1].getType()); 3042 Operation *tiledOp = 3043 mlir::clone(builder, getOperation(), resultTypes, tiledOperands); 3044 3045 return TilingResult{ 3046 {tiledOp}, 3047 SmallVector<Value>(tiledOp->getResults()), 3048 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})}; 3049 } 3050 3051 //===----------------------------------------------------------------------===// 3052 // WinogradInputTransformOp 3053 //===----------------------------------------------------------------------===// 3054 3055 LogicalResult WinogradInputTransformOp::verify() { 3056 auto inputType = cast<ShapedType>(getInput().getType()); 3057 ArrayRef<int64_t> inputShape = inputType.getShape(); 3058 int64_t inputH = inputShape[getInputHDim()]; 3059 int64_t inputW = inputShape[getInputWDim()]; 3060 int m = getM(); 3061 int r = getR(); 3062 int64_t tileSize = m + r - 1; 3063 3064 auto outputType = cast<ShapedType>(getOutput().getType()); 3065 ArrayRef<int64_t> outputShape = outputType.getShape(); 3066 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1; 3067 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1; 3068 3069 SmallVector<int64_t> expectedOutputShape(6, inputH); 3070 if (ShapedType::isDynamic(inputH)) { 3071 expectedOutputShape[getOutputAlphaHDim()] = tileSize; 3072 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic; 3073 } else { 3074 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1; 3075 expectedOutputShape[getOutputTileHDim()] = 3076 leftTransform ? (inputH - (r - 1)) / m : inputH; 3077 } 3078 if (ShapedType::isDynamic(inputW)) { 3079 expectedOutputShape[getOutputAlphaWDim()] = tileSize; 3080 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic; 3081 } else { 3082 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1; 3083 expectedOutputShape[getOutputTileWDim()] = 3084 rightTransform ? (inputW - (r - 1)) / m : inputW; 3085 } 3086 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()]; 3087 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()]; 3088 3089 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { 3090 return emitOpError("the output shape is not expected"); 3091 } 3092 return success(); 3093 } 3094 3095 SmallVector<Range> 3096 WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) { 3097 Location loc = getLoc(); 3098 IntegerAttr zeroAttr = builder.getIndexAttr(0); 3099 IntegerAttr oneAttr = builder.getIndexAttr(1); 3100 Value output = getOutput(); 3101 int64_t outputRank = getOutputOperandRank(); 3102 SmallVector<Range> loopBounds(outputRank); 3103 for (unsigned dim = 0; dim < outputRank; ++dim) { 3104 loopBounds[dim].offset = zeroAttr; 3105 // alphaH, alphaW, tileH, tileW, N, C 3106 loopBounds[dim].size = getDimValue(builder, loc, output, dim); 3107 loopBounds[dim].stride = oneAttr; 3108 } 3109 return loopBounds; 3110 } 3111 3112 SmallVector<utils::IteratorType> 3113 WinogradInputTransformOp::getLoopIteratorTypes() { 3114 int64_t outputRank = getOutputOperandRank(); 3115 SmallVector<utils::IteratorType> iteratorTypes(outputRank, 3116 utils::IteratorType::parallel); 3117 return iteratorTypes; 3118 } 3119 3120 LogicalResult WinogradInputTransformOp::getResultTilePosition( 3121 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets, 3122 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets, 3123 SmallVector<OpFoldResult> &resultSizes) { 3124 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0); 3125 ShapedType outputType = getOutputOperandType(); 3126 ArrayRef<int64_t> outputShape = outputType.getShape(); 3127 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()]; 3128 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()]; 3129 3130 int64_t m = getM(); 3131 int64_t r = getR(); 3132 int64_t alpha = m + r - 1; 3133 int64_t alphaH = outputAlphaH != 1 ? alpha : 1; 3134 int64_t alphaW = outputAlphaW != 1 ? alpha : 1; 3135 3136 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH); 3137 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW); 3138 3139 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()], 3140 offsets[getOutputTileWDim()], offsets[getOutputNDim()], 3141 offsets[getOutputCDim()]}); 3142 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()], 3143 sizes[getOutputTileWDim()], sizes[getOutputNDim()], 3144 sizes[getOutputCDim()]}); 3145 3146 return success(); 3147 } 3148 3149 /// Implement tiling for winograd_input_transform 3150 /// The input of winograd_input_transform is (N, H, W, C). 3151 /// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N, 3152 /// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are 3153 /// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are 3154 /// the values for the sizes of tileH, tileW, N, C for one tile. 3155 FailureOr<TilingResult> 3156 WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder, 3157 ArrayRef<OpFoldResult> offsets, 3158 ArrayRef<OpFoldResult> sizes) { 3159 IntegerAttr oneAttr = builder.getI64IntegerAttr(1); 3160 int64_t m = getM(); 3161 int64_t r = getR(); 3162 3163 ShapedType outputType = getOutputOperandType(); 3164 ArrayRef<int64_t> outputShape = outputType.getShape(); 3165 int64_t alphaH = outputShape[getOutputAlphaHDim()]; 3166 int64_t alphaW = outputShape[getOutputAlphaWDim()]; 3167 3168 Location loc = getLoc(); 3169 MLIRContext *context = builder.getContext(); 3170 auto identityAffineMap = 3171 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context); 3172 auto offsetAffineMap = 3173 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); 3174 Value mappedOffsetH = affine::makeComposedAffineApply( 3175 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap), 3176 offsets[getOutputTileHDim()]); 3177 Value mappedOffsetW = affine::makeComposedAffineApply( 3178 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap), 3179 offsets[getOutputTileWDim()]); 3180 auto sizeAffineMap = AffineMap::get( 3181 1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context); 3182 Value mappedSizeH = affine::makeComposedAffineApply( 3183 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]); 3184 Value mappedSizeW = affine::makeComposedAffineApply( 3185 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]); 3186 3187 SmallVector<Value> tiledOperands; 3188 SmallVector<OpFoldResult> sliceOffsets, sliceSizes; 3189 3190 OpFoldResult offsetH = OpFoldResult(mappedOffsetH); 3191 OpFoldResult offsetW = OpFoldResult(mappedOffsetW); 3192 sliceOffsets.append( 3193 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]}); 3194 OpFoldResult sizeH = 3195 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr); 3196 OpFoldResult sizeW = 3197 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr); 3198 sliceSizes.append( 3199 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]}); 3200 int64_t inputRank = getInputOperandRank(); 3201 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr); 3202 auto inputSlice = builder.create<tensor::ExtractSliceOp>( 3203 loc, getInput(), sliceOffsets, sliceSizes, inputStrides); 3204 tiledOperands.emplace_back(inputSlice); 3205 3206 SmallVector<OpFoldResult> resultOffsets, resultSizes; 3207 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets, 3208 resultSizes))) 3209 return failure(); 3210 3211 int64_t outputRank = getOutputOperandRank(); 3212 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr); 3213 auto outputSlice = builder.create<tensor::ExtractSliceOp>( 3214 loc, getOutput(), resultOffsets, resultSizes, outputStrides); 3215 tiledOperands.emplace_back(outputSlice); 3216 3217 SmallVector<Type> resultTypes; 3218 resultTypes.push_back(tiledOperands[1].getType()); 3219 Operation *tiledOp = 3220 mlir::clone(builder, getOperation(), resultTypes, tiledOperands); 3221 3222 return TilingResult{ 3223 {tiledOp}, 3224 SmallVector<Value>(tiledOp->getResults()), 3225 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})}; 3226 } 3227 3228 //===----------------------------------------------------------------------===// 3229 // WinogradOutputTransformOp 3230 //===----------------------------------------------------------------------===// 3231 3232 LogicalResult WinogradOutputTransformOp::verify() { 3233 auto valueType = cast<ShapedType>(getValue().getType()); 3234 ArrayRef<int64_t> valueShape = valueType.getShape(); 3235 int64_t valueH = valueShape[getValueAlphaHDim()]; 3236 int64_t valueW = valueShape[getValueAlphaWDim()]; 3237 int64_t valueTileH = valueShape[getValueTileHDim()]; 3238 int64_t valueTileW = valueShape[getValueTileWDim()]; 3239 int m = getM(); 3240 int r = getR(); 3241 bool leftTransform = valueH != 1; 3242 bool rightTransform = valueW != 1; 3243 3244 int64_t outputRank = getOutputOperandRank(); 3245 SmallVector<int64_t> expectedOutputShape(outputRank, valueH); 3246 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) { 3247 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic; 3248 } else { 3249 if (valueH != (leftTransform ? m + r - 1 : 1)) 3250 return emitOpError("expect input height equals to input tile size"); 3251 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH; 3252 } 3253 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) { 3254 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic; 3255 } else { 3256 if (valueW != (rightTransform ? m + r - 1 : 1)) 3257 return emitOpError("expect input width equals to input tile size"); 3258 expectedOutputShape[getOutputWDim()] = 3259 (rightTransform ? m : 1) * valueTileW; 3260 } 3261 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()]; 3262 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()]; 3263 3264 auto outputType = cast<ShapedType>(getOutput().getType()); 3265 ArrayRef<int64_t> outputShape = outputType.getShape(); 3266 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { 3267 return emitOpError("the output shape is not expected"); 3268 } 3269 return success(); 3270 } 3271 3272 SmallVector<Range> 3273 WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) { 3274 Location loc = getLoc(); 3275 IntegerAttr zeroAttr = builder.getIndexAttr(0); 3276 IntegerAttr oneAttr = builder.getIndexAttr(1); 3277 Value value = getValue(); 3278 int64_t valueRank = getValueOperandRank(); 3279 SmallVector<Range> loopBounds(valueRank); 3280 for (unsigned dim = 0; dim < valueRank; ++dim) { 3281 loopBounds[dim].offset = zeroAttr; 3282 // alphaH, alphaW, tileH, tileW, N, F 3283 loopBounds[dim].size = getDimValue(builder, loc, value, dim); 3284 loopBounds[dim].stride = oneAttr; 3285 } 3286 return loopBounds; 3287 } 3288 3289 SmallVector<utils::IteratorType> 3290 WinogradOutputTransformOp::getLoopIteratorTypes() { 3291 int64_t valueRank = getValueOperandRank(); 3292 SmallVector<utils::IteratorType> iteratorTypes(valueRank, 3293 utils::IteratorType::parallel); 3294 return iteratorTypes; 3295 } 3296 3297 LogicalResult WinogradOutputTransformOp::getResultTilePosition( 3298 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets, 3299 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets, 3300 SmallVector<OpFoldResult> &resultSizes) { 3301 int64_t m = getM(); 3302 3303 Location loc = getLoc(); 3304 MLIRContext *context = builder.getContext(); 3305 auto identityAffineMap = 3306 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context); 3307 auto affineMap = 3308 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); 3309 3310 ShapedType valueType = getValueOperandType(); 3311 ArrayRef<int64_t> valueShape = valueType.getShape(); 3312 int64_t valueH = valueShape[0]; 3313 int64_t valueW = valueShape[1]; 3314 Value mappedOffsetH = affine::makeComposedAffineApply( 3315 builder, loc, (valueH != 1 ? affineMap : identityAffineMap), 3316 offsets[getValueTileHDim()]); 3317 Value mappedOffsetW = affine::makeComposedAffineApply( 3318 builder, loc, (valueW != 1 ? affineMap : identityAffineMap), 3319 offsets[getValueTileWDim()]); 3320 Value mappedSizeH = affine::makeComposedAffineApply( 3321 builder, loc, affineMap, sizes[getValueTileHDim()]); 3322 Value mappedSizeW = affine::makeComposedAffineApply( 3323 builder, loc, affineMap, sizes[getValueTileWDim()]); 3324 3325 IntegerAttr oneAttr = builder.getI64IntegerAttr(1); 3326 OpFoldResult offsetH = OpFoldResult(mappedOffsetH); 3327 OpFoldResult offsetW = OpFoldResult(mappedOffsetW); 3328 OpFoldResult sizeH = 3329 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr); 3330 OpFoldResult sizeW = 3331 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr); 3332 3333 resultOffsets.append( 3334 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]}); 3335 resultSizes.append( 3336 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]}); 3337 return success(); 3338 } 3339 3340 /// Implement tiling for winograd_output_transform 3341 /// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N, 3342 /// F). The output of winograd_output_transform is (N, H, W, F) Users can 3343 /// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values 3344 /// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values 3345 /// for the sizes of tileH, tileW, N, F for one tile. 3346 FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation( 3347 OpBuilder &builder, ArrayRef<OpFoldResult> offsets, 3348 ArrayRef<OpFoldResult> sizes) { 3349 IntegerAttr oneAttr = builder.getI64IntegerAttr(1); 3350 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0); 3351 Location loc = getLoc(); 3352 SmallVector<Value> tiledOperands; 3353 SmallVector<OpFoldResult> sliceOffsets, sliceSizes; 3354 3355 ShapedType valueType = getValueOperandType(); 3356 ArrayRef<int64_t> valueShape = valueType.getShape(); 3357 int64_t alphaH = valueShape[getValueAlphaHDim()]; 3358 int64_t alphaW = valueShape[getValueAlphaWDim()]; 3359 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH); 3360 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW); 3361 3362 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()], 3363 offsets[getValueTileWDim()], offsets[getValueNDim()], 3364 offsets[getValueFDim()]}); 3365 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()], 3366 sizes[getValueTileWDim()], sizes[getValueNDim()], 3367 sizes[getValueFDim()]}); 3368 int64_t valueRank = getValueOperandRank(); 3369 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr); 3370 auto valueSlice = builder.create<tensor::ExtractSliceOp>( 3371 loc, getValue(), sliceOffsets, sliceSizes, sliceStrides); 3372 tiledOperands.emplace_back(valueSlice); 3373 3374 SmallVector<OpFoldResult> resultOffsets, resultSizes; 3375 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets, 3376 resultSizes))) 3377 return failure(); 3378 3379 int64_t outputRank = getOutputOperandRank(); 3380 SmallVector<OpFoldResult> strides(outputRank, oneAttr); 3381 auto outputSlice = builder.create<tensor::ExtractSliceOp>( 3382 loc, getOutput(), resultOffsets, resultSizes, strides); 3383 tiledOperands.emplace_back(outputSlice); 3384 3385 SmallVector<Type> resultTypes; 3386 resultTypes.push_back(tiledOperands[1].getType()); 3387 Operation *tiledOp = 3388 mlir::clone(builder, getOperation(), resultTypes, tiledOperands); 3389 3390 return TilingResult{ 3391 {tiledOp}, 3392 SmallVector<Value>(tiledOp->getResults()), 3393 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})}; 3394 } 3395 3396 //===----------------------------------------------------------------------===// 3397 // LinalgDialect 3398 //===----------------------------------------------------------------------===// 3399 3400 void LinalgDialect::getCanonicalizationPatterns( 3401 RewritePatternSet &results) const { 3402 results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, 3403 InferStaticShapeOfOperands>(getContext()); 3404 } 3405 3406 Operation *LinalgDialect::materializeConstant(OpBuilder &builder, 3407 Attribute value, Type type, 3408 Location loc) { 3409 return arith::ConstantOp::materialize(builder, value, type, loc); 3410 } 3411 3412 /// Returns true if the result AffineExpr of the \p explicitMap is same as \p 3413 /// defaultMap. 3414 static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) { 3415 auto explicitRange = explictMap.getResults(); 3416 auto defaultRange = defaultMap.getResults(); 3417 DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end()); 3418 DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end()); 3419 llvm::set_union(explicitSet, defaultSet); 3420 return explicitSet == defaultSet; 3421 } 3422 3423 /// Returns true if the \p explictMap is broadcasted with respect to the 3424 /// \p defaultMap. 3425 static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) { 3426 return explictMap.getNumResults() < defaultMap.getNumResults(); 3427 } 3428 3429 /// Verifies the broadcast and transpose semantic sepecified by the explicit 3430 /// indexing map for the MatmulOp \p op for each operand specified by \p 3431 /// opIndex. 3432 static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, 3433 unsigned opIndex) { 3434 SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray(); 3435 SmallVector<AffineMap, 3> defaultIndexingMaps = 3436 matmulOp.getDefaultIndexingMaps(matmulOp->getContext()); 3437 3438 auto opIndexingMap = opIndexingMaps[opIndex]; 3439 auto defaultIndexingMap = defaultIndexingMaps[opIndex]; 3440 // Check general validity of indexing map results. 3441 if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap)) 3442 return matmulOp->emitOpError() 3443 << "Unexpected dim expression in map result."; 3444 3445 // Check if the requested broadcast is valid. 3446 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) { 3447 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) { 3448 return matmulOp->emitOpError() 3449 << "Invalid broadcast requested, should be (d2)."; 3450 } 3451 return success(); 3452 } 3453 return success(); 3454 } 3455 3456 namespace mlir { 3457 namespace linalg { 3458 3459 //===----------------------------------------------------------------------===// 3460 // MatMulOp 3461 //===----------------------------------------------------------------------===// 3462 3463 /// Returns a list of AffineMap with the typical matmul indexing charactristic. 3464 SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) { 3465 AffineExpr d0, d1, d2; 3466 SmallVector<AffineMap> indexingMaps; 3467 bindDims(context, d0, d1, d2); 3468 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context)); 3469 indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context)); 3470 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context)); 3471 return indexingMaps; 3472 } 3473 3474 SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() { 3475 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel, 3476 utils::IteratorType::parallel, 3477 utils::IteratorType::reduction}; 3478 } 3479 3480 unsigned MatmulOp::getNumRegionArgs() { return 3; } 3481 3482 std::string MatmulOp::getLibraryCallName() { 3483 return generateLibraryCallName(getOperation()); 3484 } 3485 3486 bool MatmulOp::hasDynamicIndexingMaps() { return true; } 3487 3488 /// Check if the op has broadcast and/or transpose semantic. Returns true if 3489 /// the user defined indexing maps are not equal to default map. 3490 bool MatmulOp::hasUserDefinedMaps() { 3491 SmallVector<AffineMap, 3> defaultMaps = 3492 getDefaultIndexingMaps(this->getContext()); 3493 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray(); 3494 return defaultMaps != explicitMaps; 3495 } 3496 3497 /// Implements the block region builder for the MatmulOp. This is called by 3498 /// 'fillStructuredOpRegion'. 3499 void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, 3500 ArrayRef<NamedAttribute> attrs) { 3501 assert(3 > 0 && block.getNumArguments() == 3 && 3502 "MatmulOp regionBuilder expects 3 (>=0) args"); 3503 RegionBuilderHelper helper(b, block); 3504 SmallVector<Value> yields; 3505 3506 TypeFn castVal = TypeFn::cast_signed; 3507 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) { 3508 return attr.getName() == "cast"; 3509 }); 3510 if (castIter != attrs.end()) { 3511 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue())) 3512 castVal = attr.getValue(); 3513 } 3514 3515 Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(), 3516 block.getArgument(0)); 3517 Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(), 3518 block.getArgument(1)); 3519 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2); 3520 Value value4 = 3521 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3); 3522 yields.push_back(value4); 3523 helper.yieldOutputs(yields); 3524 } 3525 3526 /// Returns true if the given broadcast map \p bcastMap is valid for this op. 3527 bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) { 3528 assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr."); 3529 AffineExpr exp = bcastMap.getResult(0); 3530 // Invalid map if the common dimension of matmul not found. 3531 return exp.isFunctionOfDim(bcastMap.getNumDims() - 1); 3532 } 3533 3534 FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) { 3535 if (parser.parseOptionalKeyword("indexing_maps")) 3536 return {nullptr}; // Success in case indexing_maps was not provided. 3537 3538 ArrayAttr arrayAttr; 3539 if (parser.parseEqual() || parser.parseAttribute(arrayAttr)) 3540 return failure(); 3541 3542 if (llvm::any_of(arrayAttr, 3543 [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); })) 3544 return parser.emitError(parser.getCurrentLocation()) 3545 << "element of indexing_maps array is not an affine_map"; 3546 3547 return arrayAttr; 3548 } 3549 3550 ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) { 3551 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser); 3552 if (failed(indexingMapsAttr)) 3553 return failure(); 3554 3555 if (*indexingMapsAttr == nullptr) { 3556 auto indexingMapAttrs = llvm::map_to_vector( 3557 MatmulOp::getDefaultIndexingMaps(parser.getContext()), 3558 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); 3559 indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs); 3560 } 3561 3562 result.addAttribute("indexing_maps", *indexingMapsAttr); 3563 return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(), 3564 MatmulOp::getRegionBuilder()); 3565 } 3566 3567 void MatmulOp::print(OpAsmPrinter &p) { 3568 SmallVector<StringRef, 3> elidedAttrs = { 3569 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"}; 3570 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(), 3571 elidedAttrs); 3572 3573 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector( 3574 MatmulOp::getDefaultIndexingMaps(getContext()), 3575 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); 3576 if (!llvm::equal(getIndexingMaps(), indexingMaps)) { 3577 p << " indexing_maps = ["; 3578 llvm::interleaveComma(getIndexingMaps(), p, 3579 [&](Attribute attr) { p.printAttribute(attr); }); 3580 p << "]"; 3581 } 3582 } 3583 3584 /// Verify the user defined indexing maps. 3585 LogicalResult MatmulOp::verify() { 3586 // Verification of pure matmul is handled by verifyStructuredOpInterface(). 3587 if (!hasUserDefinedMaps()) 3588 return success(); 3589 3590 for (unsigned opIndex = 0; opIndex < 2; opIndex++) { 3591 if (failed(verifyExtendedMatmulSemantic(*this, opIndex))) 3592 return failure(); 3593 } 3594 return success(); 3595 } 3596 3597 LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) { 3598 return memref::foldMemRefCast(*this); 3599 } 3600 3601 void MatmulOp::getEffects( 3602 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 3603 &effects) { 3604 if (hasPureTensorSemantics()) 3605 return; 3606 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation())); 3607 } 3608 3609 Speculation::Speculatability MatmulOp::getSpeculatability() { 3610 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation())); 3611 } 3612 3613 //===----------------------------------------------------------------------===// 3614 // ContractOp 3615 //===----------------------------------------------------------------------===// 3616 3617 SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() { 3618 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val(); 3619 // On well-formed IR, indexing_maps is non-empty, contained affine_maps' 3620 // domains are all the same, and each implements a projected permutation. 3621 // Each iteration space dim must occur for at least one operand and either 3622 // takes part in a contraction/reduction or else has parallel iteration type. 3623 // We have that a dim is a contraction/reduction dim if and only if the dim 3624 // occurs for the output operand. We use this fact for fast inference: 3625 // NB: In case we allow dims to occur solely for one input, the above still 3626 // holds: per the einsum semantics, these are reduction dims as well. 3627 SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false); 3628 for (auto result : outAffineMap.getResults()) { 3629 auto dimExpr = dyn_cast<AffineDimExpr>(result); 3630 assert(dimExpr && "affine_map is a projected permutation"); 3631 dimsInOutput[dimExpr.getPosition()] = true; 3632 } 3633 3634 SmallVector<utils::IteratorType> iteratorTypes; 3635 for (auto dimOccursInOutput : dimsInOutput) 3636 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel 3637 : utils::IteratorType::reduction); 3638 3639 return iteratorTypes; 3640 } 3641 3642 unsigned ContractOp::getNumRegionArgs() { return 3; } 3643 3644 /// Implement block region builder, which is called by 'fillStructuredOpRegion'. 3645 void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, 3646 ArrayRef<NamedAttribute> attrs) { 3647 assert(block.getNumArguments() == 3 && 3648 "ContractOp regionBuilder expects 3 args"); 3649 RegionBuilderHelper helper(b, block); 3650 3651 TypeFn castSignedness = TypeFn::cast_signed; 3652 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) { 3653 return attr.getName() == "cast"; 3654 }); 3655 if (castIter != attrs.end()) { 3656 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue())) 3657 castSignedness = attr.getValue(); 3658 } 3659 3660 // TODO: Support fields with operators besides mult & add. 3661 Type outType = block.getArgument(2).getType(); 3662 Value lhsAtOutType = 3663 helper.buildTypeFn(castSignedness, outType, block.getArgument(0)); 3664 Value rhsAtOutType = 3665 helper.buildTypeFn(castSignedness, outType, block.getArgument(1)); 3666 Value productAtOutType = 3667 helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType); 3668 Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), 3669 productAtOutType); 3670 helper.yieldOutputs({result}); 3671 } 3672 3673 ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) { 3674 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser); 3675 if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr) 3676 return parser.emitError(parser.getCurrentLocation(), 3677 "expected 'indexing_maps' attribute"); 3678 result.addAttribute("indexing_maps", *indexingMapsAttr); 3679 3680 return parseNamedStructuredOp(parser, result, getNumRegionArgs(), 3681 regionBuilder); 3682 } 3683 3684 void ContractOp::print(OpAsmPrinter &p) { 3685 p << " indexing_maps = ["; 3686 llvm::interleaveComma(getIndexingMaps(), p, 3687 [&](Attribute attr) { p.printAttribute(attr); }); 3688 p << "]"; 3689 printNamedStructuredOp( 3690 p, getOperation(), getInputs(), getOutputs(), 3691 /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"}); 3692 } 3693 3694 LogicalResult ContractOp::verify() { 3695 int iterationSpaceDims = -1; 3696 // Map iter space dims to #occurrences in inputs' and output's affine_maps: 3697 // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to 3698 // access an input operand (so occurrence count can be at most 2) and 3699 // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc. 3700 SmallVector<size_t> inOccurrences; 3701 SmallVector<size_t> outOccurrences; 3702 3703 // A helper so that for each operand's affine_map and type we check that ... 3704 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType, 3705 bool isInput) -> LogicalResult { 3706 // ... the affine_map is a projected permutation; 3707 if (!affineMap.isProjectedPermutation()) 3708 return emitError("provided affine_map is not a projected permutation"); 3709 3710 // ... the rank of the affine_map's results and corresponding type match; 3711 if (auto shapedType = dyn_cast<ShapedType>(operandType)) { 3712 if (affineMap.getNumResults() != shapedType.getRank()) 3713 return emitError("ranks of shaped operand and results of corresponding " 3714 "affine_map differ"); 3715 } else if (affineMap.getNumResults() != 0) { 3716 return emitError("affine_map specifies shaped access while operand has " 3717 "non-shaped type"); 3718 } 3719 3720 // ... the rank of the affine_map's domain is the same as those seen prior; 3721 if (iterationSpaceDims == -1) { 3722 iterationSpaceDims = affineMap.getNumDims(); 3723 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0); 3724 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0); 3725 } else if (iterationSpaceDims != (int)affineMap.getNumDims()) { 3726 return emitError("iteration spaces of provided affine_maps differ"); 3727 } 3728 3729 // ... update counts of dims used to access either an input or the output. 3730 for (AffineExpr affineExpr : affineMap.getResults()) { 3731 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr); 3732 if (!affineDimExpr) 3733 llvm_unreachable("affine_map is a projected permutation"); 3734 3735 if (isInput) 3736 inOccurrences[affineDimExpr.getPosition()] += 1; 3737 else 3738 outOccurrences[affineDimExpr.getPosition()] += 1; 3739 } 3740 3741 return success(); 3742 }; 3743 3744 for (auto &&[affineMap, operandType, isInput] : 3745 llvm::zip(getIndexingMapsArray(), getOperandTypes(), 3746 SmallVector<bool>{true, true, false})) { 3747 if (failed(checkAffineMapAndType(affineMap, operandType, isInput))) 3748 return failure(); // NB: checkAffineMapAndType will emit relevant error. 3749 } 3750 3751 bool hasContractingDim = false; 3752 for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) { 3753 size_t inOccCount = inOccurrences[dimIndex]; 3754 size_t outOccCount = outOccurrences[dimIndex]; 3755 3756 // We have a contracting dim if and only if ... 3757 hasContractingDim |= inOccCount == 2 && outOccCount == 0; 3758 3759 if (inOccCount == 0 && outOccCount == 0) 3760 return emitError() << "iteration space dim at index " << dimIndex 3761 << " not used to access any operand"; 3762 3763 // NB: We disallow a dim which occurs for only one input operand and not 3764 // for the output. In terms of einsum semantics such dims have a 3765 // sensible meaning - namely an additional reduction per each such dim. 3766 // By contrast, the ContractionOpInterface does not know about this 3767 // iter type - cf. inferContractionDims' supported dim kinds. Similarly, 3768 // while vector.contract's verifier accepts dims of this kind many of 3769 // its lowerings give up on encountering these dims. 3770 // TODO: Remove following once we have comprehensive support for input-only 3771 // reduction dims, at both the linalg- and vector-dialect levels. 3772 if (inOccCount == 1 && outOccCount != 1) 3773 return emitError() 3774 << "iteration space dim at index " << dimIndex 3775 << " is neither a contracting dim nor of parallel iteration type"; 3776 } 3777 3778 if (!hasContractingDim) 3779 return emitError("'indexing_maps' do not specify a contracting dimension"); 3780 3781 return success(); 3782 } 3783 3784 LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) { 3785 return memref::foldMemRefCast(*this); 3786 } 3787 3788 void ContractOp::getEffects( 3789 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 3790 &effects) { 3791 if (hasPureTensorSemantics()) 3792 return; 3793 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation())); 3794 } 3795 3796 Speculation::Speculatability ContractOp::getSpeculatability() { 3797 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation())); 3798 } 3799 3800 } // namespace linalg 3801 } // namespace mlir 3802