14cb2ef4fSOleksandr "Alex" Zinenko //===-- MyExtension.cpp - Transform dialect tutorial ----------------------===// 24cb2ef4fSOleksandr "Alex" Zinenko // 34cb2ef4fSOleksandr "Alex" Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 44cb2ef4fSOleksandr "Alex" Zinenko // See https://llvm.org/LICENSE.txt for license information. 54cb2ef4fSOleksandr "Alex" Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 64cb2ef4fSOleksandr "Alex" Zinenko // 74cb2ef4fSOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 84cb2ef4fSOleksandr "Alex" Zinenko // 94cb2ef4fSOleksandr "Alex" Zinenko // This file defines Transform dialect extension operations used in the 104cb2ef4fSOleksandr "Alex" Zinenko // Chapter 4 of the Transform dialect tutorial. 114cb2ef4fSOleksandr "Alex" Zinenko // 124cb2ef4fSOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 134cb2ef4fSOleksandr "Alex" Zinenko 144cb2ef4fSOleksandr "Alex" Zinenko #include "MyExtension.h" 154cb2ef4fSOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/IR/TransformDialect.h" 164cb2ef4fSOleksandr "Alex" Zinenko #include "llvm/Support/Debug.h" 174cb2ef4fSOleksandr "Alex" Zinenko 184cb2ef4fSOleksandr "Alex" Zinenko #define DEBUG_TYPE_MATCHER "transform-matcher" 194cb2ef4fSOleksandr "Alex" Zinenko #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ") 204cb2ef4fSOleksandr "Alex" Zinenko #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x) 214cb2ef4fSOleksandr "Alex" Zinenko 224cb2ef4fSOleksandr "Alex" Zinenko #define GET_OP_CLASSES 234cb2ef4fSOleksandr "Alex" Zinenko #include "MyExtension.cpp.inc" 244cb2ef4fSOleksandr "Alex" Zinenko 254cb2ef4fSOleksandr "Alex" Zinenko //===---------------------------------------------------------------------===// 264cb2ef4fSOleksandr "Alex" Zinenko // MyExtension 274cb2ef4fSOleksandr "Alex" Zinenko //===---------------------------------------------------------------------===// 284cb2ef4fSOleksandr "Alex" Zinenko 294cb2ef4fSOleksandr "Alex" Zinenko // Define a new transform dialect extension. This uses the CRTP idiom to 304cb2ef4fSOleksandr "Alex" Zinenko // identify extensions. 314cb2ef4fSOleksandr "Alex" Zinenko class MyExtension 324cb2ef4fSOleksandr "Alex" Zinenko : public ::mlir::transform::TransformDialectExtension<MyExtension> { 334cb2ef4fSOleksandr "Alex" Zinenko public: 34*84cc1865SNikhil Kalra // The TypeID of this extension. 35*84cc1865SNikhil Kalra MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension) 36*84cc1865SNikhil Kalra 374cb2ef4fSOleksandr "Alex" Zinenko // The extension must derive the base constructor. 384cb2ef4fSOleksandr "Alex" Zinenko using Base::Base; 394cb2ef4fSOleksandr "Alex" Zinenko 404cb2ef4fSOleksandr "Alex" Zinenko // This function initializes the extension, similarly to `initialize` in 414cb2ef4fSOleksandr "Alex" Zinenko // dialect definitions. List individual operations and dependent dialects 424cb2ef4fSOleksandr "Alex" Zinenko // here. 434cb2ef4fSOleksandr "Alex" Zinenko void init(); 444cb2ef4fSOleksandr "Alex" Zinenko }; 454cb2ef4fSOleksandr "Alex" Zinenko 464cb2ef4fSOleksandr "Alex" Zinenko void MyExtension::init() { 474cb2ef4fSOleksandr "Alex" Zinenko // Register the additional match operations with the dialect similarly to 484cb2ef4fSOleksandr "Alex" Zinenko // other transform operations. List all operations generated from ODS. This 494cb2ef4fSOleksandr "Alex" Zinenko // call will perform additional checks that the operations implement the 504cb2ef4fSOleksandr "Alex" Zinenko // transform and memory effect interfaces required by the dialect interpreter 514cb2ef4fSOleksandr "Alex" Zinenko // and assert if they do not. 524cb2ef4fSOleksandr "Alex" Zinenko registerTransformOps< 534cb2ef4fSOleksandr "Alex" Zinenko #define GET_OP_LIST 544cb2ef4fSOleksandr "Alex" Zinenko #include "MyExtension.cpp.inc" 554cb2ef4fSOleksandr "Alex" Zinenko >(); 564cb2ef4fSOleksandr "Alex" Zinenko } 574cb2ef4fSOleksandr "Alex" Zinenko 584cb2ef4fSOleksandr "Alex" Zinenko //===---------------------------------------------------------------------===// 594cb2ef4fSOleksandr "Alex" Zinenko // HasOperandSatisfyingOp 604cb2ef4fSOleksandr "Alex" Zinenko //===---------------------------------------------------------------------===// 614cb2ef4fSOleksandr "Alex" Zinenko 624cb2ef4fSOleksandr "Alex" Zinenko /// Returns `true` if both types implement one of the interfaces provided as 634cb2ef4fSOleksandr "Alex" Zinenko /// template parameters. 644cb2ef4fSOleksandr "Alex" Zinenko template <typename... Tys> 654cb2ef4fSOleksandr "Alex" Zinenko static bool implementSameInterface(mlir::Type t1, mlir::Type t2) { 664cb2ef4fSOleksandr "Alex" Zinenko return ((llvm::isa<Tys>(t1) && llvm::isa<Tys>(t2)) || ... || false); 674cb2ef4fSOleksandr "Alex" Zinenko } 684cb2ef4fSOleksandr "Alex" Zinenko 694cb2ef4fSOleksandr "Alex" Zinenko /// Returns `true` if both types implement one of the transform dialect 704cb2ef4fSOleksandr "Alex" Zinenko /// interfaces. 714cb2ef4fSOleksandr "Alex" Zinenko static bool implementSameTransformInterface(mlir::Type t1, mlir::Type t2) { 724cb2ef4fSOleksandr "Alex" Zinenko return implementSameInterface< 734cb2ef4fSOleksandr "Alex" Zinenko mlir::transform::TransformHandleTypeInterface, 744cb2ef4fSOleksandr "Alex" Zinenko mlir::transform::TransformParamTypeInterface, 754cb2ef4fSOleksandr "Alex" Zinenko mlir::transform::TransformValueHandleTypeInterface>(t1, t2); 764cb2ef4fSOleksandr "Alex" Zinenko } 774cb2ef4fSOleksandr "Alex" Zinenko 784cb2ef4fSOleksandr "Alex" Zinenko // Matcher ops implement `apply` similarly to other transform ops. They are not 794cb2ef4fSOleksandr "Alex" Zinenko // expected to modify payload, but use the tri-state result to signal failure or 804cb2ef4fSOleksandr "Alex" Zinenko // success to match, as well as potential irrecoverable errors. 814cb2ef4fSOleksandr "Alex" Zinenko mlir::DiagnosedSilenceableFailure 824cb2ef4fSOleksandr "Alex" Zinenko mlir::transform::HasOperandSatisfyingOp::apply( 834cb2ef4fSOleksandr "Alex" Zinenko mlir::transform::TransformRewriter &rewriter, 844cb2ef4fSOleksandr "Alex" Zinenko mlir::transform::TransformResults &results, 854cb2ef4fSOleksandr "Alex" Zinenko mlir::transform::TransformState &state) { 864cb2ef4fSOleksandr "Alex" Zinenko // For simplicity, only handle a single payload op. Actual implementations 874cb2ef4fSOleksandr "Alex" Zinenko // can use `SingleOpMatcher` trait to simplify implementation and document 884cb2ef4fSOleksandr "Alex" Zinenko // this expectation. 894cb2ef4fSOleksandr "Alex" Zinenko auto payloadOps = state.getPayloadOps(getOp()); 904cb2ef4fSOleksandr "Alex" Zinenko if (!llvm::hasSingleElement(payloadOps)) 914cb2ef4fSOleksandr "Alex" Zinenko return emitSilenceableError() << "expected single payload"; 924cb2ef4fSOleksandr "Alex" Zinenko 934cb2ef4fSOleksandr "Alex" Zinenko // Iterate over all operands of the payload op to see if they can be matched 944cb2ef4fSOleksandr "Alex" Zinenko // using the body of this op. 954cb2ef4fSOleksandr "Alex" Zinenko Operation *payload = *payloadOps.begin(); 964cb2ef4fSOleksandr "Alex" Zinenko for (OpOperand &operand : payload->getOpOperands()) { 974cb2ef4fSOleksandr "Alex" Zinenko // Create a scope for transform values defined in the body. This corresponds 984cb2ef4fSOleksandr "Alex" Zinenko // to the syntactic scope of the region attached to this op. Any values 994cb2ef4fSOleksandr "Alex" Zinenko // associated with payloads from now on will be automatically dissociated 1004cb2ef4fSOleksandr "Alex" Zinenko // when this object is destroyed, i.e. at the end of the iteration. 1014cb2ef4fSOleksandr "Alex" Zinenko // Associate the block argument handle with the operand. 1024cb2ef4fSOleksandr "Alex" Zinenko auto matchScope = state.make_region_scope(getBody()); 1034cb2ef4fSOleksandr "Alex" Zinenko if (failed(state.mapBlockArgument(getBody().getArgument(0), 1044cb2ef4fSOleksandr "Alex" Zinenko {operand.get()}))) { 1054cb2ef4fSOleksandr "Alex" Zinenko return DiagnosedSilenceableFailure::definiteFailure(); 1064cb2ef4fSOleksandr "Alex" Zinenko } 1074cb2ef4fSOleksandr "Alex" Zinenko 1084cb2ef4fSOleksandr "Alex" Zinenko // Iterate over all nested matchers with the current mapping and see if they 1094cb2ef4fSOleksandr "Alex" Zinenko // succeed. 1104cb2ef4fSOleksandr "Alex" Zinenko bool matchSucceeded = true; 1114cb2ef4fSOleksandr "Alex" Zinenko for (Operation &matcher : getBody().front().without_terminator()) { 1124cb2ef4fSOleksandr "Alex" Zinenko // Matcher ops are applied similarly to any other transform op. 1134cb2ef4fSOleksandr "Alex" Zinenko DiagnosedSilenceableFailure diag = 1144cb2ef4fSOleksandr "Alex" Zinenko state.applyTransform(cast<TransformOpInterface>(matcher)); 1154cb2ef4fSOleksandr "Alex" Zinenko 1164cb2ef4fSOleksandr "Alex" Zinenko // Definite failures are immediately propagated as they are irrecoverable. 1174cb2ef4fSOleksandr "Alex" Zinenko if (diag.isDefiniteFailure()) 1184cb2ef4fSOleksandr "Alex" Zinenko return diag; 1194cb2ef4fSOleksandr "Alex" Zinenko 1204cb2ef4fSOleksandr "Alex" Zinenko // On success, keep checking the remaining conditions. 1214cb2ef4fSOleksandr "Alex" Zinenko if (diag.succeeded()) 1224cb2ef4fSOleksandr "Alex" Zinenko continue; 1234cb2ef4fSOleksandr "Alex" Zinenko 1244cb2ef4fSOleksandr "Alex" Zinenko // Report failure-to-match for debugging purposes and stop matching this 1254cb2ef4fSOleksandr "Alex" Zinenko // operand. 1264cb2ef4fSOleksandr "Alex" Zinenko assert(diag.isSilenceableFailure()); 1274cb2ef4fSOleksandr "Alex" Zinenko DEBUG_MATCHER(DBGS_MATCHER() 1284cb2ef4fSOleksandr "Alex" Zinenko << "failed to match operand #" << operand.getOperandNumber() 1294cb2ef4fSOleksandr "Alex" Zinenko << ": " << diag.getMessage()); 1304cb2ef4fSOleksandr "Alex" Zinenko (void)diag.silence(); 1314cb2ef4fSOleksandr "Alex" Zinenko matchSucceeded = false; 1324cb2ef4fSOleksandr "Alex" Zinenko break; 1334cb2ef4fSOleksandr "Alex" Zinenko } 1344cb2ef4fSOleksandr "Alex" Zinenko // If failed to match this operand, try other operands. 1354cb2ef4fSOleksandr "Alex" Zinenko if (!matchSucceeded) 1364cb2ef4fSOleksandr "Alex" Zinenko continue; 1374cb2ef4fSOleksandr "Alex" Zinenko 1384cb2ef4fSOleksandr "Alex" Zinenko // If we reached this point, the matching succeeded for the current operand. 1394cb2ef4fSOleksandr "Alex" Zinenko // Remap the values associated with terminator operands to be associated 1404cb2ef4fSOleksandr "Alex" Zinenko // with op results, and also map the parameter result to the operand's 1414cb2ef4fSOleksandr "Alex" Zinenko // position. Note that it is safe to do here despite the end of the scope 1424cb2ef4fSOleksandr "Alex" Zinenko // as `results` are integrated into `state` by the interpreter after `apply` 1434cb2ef4fSOleksandr "Alex" Zinenko // returns rather than immediately. 1444cb2ef4fSOleksandr "Alex" Zinenko SmallVector<SmallVector<MappedValue>> yieldedMappings; 1454cb2ef4fSOleksandr "Alex" Zinenko transform::detail::prepareValueMappings( 1464cb2ef4fSOleksandr "Alex" Zinenko yieldedMappings, getBody().front().getTerminator()->getOperands(), 1474cb2ef4fSOleksandr "Alex" Zinenko state); 148a5757c5bSChristian Sigg results.setParams(cast<OpResult>(getPosition()), 1494cb2ef4fSOleksandr "Alex" Zinenko {rewriter.getI32IntegerAttr(operand.getOperandNumber())}); 1504cb2ef4fSOleksandr "Alex" Zinenko for (auto &&[result, mapping] : llvm::zip(getResults(), yieldedMappings)) 1514cb2ef4fSOleksandr "Alex" Zinenko results.setMappedValues(result, mapping); 1524cb2ef4fSOleksandr "Alex" Zinenko return DiagnosedSilenceableFailure::success(); 1534cb2ef4fSOleksandr "Alex" Zinenko } 1544cb2ef4fSOleksandr "Alex" Zinenko 1554cb2ef4fSOleksandr "Alex" Zinenko // If we reached this point, none of the operands succeeded the match. 1564cb2ef4fSOleksandr "Alex" Zinenko return emitSilenceableError() 1574cb2ef4fSOleksandr "Alex" Zinenko << "none of the operands satisfied the conditions"; 1584cb2ef4fSOleksandr "Alex" Zinenko } 1594cb2ef4fSOleksandr "Alex" Zinenko 1604cb2ef4fSOleksandr "Alex" Zinenko // By convention, operations implementing MatchOpInterface must not modify 1614cb2ef4fSOleksandr "Alex" Zinenko // payload IR and must therefore specify that they only read operand handles and 1624cb2ef4fSOleksandr "Alex" Zinenko // payload as their effects. 1634cb2ef4fSOleksandr "Alex" Zinenko void mlir::transform::HasOperandSatisfyingOp::getEffects( 1644cb2ef4fSOleksandr "Alex" Zinenko llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance> &effects) { 1654cb2ef4fSOleksandr "Alex" Zinenko onlyReadsPayload(effects); 1662c1ae801Sdonald chen onlyReadsHandle(getOpMutable(), effects); 1672c1ae801Sdonald chen producesHandle(getOperation()->getOpResults(), effects); 1684cb2ef4fSOleksandr "Alex" Zinenko } 1694cb2ef4fSOleksandr "Alex" Zinenko 1704cb2ef4fSOleksandr "Alex" Zinenko // Verify well-formedness of the operation and emit diagnostics if it is 1714cb2ef4fSOleksandr "Alex" Zinenko // ill-formed. 172db791b27SRamkumar Ramachandra llvm::LogicalResult mlir::transform::HasOperandSatisfyingOp::verify() { 1734cb2ef4fSOleksandr "Alex" Zinenko mlir::Block &bodyBlock = getBody().front(); 1744cb2ef4fSOleksandr "Alex" Zinenko if (bodyBlock.getNumArguments() != 1 || 1754cb2ef4fSOleksandr "Alex" Zinenko !isa<TransformValueHandleTypeInterface>( 1764cb2ef4fSOleksandr "Alex" Zinenko bodyBlock.getArgument(0).getType())) { 1774cb2ef4fSOleksandr "Alex" Zinenko return emitOpError() 1784cb2ef4fSOleksandr "Alex" Zinenko << "expects the body to have one value handle argument"; 1794cb2ef4fSOleksandr "Alex" Zinenko } 1804cb2ef4fSOleksandr "Alex" Zinenko if (bodyBlock.getTerminator()->getNumOperands() != getNumResults() - 1) { 1814cb2ef4fSOleksandr "Alex" Zinenko return emitOpError() << "expects the body to yield " 1824cb2ef4fSOleksandr "Alex" Zinenko << (getNumResults() - 1) << " values, got " 1834cb2ef4fSOleksandr "Alex" Zinenko << bodyBlock.getTerminator()->getNumOperands(); 1844cb2ef4fSOleksandr "Alex" Zinenko } 1854cb2ef4fSOleksandr "Alex" Zinenko for (auto &&[i, operand, result] : 1864cb2ef4fSOleksandr "Alex" Zinenko llvm::enumerate(bodyBlock.getTerminator()->getOperands().getTypes(), 1874cb2ef4fSOleksandr "Alex" Zinenko getResults().getTypes())) { 1884cb2ef4fSOleksandr "Alex" Zinenko if (implementSameTransformInterface(operand, result)) 1894cb2ef4fSOleksandr "Alex" Zinenko continue; 1904cb2ef4fSOleksandr "Alex" Zinenko return emitOpError() << "expects terminator operand #" << i 1914cb2ef4fSOleksandr "Alex" Zinenko << " and result #" << (i + 1) 1924cb2ef4fSOleksandr "Alex" Zinenko << " to implement the same transform interface"; 1934cb2ef4fSOleksandr "Alex" Zinenko } 1944cb2ef4fSOleksandr "Alex" Zinenko 1954cb2ef4fSOleksandr "Alex" Zinenko for (Operation &op : bodyBlock.without_terminator()) { 1964cb2ef4fSOleksandr "Alex" Zinenko if (!isa<TransformOpInterface>(op) || !isa<MatchOpInterface>(op)) { 1974cb2ef4fSOleksandr "Alex" Zinenko InFlightDiagnostic diag = emitOpError() 1984cb2ef4fSOleksandr "Alex" Zinenko << "expects body to contain match ops"; 1994cb2ef4fSOleksandr "Alex" Zinenko diag.attachNote(op.getLoc()) << "non-match operation"; 2004cb2ef4fSOleksandr "Alex" Zinenko return diag; 2014cb2ef4fSOleksandr "Alex" Zinenko } 2024cb2ef4fSOleksandr "Alex" Zinenko } 2034cb2ef4fSOleksandr "Alex" Zinenko 2044cb2ef4fSOleksandr "Alex" Zinenko return success(); 2054cb2ef4fSOleksandr "Alex" Zinenko } 2064cb2ef4fSOleksandr "Alex" Zinenko 2074cb2ef4fSOleksandr "Alex" Zinenko void registerMyExtension(::mlir::DialectRegistry ®istry) { 2084cb2ef4fSOleksandr "Alex" Zinenko registry.addExtensions<MyExtension>(); 2094cb2ef4fSOleksandr "Alex" Zinenko } 210