xref: /llvm-project/mlir/test/lib/Dialect/Test/TestDialect.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestOps.h"
11 #include "TestTypes.h"
12 #include "mlir/Bytecode/BytecodeImplementation.h"
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/AsmState.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/ExtensibleDialect.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/IR/ODSSupport.h"
23 #include "mlir/IR/OperationSupport.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/TypeUtilities.h"
26 #include "mlir/IR/Verifier.h"
27 #include "mlir/Interfaces/CallInterfaces.h"
28 #include "mlir/Interfaces/FunctionImplementation.h"
29 #include "mlir/Interfaces/InferIntRangeInterface.h"
30 #include "mlir/Support/LLVM.h"
31 #include "mlir/Transforms/FoldUtils.h"
32 #include "mlir/Transforms/InliningUtils.h"
33 #include "llvm/ADT/STLFunctionalExtras.h"
34 #include "llvm/ADT/SmallString.h"
35 #include "llvm/ADT/StringExtras.h"
36 #include "llvm/ADT/StringSwitch.h"
37 #include "llvm/Support/Base64.h"
38 #include "llvm/Support/Casting.h"
39 
40 #include "mlir/Dialect/Arith/IR/Arith.h"
41 #include "mlir/Dialect/DLTI/DLTI.h"
42 #include "mlir/Interfaces/FoldInterfaces.h"
43 #include "mlir/Reducer/ReductionPatternInterface.h"
44 #include "mlir/Transforms/InliningUtils.h"
45 #include <cstdint>
46 #include <numeric>
47 #include <optional>
48 
49 // Include this before the using namespace lines below to test that we don't
50 // have namespace dependencies.
51 #include "TestOpsDialect.cpp.inc"
52 
53 using namespace mlir;
54 using namespace test;
55 
56 //===----------------------------------------------------------------------===//
57 // PropertiesWithCustomPrint
58 //===----------------------------------------------------------------------===//
59 
60 LogicalResult
setPropertiesFromAttribute(PropertiesWithCustomPrint & prop,Attribute attr,function_ref<InFlightDiagnostic ()> emitError)61 test::setPropertiesFromAttribute(PropertiesWithCustomPrint &prop,
62                                  Attribute attr,
63                                  function_ref<InFlightDiagnostic()> emitError) {
64   DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
65   if (!dict) {
66     emitError() << "expected DictionaryAttr to set TestProperties";
67     return failure();
68   }
69   auto label = dict.getAs<mlir::StringAttr>("label");
70   if (!label) {
71     emitError() << "expected StringAttr for key `label`";
72     return failure();
73   }
74   auto valueAttr = dict.getAs<IntegerAttr>("value");
75   if (!valueAttr) {
76     emitError() << "expected IntegerAttr for key `value`";
77     return failure();
78   }
79 
80   prop.label = std::make_shared<std::string>(label.getValue());
81   prop.value = valueAttr.getValue().getSExtValue();
82   return success();
83 }
84 
85 DictionaryAttr
getPropertiesAsAttribute(MLIRContext * ctx,const PropertiesWithCustomPrint & prop)86 test::getPropertiesAsAttribute(MLIRContext *ctx,
87                                const PropertiesWithCustomPrint &prop) {
88   SmallVector<NamedAttribute> attrs;
89   Builder b{ctx};
90   attrs.push_back(b.getNamedAttr("label", b.getStringAttr(*prop.label)));
91   attrs.push_back(b.getNamedAttr("value", b.getI32IntegerAttr(prop.value)));
92   return b.getDictionaryAttr(attrs);
93 }
94 
computeHash(const PropertiesWithCustomPrint & prop)95 llvm::hash_code test::computeHash(const PropertiesWithCustomPrint &prop) {
96   return llvm::hash_combine(prop.value, StringRef(*prop.label));
97 }
98 
customPrintProperties(OpAsmPrinter & p,const PropertiesWithCustomPrint & prop)99 void test::customPrintProperties(OpAsmPrinter &p,
100                                  const PropertiesWithCustomPrint &prop) {
101   p.printKeywordOrString(*prop.label);
102   p << " is " << prop.value;
103 }
104 
customParseProperties(OpAsmParser & parser,PropertiesWithCustomPrint & prop)105 ParseResult test::customParseProperties(OpAsmParser &parser,
106                                         PropertiesWithCustomPrint &prop) {
107   std::string label;
108   if (parser.parseKeywordOrString(&label) || parser.parseKeyword("is") ||
109       parser.parseInteger(prop.value))
110     return failure();
111   prop.label = std::make_shared<std::string>(std::move(label));
112   return success();
113 }
114 
115 //===----------------------------------------------------------------------===//
116 // MyPropStruct
117 //===----------------------------------------------------------------------===//
118 
asAttribute(MLIRContext * ctx) const119 Attribute MyPropStruct::asAttribute(MLIRContext *ctx) const {
120   return StringAttr::get(ctx, content);
121 }
122 
123 LogicalResult
setFromAttr(MyPropStruct & prop,Attribute attr,function_ref<InFlightDiagnostic ()> emitError)124 MyPropStruct::setFromAttr(MyPropStruct &prop, Attribute attr,
125                           function_ref<InFlightDiagnostic()> emitError) {
126   StringAttr strAttr = dyn_cast<StringAttr>(attr);
127   if (!strAttr) {
128     emitError() << "Expect StringAttr but got " << attr;
129     return failure();
130   }
131   prop.content = strAttr.getValue();
132   return success();
133 }
134 
hash() const135 llvm::hash_code MyPropStruct::hash() const {
136   return hash_value(StringRef(content));
137 }
138 
readFromMlirBytecode(DialectBytecodeReader & reader,MyPropStruct & prop)139 LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader,
140                                          MyPropStruct &prop) {
141   StringRef str;
142   if (failed(reader.readString(str)))
143     return failure();
144   prop.content = str.str();
145   return success();
146 }
147 
writeToMlirBytecode(DialectBytecodeWriter & writer,MyPropStruct & prop)148 void test::writeToMlirBytecode(DialectBytecodeWriter &writer,
149                                MyPropStruct &prop) {
150   writer.writeOwnedString(prop.content);
151 }
152 
153 //===----------------------------------------------------------------------===//
154 // VersionedProperties
155 //===----------------------------------------------------------------------===//
156 
157 LogicalResult
setPropertiesFromAttribute(VersionedProperties & prop,Attribute attr,function_ref<InFlightDiagnostic ()> emitError)158 test::setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
159                                  function_ref<InFlightDiagnostic()> emitError) {
160   DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
161   if (!dict) {
162     emitError() << "expected DictionaryAttr to set VersionedProperties";
163     return failure();
164   }
165   auto value1Attr = dict.getAs<IntegerAttr>("value1");
166   if (!value1Attr) {
167     emitError() << "expected IntegerAttr for key `value1`";
168     return failure();
169   }
170   auto value2Attr = dict.getAs<IntegerAttr>("value2");
171   if (!value2Attr) {
172     emitError() << "expected IntegerAttr for key `value2`";
173     return failure();
174   }
175 
176   prop.value1 = value1Attr.getValue().getSExtValue();
177   prop.value2 = value2Attr.getValue().getSExtValue();
178   return success();
179 }
180 
getPropertiesAsAttribute(MLIRContext * ctx,const VersionedProperties & prop)181 DictionaryAttr test::getPropertiesAsAttribute(MLIRContext *ctx,
182                                               const VersionedProperties &prop) {
183   SmallVector<NamedAttribute> attrs;
184   Builder b{ctx};
185   attrs.push_back(b.getNamedAttr("value1", b.getI32IntegerAttr(prop.value1)));
186   attrs.push_back(b.getNamedAttr("value2", b.getI32IntegerAttr(prop.value2)));
187   return b.getDictionaryAttr(attrs);
188 }
189 
computeHash(const VersionedProperties & prop)190 llvm::hash_code test::computeHash(const VersionedProperties &prop) {
191   return llvm::hash_combine(prop.value1, prop.value2);
192 }
193 
customPrintProperties(OpAsmPrinter & p,const VersionedProperties & prop)194 void test::customPrintProperties(OpAsmPrinter &p,
195                                  const VersionedProperties &prop) {
196   p << prop.value1 << " | " << prop.value2;
197 }
198 
customParseProperties(OpAsmParser & parser,VersionedProperties & prop)199 ParseResult test::customParseProperties(OpAsmParser &parser,
200                                         VersionedProperties &prop) {
201   if (parser.parseInteger(prop.value1) || parser.parseVerticalBar() ||
202       parser.parseInteger(prop.value2))
203     return failure();
204   return success();
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // Bytecode Support
209 //===----------------------------------------------------------------------===//
210 
readFromMlirBytecode(DialectBytecodeReader & reader,MutableArrayRef<int64_t> prop)211 LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader,
212                                          MutableArrayRef<int64_t> prop) {
213   uint64_t size;
214   if (failed(reader.readVarInt(size)))
215     return failure();
216   if (size != prop.size())
217     return reader.emitError("array size mismach when reading properties: ")
218            << size << " vs expected " << prop.size();
219   for (auto &elt : prop) {
220     uint64_t value;
221     if (failed(reader.readVarInt(value)))
222       return failure();
223     elt = value;
224   }
225   return success();
226 }
227 
writeToMlirBytecode(DialectBytecodeWriter & writer,ArrayRef<int64_t> prop)228 void test::writeToMlirBytecode(DialectBytecodeWriter &writer,
229                                ArrayRef<int64_t> prop) {
230   writer.writeVarInt(prop.size());
231   for (auto elt : prop)
232     writer.writeVarInt(elt);
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // Dynamic operations
237 //===----------------------------------------------------------------------===//
238 
getDynamicGenericOp(TestDialect * dialect)239 std::unique_ptr<DynamicOpDefinition> getDynamicGenericOp(TestDialect *dialect) {
240   return DynamicOpDefinition::get(
241       "dynamic_generic", dialect, [](Operation *op) { return success(); },
242       [](Operation *op) { return success(); });
243 }
244 
245 std::unique_ptr<DynamicOpDefinition>
getDynamicOneOperandTwoResultsOp(TestDialect * dialect)246 getDynamicOneOperandTwoResultsOp(TestDialect *dialect) {
247   return DynamicOpDefinition::get(
248       "dynamic_one_operand_two_results", dialect,
249       [](Operation *op) {
250         if (op->getNumOperands() != 1) {
251           op->emitOpError()
252               << "expected 1 operand, but had " << op->getNumOperands();
253           return failure();
254         }
255         if (op->getNumResults() != 2) {
256           op->emitOpError()
257               << "expected 2 results, but had " << op->getNumResults();
258           return failure();
259         }
260         return success();
261       },
262       [](Operation *op) { return success(); });
263 }
264 
265 std::unique_ptr<DynamicOpDefinition>
getDynamicCustomParserPrinterOp(TestDialect * dialect)266 getDynamicCustomParserPrinterOp(TestDialect *dialect) {
267   auto verifier = [](Operation *op) {
268     if (op->getNumOperands() == 0 && op->getNumResults() == 0)
269       return success();
270     op->emitError() << "operation should have no operands and no results";
271     return failure();
272   };
273   auto regionVerifier = [](Operation *op) { return success(); };
274 
275   auto parser = [](OpAsmParser &parser, OperationState &state) {
276     return parser.parseKeyword("custom_keyword");
277   };
278 
279   auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) {
280     printer << op->getName() << " custom_keyword";
281   };
282 
283   return DynamicOpDefinition::get("dynamic_custom_parser_printer", dialect,
284                                   verifier, regionVerifier, parser, printer);
285 }
286 
287 //===----------------------------------------------------------------------===//
288 // TestDialect
289 //===----------------------------------------------------------------------===//
290 
registerTestDialect(DialectRegistry & registry)291 void test::registerTestDialect(DialectRegistry &registry) {
292   registry.insert<TestDialect>();
293 }
294 
testSideEffectOpGetEffect(Operation * op,SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> & effects)295 void test::testSideEffectOpGetEffect(
296     Operation *op,
297     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
298         &effects) {
299   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
300   if (!effectsAttr)
301     return;
302 
303   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
304 }
305 
306 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
307 struct TestOpEffectInterfaceFallback
308     : public TestEffectOpInterface::FallbackModel<
309           TestOpEffectInterfaceFallback> {
classofTestOpEffectInterfaceFallback310   static bool classof(Operation *op) {
311     bool isSupportedOp =
312         op->getName().getStringRef() == "test.unregistered_side_effect_op";
313     assert(isSupportedOp && "Unexpected dispatch");
314     return isSupportedOp;
315   }
316 
317   void
getEffectsTestOpEffectInterfaceFallback318   getEffects(Operation *op,
319              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
320                  &effects) const {
321     testSideEffectOpGetEffect(op, effects);
322   }
323 };
324 
initialize()325 void TestDialect::initialize() {
326   registerAttributes();
327   registerTypes();
328   registerOpsSyntax();
329   addOperations<ManualCppOpWithFold>();
330   registerTestDialectOperations(this);
331   registerDynamicOp(getDynamicGenericOp(this));
332   registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
333   registerDynamicOp(getDynamicCustomParserPrinterOp(this));
334   registerInterfaces();
335   allowUnknownOperations();
336 
337   // Instantiate our fallback op interface that we'll use on specific
338   // unregistered op.
339   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
340 }
341 
~TestDialect()342 TestDialect::~TestDialect() {
343   delete static_cast<TestOpEffectInterfaceFallback *>(
344       fallbackEffectOpInterfaces);
345 }
346 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)347 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
348                                             Type type, Location loc) {
349   return builder.create<TestOpConstant>(loc, type, value);
350 }
351 
getRegisteredInterfaceForOp(TypeID typeID,OperationName opName)352 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
353                                                OperationName opName) {
354   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
355       typeID == TypeID::get<TestEffectOpInterface>())
356     return fallbackEffectOpInterfaces;
357   return nullptr;
358 }
359 
verifyOperationAttribute(Operation * op,NamedAttribute namedAttr)360 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
361                                                     NamedAttribute namedAttr) {
362   if (namedAttr.getName() == "test.invalid_attr")
363     return op->emitError() << "invalid to use 'test.invalid_attr'";
364   return success();
365 }
366 
verifyRegionArgAttribute(Operation * op,unsigned regionIndex,unsigned argIndex,NamedAttribute namedAttr)367 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
368                                                     unsigned regionIndex,
369                                                     unsigned argIndex,
370                                                     NamedAttribute namedAttr) {
371   if (namedAttr.getName() == "test.invalid_attr")
372     return op->emitError() << "invalid to use 'test.invalid_attr'";
373   return success();
374 }
375 
376 LogicalResult
verifyRegionResultAttribute(Operation * op,unsigned regionIndex,unsigned resultIndex,NamedAttribute namedAttr)377 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
378                                          unsigned resultIndex,
379                                          NamedAttribute namedAttr) {
380   if (namedAttr.getName() == "test.invalid_attr")
381     return op->emitError() << "invalid to use 'test.invalid_attr'";
382   return success();
383 }
384 
385 std::optional<Dialect::ParseOpHook>
getParseOperationHook(StringRef opName) const386 TestDialect::getParseOperationHook(StringRef opName) const {
387   if (opName == "test.dialect_custom_printer") {
388     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
389       return parser.parseKeyword("custom_format");
390     }};
391   }
392   if (opName == "test.dialect_custom_format_fallback") {
393     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
394       return parser.parseKeyword("custom_format_fallback");
395     }};
396   }
397   if (opName == "test.dialect_custom_printer.with.dot") {
398     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
399       return ParseResult::success();
400     }};
401   }
402   return std::nullopt;
403 }
404 
405 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
getOperationPrinter(Operation * op) const406 TestDialect::getOperationPrinter(Operation *op) const {
407   StringRef opName = op->getName().getStringRef();
408   if (opName == "test.dialect_custom_printer") {
409     return [](Operation *op, OpAsmPrinter &printer) {
410       printer.getStream() << " custom_format";
411     };
412   }
413   if (opName == "test.dialect_custom_format_fallback") {
414     return [](Operation *op, OpAsmPrinter &printer) {
415       printer.getStream() << " custom_format_fallback";
416     };
417   }
418   return {};
419 }
420 
421 static LogicalResult
dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,PatternRewriter & rewriter)422 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
423                                PatternRewriter &rewriter) {
424   rewriter.replaceOpWithNewOp<arith::ConstantOp>(
425       op, rewriter.getI32IntegerAttr(42));
426   return success();
427 }
428 
getCanonicalizationPatterns(RewritePatternSet & results) const429 void TestDialect::getCanonicalizationPatterns(
430     RewritePatternSet &results) const {
431   results.add(&dialectCanonicalizationPattern);
432 }
433