1 //===- mlir-linalg-ods-yaml-gen.cpp - Linalg ODS generation from yaml ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements an ODS (and C++) generator from a YAML form 10 // derived from the mathematical expression of linalg named ops. Typically a 11 // math oriented DSL will be used to export the essential representation to 12 // this form, and maintaining the SOT at the math level (versus recreating it 13 // in MLIR) is deemed to have systemic value. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/AsmParser/AsmParser.h" 18 #include "mlir/IR/AffineMap.h" 19 #include "mlir/IR/Diagnostics.h" 20 #include "mlir/IR/MLIRContext.h" 21 #include "mlir/Support/FileUtilities.h" 22 #include "mlir/Support/LLVM.h" 23 #include "llvm/ADT/StringRef.h" 24 #include "llvm/Support/CommandLine.h" 25 #include "llvm/Support/Debug.h" 26 #include "llvm/Support/FormatVariadic.h" 27 #include "llvm/Support/ToolOutputFile.h" 28 #include "llvm/Support/YAMLTraits.h" 29 #include <optional> 30 31 using namespace mlir; 32 33 using llvm::yaml::Input; 34 35 #define DEBUG_TYPE "linalg-ods-gen" 36 37 //===----------------------------------------------------------------------===// 38 // Mapping structs (correspond to data types in the YAML description). 39 // TODO: Since this is a schema/part of the contract, it should be moved to 40 // a real header. 41 //===----------------------------------------------------------------------===// 42 43 namespace { 44 45 struct LinalgYAMLContext { 46 MLIRContext *mlirContext; 47 }; 48 49 struct LinalgOpMetadata { 50 std::string name; 51 std::string cppClassName; 52 std::optional<std::string> doc; 53 SmallVector<std::string> implements; 54 SmallVector<std::string> defines; 55 }; 56 57 struct SerializedAffineMap { 58 AffineMapAttr affineMapAttr; 59 60 AffineMap affineMap() { return affineMapAttr.getValue(); } 61 }; 62 63 enum class LinalgOperandDefKind { 64 InputTensor, 65 Scalar, 66 OutputTensor, 67 IndexAttr, 68 UnaryFnAttr, 69 BinaryFnAttr, 70 TernaryFnAttr, 71 TypeFnAttr 72 }; 73 74 struct LinalgOperandDef { 75 std::string name; 76 LinalgOperandDefKind kind; 77 std::optional<std::string> typeVar; 78 std::optional<SerializedAffineMap> shapeMap; 79 std::optional<SerializedAffineMap> indexAttrMap; 80 std::optional<SmallVector<int64_t>> defaultIndices; 81 std::optional<std::string> defaultFn; 82 }; 83 84 enum class LinalgIteratorTypeDef { 85 parallel, 86 reduction, 87 }; 88 89 struct LinalgIndexingMapsConfig { 90 std::optional<SmallVector<SerializedAffineMap>> staticIndexingMaps; 91 }; 92 93 struct ScalarExpression; 94 95 enum class ScalarFnKind { Unary, Binary, Ternary, Type }; 96 97 struct ScalarFn { 98 ScalarFnKind kind; 99 std::optional<std::string> fnName; 100 std::optional<std::string> attrName; 101 std::optional<std::string> typeVar; 102 // NOTE: This must be of arity 1, but to break the self-referential cycle, 103 // we use a heap allocated vector. 104 std::vector<ScalarExpression> operands; 105 }; 106 107 struct ScalarExpression { 108 std::optional<std::string> arg; 109 std::optional<std::string> constant; 110 std::optional<int64_t> index; 111 std::optional<ScalarFn> scalarFn; 112 }; 113 114 struct ScalarAssign { 115 std::string arg; 116 ScalarExpression value; 117 }; 118 119 struct LinalgStructuredOpConfig { 120 SmallVector<LinalgOperandDef> args; 121 LinalgIndexingMapsConfig indexingMaps; 122 SmallVector<LinalgIteratorTypeDef> iteratorTypes; 123 std::vector<ScalarAssign> assignments; 124 }; 125 126 struct LinalgOpConfig { 127 std::optional<LinalgOpMetadata> metadata; 128 std::optional<LinalgStructuredOpConfig> structuredOp; 129 }; 130 131 } // namespace 132 133 //===----------------------------------------------------------------------===// 134 // Mapping traits. 135 //===----------------------------------------------------------------------===// 136 137 LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgOperandDef) 138 LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap) 139 LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef) 140 LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign) 141 LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarExpression) 142 LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(LinalgOpConfig) 143 144 namespace llvm { 145 namespace yaml { 146 147 /// Top-level type containing op metadata and one of a concrete op type. 148 /// Currently, the only defined op type is `structured_op` (maps to 149 /// `LinalgStructuredOpConfig`). 150 template <> 151 struct MappingTraits<LinalgOpConfig> { 152 static void mapping(IO &io, LinalgOpConfig &info) { 153 io.mapOptional("metadata", info.metadata); 154 io.mapOptional("structured_op", info.structuredOp); 155 } 156 }; 157 158 /// A structured op models (at most) a single contraction by modeling 159 /// - A list of named arguments (`LinalgOperandDef`), which can be inputs, 160 /// outputs, or index attributes. 161 /// - List of indexing maps (see `LinalgIndexingMaps`). 162 /// - Iterator types (see `LinalgIteratorTypeDef`). 163 /// - List of scalar level assignment (see `ScalarAssign`). 164 template <> 165 struct MappingTraits<LinalgStructuredOpConfig> { 166 static void mapping(IO &io, LinalgStructuredOpConfig &info) { 167 io.mapRequired("args", info.args); 168 io.mapRequired("indexing_maps", info.indexingMaps); 169 io.mapRequired("iterator_types", info.iteratorTypes); 170 io.mapRequired("assignments", info.assignments); 171 } 172 }; 173 174 /// Maps a named tensor, scalar or attribute argument to an operation, 175 /// consisting of: 176 /// - `name`: Must be unique within the operation. 177 /// - `usage`: How the argument is used (input, output, attribute, etc). 178 /// - `type_var`: The symbolic type variable that binds to the element or self 179 /// type of the tensor or scalar argument, respectively. 180 /// - `shape_map`: An optional AffineMap from all op symbols to the shape of 181 /// the argument. Only tensor arguments have a `shape_map`. Each shape must 182 /// be normalized over the same list of symbols and have no dimension 183 /// inputs. 184 /// - `index_attr_map`: An optional AffineMap from all op symbols to the 185 /// index attribute symbols. During op creation these symbols are replaced 186 /// by the corresponding `name` index attribue values. Only index attribute 187 /// arguments have an `index_attr_map`. 188 /// - `default_indices`: An optional default initialization for index 189 /// attribute arguments. 190 /// - `default_fn`: An optional default initialization for function attribute 191 /// arguments. 192 template <> 193 struct MappingTraits<LinalgOperandDef> { 194 static void mapping(IO &io, LinalgOperandDef &info) { 195 io.mapRequired("name", info.name); 196 io.mapRequired("kind", info.kind); 197 io.mapOptional("type_var", info.typeVar); 198 io.mapOptional("shape_map", info.shapeMap); 199 io.mapOptional("index_attr_map", info.indexAttrMap); 200 io.mapOptional("default_indices", info.defaultIndices); 201 io.mapOptional("default_fn", info.defaultFn); 202 } 203 }; 204 205 /// Usage enum for a named argument. 206 template <> 207 struct ScalarEnumerationTraits<LinalgOperandDefKind> { 208 static void enumeration(IO &io, LinalgOperandDefKind &value) { 209 io.enumCase(value, "input_tensor", LinalgOperandDefKind::InputTensor); 210 io.enumCase(value, "scalar", LinalgOperandDefKind::Scalar); 211 io.enumCase(value, "output_tensor", LinalgOperandDefKind::OutputTensor); 212 io.enumCase(value, "index_attr", LinalgOperandDefKind::IndexAttr); 213 io.enumCase(value, "unary_fn_attr", LinalgOperandDefKind::UnaryFnAttr); 214 io.enumCase(value, "binary_fn_attr", LinalgOperandDefKind::BinaryFnAttr); 215 io.enumCase(value, "ternary_fn_attr", LinalgOperandDefKind::TernaryFnAttr); 216 io.enumCase(value, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr); 217 } 218 }; 219 220 /// Iterator type enum. 221 template <> 222 struct ScalarEnumerationTraits<LinalgIteratorTypeDef> { 223 static void enumeration(IO &io, LinalgIteratorTypeDef &value) { 224 io.enumCase(value, "parallel", LinalgIteratorTypeDef::parallel); 225 io.enumCase(value, "reduction", LinalgIteratorTypeDef::reduction); 226 } 227 }; 228 229 /// Metadata about the op (name, C++ name, and documentation). 230 template <> 231 struct MappingTraits<LinalgOpMetadata> { 232 static void mapping(IO &io, LinalgOpMetadata &info) { 233 io.mapRequired("name", info.name); 234 io.mapRequired("cpp_class_name", info.cppClassName); 235 io.mapOptional("doc", info.doc); 236 io.mapOptional("implements", info.implements); 237 io.mapOptional("defines", info.defines); 238 } 239 }; 240 241 /// How the ops indexing maps are produced. Must be one of: 242 /// - static_indexing_maps: A static list of AffineMaps, possibly with 243 /// some symbols that bind to attributes of the op. Each indexing map must 244 /// be normalized over the same list of dimensions, and its symbols must 245 /// match the symbols for argument shapes. 246 template <> 247 struct MappingTraits<LinalgIndexingMapsConfig> { 248 static void mapping(IO &io, LinalgIndexingMapsConfig &info) { 249 io.mapOptional("static_indexing_maps", info.staticIndexingMaps); 250 } 251 }; 252 253 /// Models an assignment to a named output. 254 /// - The `arg` name must match a named output. 255 /// - The `value` is a scalar expression for computing the value to 256 /// assign (see `ScalarExpression`). 257 template <> 258 struct MappingTraits<ScalarAssign> { 259 static void mapping(IO &io, ScalarAssign &info) { 260 io.mapRequired("arg", info.arg); 261 io.mapRequired("value", info.value); 262 } 263 }; 264 265 /// A scalar expression (RHS of an assignment). Must be one of: 266 /// - `scalar_arg`: An operation argument. 267 /// - `scalar_const`: A constant definition. 268 /// - `scalar_index`: An iteration index. 269 /// - `scalar_fn`: A named function (see `ScalarFn`). 270 template <> 271 struct MappingTraits<ScalarExpression> { 272 static void mapping(IO &io, ScalarExpression &info) { 273 io.mapOptional("scalar_arg", info.arg); 274 io.mapOptional("scalar_const", info.constant); 275 io.mapOptional("scalar_index", info.index); 276 io.mapOptional("scalar_fn", info.scalarFn); 277 } 278 }; 279 280 /// Scalar function kind enum. 281 template <> 282 struct ScalarEnumerationTraits<ScalarFnKind> { 283 static void enumeration(IO &io, ScalarFnKind &value) { 284 io.enumCase(value, "unary", ScalarFnKind::Unary); 285 io.enumCase(value, "binary", ScalarFnKind::Binary); 286 io.enumCase(value, "ternary", ScalarFnKind::Ternary); 287 io.enumCase(value, "type", ScalarFnKind::Type); 288 } 289 }; 290 291 /// A scalar expression that evaluates a named function. 292 /// Functions are generally "math" level and type polymorphic. Builtin 293 /// functions include: 294 /// - `add(lhs, rhs)` 295 /// - `mul(lhs, rhs)` 296 template <> 297 struct MappingTraits<ScalarFn> { 298 static void mapping(IO &io, ScalarFn &info) { 299 io.mapRequired("kind", info.kind); 300 io.mapOptional("fn_name", info.fnName); 301 io.mapOptional("attr_name", info.attrName); 302 io.mapOptional("type_var", info.typeVar); 303 io.mapRequired("operands", info.operands); 304 } 305 }; 306 307 /// Helper mapping which accesses an AffineMapAttr as a serialized string of 308 /// the same. 309 template <> 310 struct ScalarTraits<SerializedAffineMap> { 311 static void output(const SerializedAffineMap &value, void *rawYamlContext, 312 raw_ostream &out) { 313 assert(value.affineMapAttr); 314 value.affineMapAttr.print(out); 315 } 316 static StringRef input(StringRef scalar, void *rawYamlContext, 317 SerializedAffineMap &value) { 318 assert(rawYamlContext); 319 auto *yamlContext = static_cast<LinalgYAMLContext *>(rawYamlContext); 320 if (auto attr = dyn_cast_or_null<AffineMapAttr>( 321 mlir::parseAttribute(scalar, yamlContext->mlirContext))) 322 value.affineMapAttr = attr; 323 else if (!value.affineMapAttr || !isa<AffineMapAttr>(value.affineMapAttr)) 324 return "could not parse as an affine map attribute"; 325 return StringRef(); 326 } 327 static QuotingType mustQuote(StringRef) { return QuotingType::None; } 328 }; 329 330 } // namespace yaml 331 } // namespace llvm 332 333 namespace { 334 335 //===----------------------------------------------------------------------===// 336 // Generation utilities 337 //===----------------------------------------------------------------------===// 338 339 class GenerationContext { 340 public: 341 GenerationContext(MLIRContext *context, raw_ostream *odsOut, 342 raw_ostream *defnOut) 343 : context(context), loc(UnknownLoc::get(context)), odsOut(odsOut), 344 defnOut(defnOut) {} 345 346 MLIRContext *getContext() { return context; } 347 348 void setLoc(Location loc) { this->loc = loc; } 349 Location getLoc() { return loc; } 350 351 bool shouldGenerateOds() { return odsOut; } 352 bool shouldGenerateDefns() { return defnOut; } 353 354 raw_ostream &odss() { 355 assert(odsOut && "ODS stream not defined"); 356 return *odsOut; 357 } 358 359 raw_ostream &defns() { 360 assert(defnOut && "Definition stream not defined"); 361 return *defnOut; 362 } 363 364 private: 365 MLIRContext *context; 366 Location loc; 367 raw_ostream *odsOut; 368 raw_ostream *defnOut; 369 }; 370 371 } // namespace 372 373 static std::string generateCppExpression(SerializedAffineMap self, 374 StringRef contextName) { 375 std::string printedStr; 376 llvm::raw_string_ostream printedSs(printedStr); 377 self.affineMapAttr.print(printedSs); 378 379 static const char exprFormat[] = 380 R"FMT(llvm::cast<AffineMapAttr>(mlir::parseAttribute("{0}", {1})).getValue())FMT"; 381 return llvm::formatv(exprFormat, printedStr, contextName); 382 } 383 384 template <typename Container> 385 static std::string interleaveToString(Container &container, 386 StringRef separator) { 387 std::string result; 388 llvm::raw_string_ostream ss(result); 389 llvm::interleave(container, ss, separator); 390 return result; 391 } 392 393 static std::optional<int> 394 findTensorDefArgIndex(StringRef name, SmallVectorImpl<LinalgOperandDef> &args) { 395 for (const auto &it : llvm::enumerate(args)) { 396 if (it.value().name == name) 397 return it.index(); 398 } 399 return std::nullopt; 400 } 401 402 // Try to map the TypeVar to a predefined or an argument type. 403 static std::optional<std::string> 404 findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgOperandDef> &args) { 405 // Handle all predefined types. 406 if (typeVar == "I32") 407 return std::string("helper.getIntegerType(32)"); 408 if (typeVar == "I64") 409 return std::string("helper.getIntegerType(64)"); 410 if (typeVar == "F32") 411 return std::string("helper.getFloat32Type()"); 412 if (typeVar == "F64") 413 return std::string("helper.getFloat64Type()"); 414 415 // Search all argument types. 416 for (const auto &it : llvm::enumerate(args)) { 417 if (it.value().kind != LinalgOperandDefKind::InputTensor && 418 it.value().kind != LinalgOperandDefKind::Scalar && 419 it.value().kind != LinalgOperandDefKind::OutputTensor) 420 continue; 421 if (*it.value().typeVar == typeVar) 422 return llvm::formatv("block.getArgument({0}).getType()", it.index()) 423 .str(); 424 } 425 426 return std::nullopt; 427 } 428 429 static ScalarAssign *findAssignment(StringRef name, 430 std::vector<ScalarAssign> &assignments) { 431 for (auto &assign : assignments) { 432 if (assign.arg == name) 433 return &assign; 434 } 435 return nullptr; 436 } 437 438 // Return true if the operand is a function attribute. 439 static bool isFunctionAttribute(LinalgOperandDefKind kind) { 440 return kind == LinalgOperandDefKind::UnaryFnAttr || 441 kind == LinalgOperandDefKind::BinaryFnAttr || 442 kind == LinalgOperandDefKind::TernaryFnAttr || 443 kind == LinalgOperandDefKind::TypeFnAttr; 444 } 445 446 // Return true if the operand is an attribute. 447 static bool isAttribute(LinalgOperandDefKind kind) { 448 return kind == LinalgOperandDefKind::IndexAttr || isFunctionAttribute(kind); 449 } 450 451 // Get the enum name for the given operand kind. 452 std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) { 453 switch (kind) { 454 case LinalgOperandDefKind::UnaryFnAttr: 455 return std::string("UnaryFn"); 456 case LinalgOperandDefKind::BinaryFnAttr: 457 return std::string("BinaryFn"); 458 case LinalgOperandDefKind::TernaryFnAttr: 459 return std::string("TernaryFn"); 460 case LinalgOperandDefKind::TypeFnAttr: 461 return std::string("TypeFn"); 462 default: 463 break; 464 } 465 llvm_unreachable("unsupported function attribute kind"); 466 } 467 468 // Get the enum name for the given function kind. 469 std::string convertFunctionKindToEnumName(ScalarFnKind kind) { 470 switch (kind) { 471 case ScalarFnKind::Unary: 472 return std::string("UnaryFn"); 473 case ScalarFnKind::Binary: 474 return std::string("BinaryFn"); 475 case ScalarFnKind::Ternary: 476 return std::string("TernaryFn"); 477 case ScalarFnKind::Type: 478 return std::string("TypeFn"); 479 } 480 llvm_unreachable("unsupported function kind"); 481 } 482 483 //===----------------------------------------------------------------------===// 484 // Templates 485 //===----------------------------------------------------------------------===// 486 487 // A single line banner format. Parameters: 488 // {0}: Single line comment 489 static const char bannerFormat[] = R"FMT( 490 //===----------------------------------------------------------------------===// 491 // {0} 492 //===----------------------------------------------------------------------===// 493 )FMT"; 494 495 //===----------------------------------------------------------------------===// 496 // Named generic op generation. 497 // These ops map at most a single contraction that complies with the limitations 498 // of a linalg.generic. 499 //===----------------------------------------------------------------------===// 500 501 // Template for Linalg named ops' ODS definitions. Parameters: 502 // {0}: ODS/C++ op name 503 // {1}: assembly op mnemonic 504 // {2}: op interface list 505 // {3}: documentation (summary + description) 506 // {4}: op attribute list 507 // {5}: builder methods taking standalone attribute parameters 508 // {6}: additional method defintions 509 // {7}: additional methods for attributes used by indexing maps 510 static const char structuredOpOdsHeaderFormat[] = R"FMT( 511 //===----------------------------------------------------------------------===// 512 // Op definition for {0} 513 //===----------------------------------------------------------------------===// 514 515 def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments], 516 /*extraInterfaces=*/[{2}])> { 517 {3} 518 let arguments = (ins 519 Variadic<AnyType>:$inputs, 520 Variadic<AnyShaped>:$outputs{4} 521 ); 522 let results = (outs Variadic<AnyRankedTensor>:$result_tensors); 523 let regions = (region AnyRegion:$region); 524 525 let skipDefaultBuilders = 1; 526 let builders = [ 527 OpBuilder< 528 (ins "ValueRange":$inputs, "ValueRange":$outputs, 529 CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes), 530 [{{ 531 buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs, 532 attributes, {0}::getRegionBuilder()); 533 }]>, 534 OpBuilder< 535 (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, 536 "ValueRange":$outputs, 537 CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes), 538 [{{ 539 buildStructuredOp($_builder, $_state, resultTensorTypes, 540 inputs, outputs, attributes, {0}::getRegionBuilder()); 541 }]>, 542 OpBuilder< 543 (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, 544 CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes), 545 [{{ 546 $_state.addOperands(operands); 547 $_state.addAttributes(attributes); 548 $_state.addTypes(resultTensorTypes); 549 (void)$_state.addRegion(); 550 }]> 551 {5} 552 ]; 553 let hasCustomAssemblyFormat = 1; 554 let hasFolder = 1; 555 {6} 556 557 let extraClassDeclaration = structuredOpsBaseDecls # [{{ 558 // Auto-generated. 559 SmallVector<utils::IteratorType> getIteratorTypesArray(); 560 ArrayAttr getIndexingMaps(); 561 static void regionBuilder(ImplicitLocOpBuilder &b, 562 Block &block, ArrayRef<NamedAttribute> attrs); 563 static std::function<void(ImplicitLocOpBuilder &, 564 Block &, ArrayRef<NamedAttribute>)> 565 getRegionBuilder() {{ 566 return regionBuilder; 567 } 568 569 ::mlir::MutableOperandRange getDpsInitsMutable() {{ 570 return getOutputsMutable(); 571 } 572 573 // Generic methods. 574 static unsigned getNumRegionArgs(); 575 std::string getLibraryCallName(); 576 {7} 577 }]; 578 } 579 )FMT"; 580 581 // Builder method taking attribute parameters. Parameters: 582 // {0}: Class name 583 // {1}: Comma interleaved attribute parameters 584 // {2}: Attribute initialization 585 static const char structuredOpBuilderFormat[] = R"FMT( 586 , OpBuilder< 587 (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, 588 "ValueRange":$outputs, {1}, 589 CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes), 590 [{{ 591 {2} 592 buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs, 593 attributes, {0}::getRegionBuilder()); 594 }]> 595 )FMT"; 596 597 // The getIteratorTypesArray() method for structured ops. Parameters: 598 // {0}: Class name 599 // {1}: Comma interleaved iterator type names. 600 static const char structuredOpIteratorTypesFormat[] = 601 R"FMT( 602 SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{ 603 return SmallVector<utils::IteratorType>{{ {1} }; 604 } 605 )FMT"; 606 607 // The getIteratorTypesArray() method for rank polymorphic structured ops. 608 // Parameters: 609 // {0}: Class name 610 static const char rankPolyStructuredOpIteratorTypesFormat[] = 611 R"FMT( 612 SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{ 613 int64_t rank = getRank(getDpsInitOperand(0)); 614 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel); 615 } 616 )FMT"; 617 618 // The indexing_maps() method for structured ops. Parameters: 619 // {0}: Class name 620 // {1}: Comma-separated list of dimension variable names. 621 // {2}: Statements 622 static const char structuredOpIndexingMapsFormat[] = R"FMT( 623 ArrayAttr {0}::getIndexingMaps() {{ 624 static const char memoizeAttr[] = "linalg.memoized_indexing_maps"; 625 ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr); 626 if (cached) 627 return cached; 628 629 MLIRContext *context = getContext(); 630 auto symbolBindings = getSymbolBindings(*this); 631 SmallVector<AffineMap> maps; 632 {1} 633 cached = Builder(context).getAffineMapArrayAttr(maps); 634 getOperation()->setAttr(memoizeAttr, cached); 635 return cached; 636 } 637 )FMT"; 638 639 // The indexing_maps() method for rank polymorphic structured ops. Parameters: 640 // {0}: Class name 641 static const char rankPolyStructuredOpIndexingMapsFormat[] = R"FMT( 642 ArrayAttr {0}::getIndexingMaps() {{ 643 MLIRContext *context = getContext(); 644 AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context); 645 AffineMap tensorMap = AffineMap::getMultiDimIdentityMap( 646 getNumParallelLoops(), context); 647 SmallVector<AffineMap> indexingMaps; 648 for (OpOperand &opOperand : getOperation()->getOpOperands()) 649 indexingMaps.push_back(getRank(&opOperand) == 0 ? scalarMap : tensorMap); 650 return Builder(getContext()).getAffineMapArrayAttr(indexingMaps); 651 } 652 )FMT"; 653 654 // Implementations of fold, getEffects and getSpeculatability. 655 // Parameters: 656 // {0}: Class name 657 const char structuredOpFoldersFormat[] = R"FMT( 658 LogicalResult {0}::fold(FoldAdaptor, 659 SmallVectorImpl<OpFoldResult> &) {{ 660 return memref::foldMemRefCast(*this); 661 } 662 void {0}::getEffects(SmallVectorImpl< 663 SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{ 664 if (hasPureTensorSemantics()) return; 665 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation())); 666 } 667 Speculation::Speculatability {0}::getSpeculatability() {{ 668 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation())); 669 } 670 )FMT"; 671 672 // Implementation of parse/print. 673 // Parameters: 674 // {0}: Class name 675 static const char structuredOpParserFormat[] = R"FMT( 676 ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{ 677 return ::parseNamedStructuredOp(parser, result, 678 {0}::getNumRegionArgs(), {0}::getRegionBuilder()); 679 } 680 void {0}::print(OpAsmPrinter &p) {{ 681 SmallVector<StringRef, 3> elidedAttrs = {{"operandSegmentSizes", 682 "linalg.memoized_indexing_maps"}; 683 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(), 684 elidedAttrs); 685 } 686 )FMT"; 687 688 static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig, 689 GenerationContext &genContext) { 690 if (!genContext.shouldGenerateOds()) 691 return success(); 692 693 raw_ostream &os = genContext.odss(); 694 695 std::string interfaceNameList; 696 std::string attrList; 697 std::string attrMethods; 698 std::string attrBuilder; 699 700 std::string doc; 701 if (opConfig.metadata->doc) { 702 static const char structuredOpDocFmt[] = R"FMT( 703 let summary = [{{{0}}]; 704 let description = [{{{1}}]; 705 )FMT"; 706 StringRef summary, description; 707 std::tie(summary, description) = 708 StringRef(*opConfig.metadata->doc).trim().split("\n\n"); 709 710 doc = llvm::formatv(structuredOpDocFmt, summary.trim(), description.trim()); 711 } 712 713 interfaceNameList = interleaveToString(opConfig.metadata->implements, ", "); 714 715 std::string definitionList; 716 for (const std::string &definition : opConfig.metadata->defines) { 717 static const char definitionFmt[] = "let {0} = 1;\n"; 718 definitionList.append(llvm::formatv(definitionFmt, definition)); 719 } 720 721 if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { 722 return isAttribute(arg.kind); 723 })) { 724 SmallVector<std::string> attrDefs; 725 SmallVector<std::string> attrParams; 726 SmallVector<std::string> attrStmts; 727 for (LinalgOperandDef &arg : opConfig.structuredOp->args) { 728 static const char paramFmt[] = "\"Attribute\":${0}"; 729 static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});"; 730 // Add the type conversion attributes to the op definition and builders. 731 if (isFunctionAttribute(arg.kind)) { 732 assert(arg.defaultFn); 733 std::string enumName = convertOperandKindToEnumName(arg.kind); 734 static const char typeFmt[] = "{0}::{1}"; 735 static const char defFmt[] = 736 "DefaultValuedOptionalAttr<{0}, \"{1}\">:${2}"; 737 attrDefs.push_back(llvm::formatv( 738 defFmt, llvm::formatv("{0}Attr", enumName), 739 llvm::formatv(typeFmt, enumName, arg.defaultFn), arg.name)); 740 attrParams.push_back(llvm::formatv(paramFmt, arg.name)); 741 attrStmts.push_back(llvm::formatv(stmtFmt, arg.name)); 742 } 743 // Add the index attributes to the op definition and builders. 744 if (arg.kind == LinalgOperandDefKind::IndexAttr) { 745 assert(arg.indexAttrMap.has_value()); 746 assert(arg.defaultIndices.has_value()); 747 size_t size = arg.indexAttrMap->affineMap().getNumResults(); 748 assert(arg.defaultIndices->size() == size); 749 static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>"; 750 static const char defFmt[] = 751 "DefaultValuedOptionalAttr<{0}, \"{ {1} }\">:${2}"; 752 std::string defaultVals; 753 llvm::raw_string_ostream ss(defaultVals); 754 llvm::interleave( 755 *arg.defaultIndices, ss, 756 [&](int64_t val) { ss << "static_cast<int64_t>(" << val << ")"; }, 757 ", "); 758 attrDefs.push_back(llvm::formatv(defFmt, llvm::formatv(typeFmt, size), 759 ss.str(), arg.name)); 760 attrParams.push_back(llvm::formatv(paramFmt, arg.name)); 761 attrStmts.push_back(llvm::formatv(stmtFmt, arg.name)); 762 } 763 } 764 if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { 765 return arg.kind == LinalgOperandDefKind::IndexAttr; 766 })) { 767 attrMethods = R"( 768 bool hasDynamicIndexingMaps(); 769 LogicalResult verifyIndexingMapRequiredAttributes(); 770 )"; 771 } 772 attrList = ",\n" + llvm::join(attrDefs, ",\n"); 773 attrBuilder = llvm::formatv( 774 structuredOpBuilderFormat, opConfig.metadata->cppClassName, 775 llvm::join(attrParams, ", "), llvm::join(attrStmts, "\n")); 776 } 777 778 os << llvm::formatv(structuredOpOdsHeaderFormat, 779 opConfig.metadata->cppClassName, opConfig.metadata->name, 780 interfaceNameList, doc, attrList, attrBuilder, 781 definitionList, attrMethods); 782 783 return success(); 784 } 785 786 static LogicalResult 787 generateNamedGenericOpDefns(LinalgOpConfig &opConfig, 788 GenerationContext &genContext) { 789 if (!genContext.shouldGenerateDefns()) 790 return success(); 791 792 raw_ostream &os = genContext.defns(); 793 StringRef className = opConfig.metadata->cppClassName; 794 795 // Implementation banner. 796 std::string bannerComment = llvm::formatv("Implementation of {0}", className); 797 os << llvm::formatv(bannerFormat, bannerComment); 798 799 // Compute the number of scalar and tensor arguments. 800 int64_t numOfArgs = 801 llvm::count_if(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { 802 return arg.kind == LinalgOperandDefKind::InputTensor || 803 arg.kind == LinalgOperandDefKind::Scalar || 804 arg.kind == LinalgOperandDefKind::OutputTensor; 805 }); 806 807 // An operation that accesses only scalars and scalar/rank zero tensors is 808 // rank polymorhpic. We implement rank polymorphism by generating different 809 // indexing maps and iterators that match the rank of the first output tensor. 810 // An operation is rank polymorphic if the iteration domain has rank zero. 811 bool isRankPolymorphic = opConfig.structuredOp->iteratorTypes.empty(); 812 813 // Generate the iterator_types() method. 814 if (!isRankPolymorphic) { 815 std::string iteratorsStr; 816 llvm::raw_string_ostream ss(iteratorsStr); 817 llvm::interleaveComma(opConfig.structuredOp->iteratorTypes, ss, 818 [&](LinalgIteratorTypeDef it) { 819 switch (it) { 820 case LinalgIteratorTypeDef::parallel: 821 ss << "utils::IteratorType::parallel"; 822 break; 823 case LinalgIteratorTypeDef::reduction: 824 ss << "utils::IteratorType::reduction"; 825 break; 826 } 827 }); 828 os << llvm::formatv(structuredOpIteratorTypesFormat, className, 829 iteratorsStr); 830 } else { 831 os << llvm::formatv(rankPolyStructuredOpIteratorTypesFormat, className); 832 } 833 834 // Generating the getIndexingMaps() method. 835 if (auto &staticMaps = 836 opConfig.structuredOp->indexingMaps.staticIndexingMaps) { 837 if (staticMaps->empty()) 838 return emitError(genContext.getLoc()) << "op has no indexing maps"; 839 if (!isRankPolymorphic) { 840 AffineMap firstMap = staticMaps->front().affineMap(); 841 842 // Symbol bindings. 843 { 844 // For each symbol, generate a declaration for it, either with an 845 // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from 846 // an attribute). 847 // TODO: Possibly lift into a top-level method. 848 static const char structuredOpSymbolBindingsFormat[] = R"FMT( 849 static SmallVector<AffineExpr> getSymbolBindings({0} self) { 850 MLIRContext *context = self.getContext(); 851 SmallVector<AffineExpr> exprs; 852 {1} 853 return exprs; 854 } 855 )FMT"; 856 857 unsigned symbolCount = firstMap.getNumSymbols(); 858 SmallVector<std::string> symbolBindings; 859 for (unsigned i = 0; i < symbolCount; ++i) { 860 symbolBindings.push_back(llvm::formatv( 861 " exprs.push_back(getAffineSymbolExpr({0}, context));", i)); 862 } 863 864 // Access an index attribute. Parameters: 865 // {0}: Attribute name 866 // {1}: Symbol position 867 // {2}: Attribute index 868 static const char structuredOpAccessAttrFormat[] = R"FMT( 869 int64_t cst{1} = self.get{0}().getValues<int64_t>()[{2}]; 870 exprs.push_back(getAffineConstantExpr(cst{1}, context)); 871 )FMT"; 872 // Update all symbol bindings mapped to an attribute. 873 for (LinalgOperandDef &arg : opConfig.structuredOp->args) { 874 if (arg.kind != LinalgOperandDefKind::IndexAttr) 875 continue; 876 assert(arg.indexAttrMap); 877 for (auto [idx, result] : 878 llvm::enumerate(arg.indexAttrMap->affineMap().getResults())) { 879 if (auto symbol = dyn_cast<AffineSymbolExpr>(result)) { 880 std::string argName = arg.name; 881 argName[0] = toupper(argName[0]); 882 symbolBindings[symbol.getPosition()] = 883 llvm::formatv(structuredOpAccessAttrFormat, argName, 884 symbol.getPosition(), idx); 885 } 886 } 887 } 888 889 std::string symbolBindingsStr; 890 llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr); 891 llvm::interleave(symbolBindings, symbolBindingsSs, "\n"); 892 893 os << llvm::formatv(structuredOpSymbolBindingsFormat, className, 894 symbolBindingsStr); 895 } 896 897 // Indexing maps. 898 { 899 unsigned dimCount = firstMap.getNumDims(); 900 901 // Generate a comma-separated list of dim identifiers to be passed to 902 // bindDims, ensuring tht AffineExpr identifiers are bound in the right 903 // order to the proper AffineDimExpr. 904 // This results in vars in scope like: d0, d1, d2... 905 SmallVector<unsigned> dimIndices; 906 for (unsigned i = 0; i < dimCount; ++i) 907 dimIndices.push_back(i); 908 std::string dimIdentsStr; 909 llvm::raw_string_ostream dimIdentsSs(dimIdentsStr); 910 llvm::interleaveComma(dimIndices, dimIdentsSs, 911 [&](unsigned i) { dimIdentsSs << "d" << i; }); 912 913 // Statements to add and simplify each affine map. 914 SmallVector<std::string> stmts; 915 for (auto &indexingMap : *staticMaps) { 916 // TODO: Assert that dim and symbol count match the first. 917 stmts.push_back( 918 llvm::formatv("maps.push_back({0});", 919 generateCppExpression(indexingMap, "context"))); 920 stmts.push_back(llvm::formatv( 921 "maps.back() = " 922 "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, " 923 "symbolBindings, {0}, 0));", 924 dimCount)); 925 } 926 927 // TODO: This needs to be memoized and/or converted to non-parser based 928 // C++ codegen prior to real use. 929 os << llvm::formatv(structuredOpIndexingMapsFormat, className, 930 interleaveToString(stmts, "\n ")); 931 } 932 } else { 933 os << llvm::formatv(rankPolyStructuredOpIndexingMapsFormat, className); 934 } 935 } else { 936 return emitError(genContext.getLoc()) 937 << "generating code for non static indexing maps not currently " 938 "supported"; 939 } 940 941 // getNumRegionArgs() 942 { 943 // Generates a getNumRegionArgs() method. Parameters: 944 // {0}: Class name 945 // {1}: Number of region args 946 static const char structuredOpGetNumRegionArgsFormat[] = R"FMT( 947 unsigned {0}::getNumRegionArgs() {{ return {1}; } 948 )FMT"; 949 os << llvm::formatv(structuredOpGetNumRegionArgsFormat, className, 950 numOfArgs); 951 } 952 953 // getLibraryCallName() 954 { 955 // Generates a getLibraryCallName method. Parameters: 956 // {0}: Class name 957 static const char structuredOpGetLibraryCallFormat[] = R"FMT( 958 std::string {0}::getLibraryCallName() {{ 959 return generateLibraryCallName(getOperation()); 960 } 961 )FMT"; 962 os << llvm::formatv(structuredOpGetLibraryCallFormat, className); 963 } 964 965 // hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes() 966 if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { 967 return arg.kind == LinalgOperandDefKind::IndexAttr; 968 })) { 969 std::vector<std::string> attrVerifications; 970 for (LinalgOperandDef &arg : opConfig.structuredOp->args) { 971 if (arg.kind != LinalgOperandDefKind::IndexAttr) 972 continue; 973 assert(arg.indexAttrMap); 974 // Verify index attribute. Paramters: 975 // {0}: Attribute name 976 // {1}: Attribute size 977 static const char attrFmt[] = R"FMT( 978 if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{ 979 if (!attr.getType().getElementType().isInteger(64)) 980 return op->emitError("incorrect element type for index attribute '{0}'"); 981 if (attr.getType().getShape() != ArrayRef<int64_t>{{ {1} }) 982 return op->emitError("incorrect shape for index attribute '{0}'"); 983 } 984 )FMT"; 985 attrVerifications.push_back(llvm::formatv( 986 attrFmt, arg.name, arg.indexAttrMap->affineMap().getNumResults())); 987 } 988 989 // Generates the verifyIndexingMapRequiredAttributes method. Parameters: 990 // {0}: Class name 991 // {1}: Attribute verification 992 static const char structuredOpVerifyIndexingMapRequiredAttributes[] = R"FMT( 993 bool {0}::hasDynamicIndexingMaps() {{ return true; } 994 LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{ 995 Operation *op = getOperation(); 996 {1} 997 return success(); 998 } 999 )FMT"; 1000 os << llvm::formatv(structuredOpVerifyIndexingMapRequiredAttributes, 1001 className, llvm::join(attrVerifications, "\n")); 1002 } 1003 1004 // regionBuilder() 1005 { 1006 // Generates a regionBuilder method. Parameters. 1007 // {0}: Class name 1008 // {1}: Number of args 1009 // {2}: Attributes 1010 // {3}: Statements 1011 static const char structuredOpRegionBuilderFormat[] = R"FMT( 1012 void {0}::regionBuilder(ImplicitLocOpBuilder &b, 1013 Block &block, ArrayRef<NamedAttribute> attrs) {{ 1014 assert({1} > 0 && block.getNumArguments() == {1} && 1015 "{0} regionBuilder expects {1} (>=0) args"); 1016 RegionBuilderHelper helper(b, block); 1017 SmallVector<Value> yields; 1018 {2} 1019 {3} 1020 helper.yieldOutputs(yields); 1021 } 1022 )FMT"; 1023 auto &args = opConfig.structuredOp->args; 1024 auto &assignments = opConfig.structuredOp->assignments; 1025 size_t generatedAssignmentCount = 0; 1026 int localCounter = 0; 1027 SmallVector<std::string> attrs; 1028 SmallVector<std::string> stmts; 1029 for (LinalgOperandDef &arg : args) { 1030 if (!isFunctionAttribute(arg.kind)) 1031 continue; 1032 // Obtain the type function attribute values. Parameters. 1033 // {0}: enum name 1034 // {1}: attribute name 1035 // {2}: default type function name 1036 static const char attrDef[] = R"FMT( 1037 {0} {1}Val = {0}::{2}; 1038 auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{ 1039 return attr.getName() == "{1}"; }); 1040 if ({1}Iter != attrs.end()) {{ 1041 if (auto attr = llvm::dyn_cast<{0}Attr>({1}Iter->getValue())) 1042 {1}Val = attr.getValue(); 1043 } 1044 )FMT"; 1045 std::string enumName = convertOperandKindToEnumName(arg.kind); 1046 attrs.push_back( 1047 llvm::formatv(attrDef, enumName, arg.name, arg.defaultFn)); 1048 } 1049 for (LinalgOperandDef &arg : args) { 1050 if (arg.kind != LinalgOperandDefKind::OutputTensor) 1051 continue; 1052 1053 // Find the assignment that correlates with the argument. 1054 ScalarAssign *assignment = findAssignment(arg.name, assignments); 1055 if (!assignment) 1056 return emitError(genContext.getLoc()) 1057 << "no assignment found for output argument " << arg.name; 1058 ++generatedAssignmentCount; 1059 1060 // Recursively generate the expression. 1061 std::function<std::optional<std::string>(ScalarExpression &)> 1062 generateExpression = 1063 [&](ScalarExpression &expression) -> std::optional<std::string> { 1064 if (expression.arg) { 1065 // Argument reference. 1066 std::optional<int> argIndex = 1067 findTensorDefArgIndex(*expression.arg, args); 1068 if (!argIndex) { 1069 emitError(genContext.getLoc()) 1070 << "scalar argument not defined on the op: " << *expression.arg; 1071 return std::nullopt; 1072 } 1073 return std::string( 1074 llvm::formatv("block.getArgument({0})", *argIndex)); 1075 } 1076 if (expression.constant) { 1077 std::string cppIdent = llvm::formatv("value{0}", ++localCounter); 1078 stmts.push_back( 1079 llvm::formatv(R"FMT(Value {0} = helper.constant("{1}");)FMT", 1080 cppIdent, expression.constant)); 1081 return cppIdent; 1082 } 1083 if (expression.index) { 1084 // Access an iteration index. 1085 std::string cppIdent = llvm::formatv("value{0}", ++localCounter); 1086 stmts.push_back(llvm::formatv("Value {0} = helper.index({1});", 1087 cppIdent, *expression.index)); 1088 return cppIdent; 1089 } 1090 if (expression.scalarFn) { 1091 std::string enumName = 1092 convertFunctionKindToEnumName(expression.scalarFn->kind); 1093 1094 // Get the function or attribute name. 1095 assert(expression.scalarFn->fnName || expression.scalarFn->attrName); 1096 std::string funcType; 1097 if (expression.scalarFn->fnName) { 1098 funcType = llvm::formatv("{0}::{1}", enumName, 1099 *expression.scalarFn->fnName); 1100 } 1101 if (expression.scalarFn->attrName) { 1102 if (llvm::none_of(args, [&](LinalgOperandDef &arg) { 1103 return isFunctionAttribute(arg.kind) && 1104 arg.name == *expression.scalarFn->attrName; 1105 })) { 1106 emitError(genContext.getLoc()) << "missing function attribute " 1107 << *expression.scalarFn->attrName; 1108 } 1109 funcType = llvm::formatv("{0}Val", *expression.scalarFn->attrName); 1110 } 1111 assert(!funcType.empty()); 1112 1113 // Add the optional type parameter to the operands. 1114 SmallVector<std::string> operandCppValues; 1115 if (expression.scalarFn->kind == ScalarFnKind::Type) { 1116 assert(expression.scalarFn->typeVar.has_value()); 1117 std::optional<std::string> typeCppValue = 1118 findTypeValue(*expression.scalarFn->typeVar, args); 1119 if (!typeCppValue) { 1120 emitError(genContext.getLoc()) 1121 << "type variable " << *expression.scalarFn->typeVar 1122 << ", used in a type conversion, must map to a predefined or " 1123 << "an argument type but it does not"; 1124 return std::nullopt; 1125 } 1126 operandCppValues.push_back(*typeCppValue); 1127 } 1128 1129 // Collect the scalar operands. 1130 for (ScalarExpression &operand : expression.scalarFn->operands) { 1131 auto operandCppValue = generateExpression(operand); 1132 if (!operandCppValue) 1133 return std::nullopt; 1134 operandCppValues.push_back(*operandCppValue); 1135 } 1136 1137 // Call the function builder. 1138 std::string cppIdent = llvm::formatv("value{0}", ++localCounter); 1139 stmts.push_back(llvm::formatv( 1140 "Value {0} = helper.build{1}({2}, {3});", cppIdent, enumName, 1141 funcType, interleaveToString(operandCppValues, ", "))); 1142 return cppIdent; 1143 } 1144 emitError(genContext.getLoc()) << "unknown ScalarExpression type"; 1145 return std::nullopt; 1146 }; 1147 std::optional<std::string> cppValue = 1148 generateExpression(assignment->value); 1149 if (!cppValue) 1150 return failure(); 1151 stmts.push_back(llvm::formatv("yields.push_back({0});", *cppValue)); 1152 } 1153 1154 if (generatedAssignmentCount != assignments.size()) 1155 return emitError(genContext.getLoc()) 1156 << "mismatched number of assignments vs output arguments"; 1157 1158 os << llvm::formatv(structuredOpRegionBuilderFormat, className, numOfArgs, 1159 interleaveToString(attrs, "\n "), 1160 interleaveToString(stmts, "\n ")); 1161 } 1162 1163 // Parser and printer. 1164 os << llvm::formatv(structuredOpParserFormat, className); 1165 1166 // Canonicalizers and folders. 1167 os << llvm::formatv(structuredOpFoldersFormat, className); 1168 1169 return success(); 1170 } 1171 1172 static LogicalResult generateOp(LinalgOpConfig &opConfig, 1173 GenerationContext &genContext) { 1174 // Switch on op type being generated. 1175 if (opConfig.structuredOp) { 1176 return success( 1177 succeeded(generateNamedGenericOpOds(opConfig, genContext)) && 1178 succeeded(generateNamedGenericOpDefns(opConfig, genContext))); 1179 } 1180 return emitError(genContext.getLoc()) << "unsupported operation type"; 1181 } 1182 1183 //===----------------------------------------------------------------------===// 1184 // Command line options and main 1185 //===----------------------------------------------------------------------===// 1186 1187 static llvm::cl::opt<std::string> 1188 inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"), 1189 llvm::cl::init("-"), llvm::cl::value_desc("YAML filename")); 1190 1191 static llvm::cl::opt<std::string> 1192 outputOdsDeclFilename("o-ods-decl", llvm::cl::desc("ODS output filename"), 1193 llvm::cl::value_desc("filename"), llvm::cl::init("")); 1194 1195 static llvm::cl::opt<std::string> 1196 outputCppImplFilename("o-impl", 1197 llvm::cl::desc("C++ implementation file name"), 1198 llvm::cl::value_desc("filename"), llvm::cl::init("")); 1199 1200 int main(int argc, char **argv) { 1201 llvm::cl::ParseCommandLineOptions(argc, argv, "Linalg ODS Gen from YAML"); 1202 1203 // Set up the input file. 1204 std::string errorMessage; 1205 std::unique_ptr<llvm::MemoryBuffer> file = 1206 mlir::openInputFile(inputFilename, &errorMessage); 1207 if (!file) { 1208 llvm::errs() << errorMessage << "\n"; 1209 return 1; 1210 } 1211 1212 MLIRContext mlirContext; 1213 LinalgYAMLContext yamlContext{&mlirContext}; 1214 1215 std::vector<LinalgOpConfig> opConfigs; 1216 1217 // Parse input. 1218 Input yin(file->getBuffer(), &yamlContext); 1219 yin >> opConfigs; 1220 1221 if (yin.error()) 1222 return 1; 1223 1224 // Open output files. 1225 std::unique_ptr<llvm::ToolOutputFile> outputOdsDecl; 1226 if (!outputOdsDeclFilename.empty()) { 1227 outputOdsDecl = openOutputFile(outputOdsDeclFilename, &errorMessage); 1228 if (!outputOdsDecl) { 1229 llvm::errs() << errorMessage << "\n"; 1230 return 1; 1231 } 1232 } 1233 1234 std::unique_ptr<llvm::ToolOutputFile> outputCppImpl; 1235 if (!outputCppImplFilename.empty()) { 1236 outputCppImpl = openOutputFile(outputCppImplFilename, &errorMessage); 1237 if (!outputCppImpl) { 1238 llvm::errs() << errorMessage << "\n"; 1239 return 1; 1240 } 1241 } 1242 1243 if (!outputOdsDecl && !outputCppImpl) { 1244 llvm::errs() << "error: No output files specified\n"; 1245 return 1; 1246 } 1247 1248 // Generate. 1249 GenerationContext genContext(&mlirContext, 1250 outputOdsDecl ? &outputOdsDecl->os() : nullptr, 1251 outputCppImpl ? &outputCppImpl->os() : nullptr); 1252 1253 for (auto &opConfig : opConfigs) { 1254 if (!opConfig.metadata) { 1255 emitError(genContext.getLoc()) 1256 << "missing operation metadata on subsequent op"; 1257 return 1; 1258 } 1259 1260 genContext.setLoc(NameLoc::get( 1261 StringAttr::get(&mlirContext, opConfig.metadata->cppClassName))); 1262 if (failed(generateOp(opConfig, genContext))) { 1263 return 1; 1264 } 1265 } 1266 1267 if (outputOdsDecl) 1268 outputOdsDecl->keep(); 1269 if (outputCppImpl) 1270 outputCppImpl->keep(); 1271 1272 return 0; 1273 } 1274