xref: /llvm-project/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp (revision 063e0bd52ac7a25b5d7073a9904f8be6a38220b3)
1 
2 //===- DLTITransformOps.cpp - Implementation of DLTI transform ops --------===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
11 
12 #include "mlir/Dialect/DLTI/DLTI.h"
13 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
14 #include "mlir/Dialect/Transform/Utils/Utils.h"
15 #include "mlir/Interfaces/DataLayoutInterfaces.h"
16 
17 using namespace mlir;
18 using namespace mlir::transform;
19 
20 #define DEBUG_TYPE "dlti-transforms"
21 
22 //===----------------------------------------------------------------------===//
23 // QueryOp
24 //===----------------------------------------------------------------------===//
25 
26 void transform::QueryOp::getEffects(
27     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
28   onlyReadsHandle(getTargetMutable(), effects);
29   producesHandle(getOperation()->getOpResults(), effects);
30   onlyReadsPayload(effects);
31 }
32 
33 DiagnosedSilenceableFailure transform::QueryOp::applyToOne(
34     transform::TransformRewriter &rewriter, Operation *target,
35     transform::ApplyToEachResultList &results, TransformState &state) {
36   SmallVector<DataLayoutEntryKey> keys;
37   for (Attribute key : getKeys()) {
38     if (auto strKey = dyn_cast<StringAttr>(key))
39       keys.push_back(strKey);
40     else if (auto typeKey = dyn_cast<TypeAttr>(key))
41       keys.push_back(typeKey.getValue());
42     else
43       return emitDefiniteFailure("'transform.dlti.query' keys of wrong type: "
44                                  "only StringAttr and TypeAttr are allowed");
45   }
46 
47   FailureOr<Attribute> result = dlti::query(target, keys, /*emitError=*/true);
48 
49   if (failed(result))
50     return emitSilenceableFailure(getLoc(),
51                                   "'transform.dlti.query' op failed to apply");
52 
53   results.push_back(*result);
54   return DiagnosedSilenceableFailure::success();
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // Transform op registration
59 //===----------------------------------------------------------------------===//
60 
61 namespace {
62 class DLTITransformDialectExtension
63     : public transform::TransformDialectExtension<
64           DLTITransformDialectExtension> {
65 public:
66   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DLTITransformDialectExtension)
67 
68   using Base::Base;
69 
70   void init() {
71     registerTransformOps<
72 #define GET_OP_LIST
73 #include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.cpp.inc"
74         >();
75   }
76 };
77 } // namespace
78 
79 #define GET_OP_CLASSES
80 #include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.cpp.inc"
81 
82 void mlir::dlti::registerTransformDialectExtension(DialectRegistry &registry) {
83   registry.addExtensions<DLTITransformDialectExtension>();
84 }
85