xref: /llvm-project/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp (revision adbf21f12b3069b2554efb39f2e92c6cf6f24940)
1*adbf21f1SBoian Petkantchin //===- MeshShardingExtensions.cpp - ---------------------------------------===//
2*adbf21f1SBoian Petkantchin //
3*adbf21f1SBoian Petkantchin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*adbf21f1SBoian Petkantchin // See https://llvm.org/LICENSE.txt for license information.
5*adbf21f1SBoian Petkantchin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*adbf21f1SBoian Petkantchin //
7*adbf21f1SBoian Petkantchin //===----------------------------------------------------------------------===//
8*adbf21f1SBoian Petkantchin 
9*adbf21f1SBoian Petkantchin #include "mlir/Dialect/Func/Extensions/MeshShardingExtensions.h"
10*adbf21f1SBoian Petkantchin #include "mlir/Dialect/Func/IR/FuncOps.h"
11*adbf21f1SBoian Petkantchin #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
12*adbf21f1SBoian Petkantchin #include "mlir/IR/MLIRContext.h"
13*adbf21f1SBoian Petkantchin 
14*adbf21f1SBoian Petkantchin namespace mlir::func {
15*adbf21f1SBoian Petkantchin 
registerShardingInterfaceExternalModels(DialectRegistry & registry)16*adbf21f1SBoian Petkantchin void registerShardingInterfaceExternalModels(DialectRegistry &registry) {
17*adbf21f1SBoian Petkantchin   registry.addExtension(+[](MLIRContext *ctx, FuncDialect *dialect) {
18*adbf21f1SBoian Petkantchin     ReturnOp::attachInterface<
19*adbf21f1SBoian Petkantchin         mesh::IndependentParallelIteratorDomainShardingInterface<ReturnOp>>(
20*adbf21f1SBoian Petkantchin         *ctx);
21*adbf21f1SBoian Petkantchin   });
22*adbf21f1SBoian Petkantchin }
23*adbf21f1SBoian Petkantchin 
24*adbf21f1SBoian Petkantchin } // namespace mlir::func
25