xref: /llvm-project/mlir/test/lib/Dialect/Affine/TestLoopMapping.cpp (revision 4c48f016effde67d500fc95290096aec9f3bdb70)
1a70aa7bbSRiver Riddle //===- TestLoopMapping.cpp --- Parametric loop mapping pass ---------------===//
2a70aa7bbSRiver Riddle //
3a70aa7bbSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a70aa7bbSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5a70aa7bbSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a70aa7bbSRiver Riddle //
7a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
8a70aa7bbSRiver Riddle //
9a70aa7bbSRiver Riddle // This file implements a pass to parametrically map scf.for loops to virtual
10a70aa7bbSRiver Riddle // processing element dimensions.
11a70aa7bbSRiver Riddle //
12a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
13a70aa7bbSRiver Riddle 
14a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
15a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/LoopUtils.h"
168b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
17a70aa7bbSRiver Riddle #include "mlir/IR/Builders.h"
18a70aa7bbSRiver Riddle #include "mlir/Pass/Pass.h"
19a70aa7bbSRiver Riddle 
20a70aa7bbSRiver Riddle using namespace mlir;
21*4c48f016SMatthias Springer using namespace mlir::affine;
22a70aa7bbSRiver Riddle 
23a70aa7bbSRiver Riddle namespace {
245e50dd04SRiver Riddle struct TestLoopMappingPass
2587d6bf37SRiver Riddle     : public PassWrapper<TestLoopMappingPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anoncee15cae0111::TestLoopMappingPass265e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopMappingPass)
275e50dd04SRiver Riddle 
28a70aa7bbSRiver Riddle   StringRef getArgument() const final {
29a70aa7bbSRiver Riddle     return "test-mapping-to-processing-elements";
30a70aa7bbSRiver Riddle   }
getDescription__anoncee15cae0111::TestLoopMappingPass31a70aa7bbSRiver Riddle   StringRef getDescription() const final {
32a70aa7bbSRiver Riddle     return "test mapping a single loop on a virtual processor grid";
33a70aa7bbSRiver Riddle   }
34a70aa7bbSRiver Riddle   explicit TestLoopMappingPass() = default;
35a70aa7bbSRiver Riddle 
getDependentDialects__anoncee15cae0111::TestLoopMappingPass36a70aa7bbSRiver Riddle   void getDependentDialects(DialectRegistry &registry) const override {
37*4c48f016SMatthias Springer     registry.insert<affine::AffineDialect, scf::SCFDialect>();
38a70aa7bbSRiver Riddle   }
39a70aa7bbSRiver Riddle 
runOnOperation__anoncee15cae0111::TestLoopMappingPass40a70aa7bbSRiver Riddle   void runOnOperation() override {
41a70aa7bbSRiver Riddle     // SSA values for the transformation are created out of thin air by
42a70aa7bbSRiver Riddle     // unregistered "new_processor_id_and_range" operations. This is enough to
43a70aa7bbSRiver Riddle     // emulate mapping conditions.
44a70aa7bbSRiver Riddle     SmallVector<Value, 8> processorIds, numProcessors;
4587d6bf37SRiver Riddle     getOperation()->walk([&processorIds, &numProcessors](Operation *op) {
46a70aa7bbSRiver Riddle       if (op->getName().getStringRef() != "new_processor_id_and_range")
47a70aa7bbSRiver Riddle         return;
48a70aa7bbSRiver Riddle       processorIds.push_back(op->getResult(0));
49a70aa7bbSRiver Riddle       numProcessors.push_back(op->getResult(1));
50a70aa7bbSRiver Riddle     });
51a70aa7bbSRiver Riddle 
5287d6bf37SRiver Riddle     getOperation()->walk([&processorIds, &numProcessors](scf::ForOp op) {
53a70aa7bbSRiver Riddle       // Ignore nested loops.
54a70aa7bbSRiver Riddle       if (op->getParentRegion()->getParentOfType<scf::ForOp>())
55a70aa7bbSRiver Riddle         return;
56a70aa7bbSRiver Riddle       mapLoopToProcessorIds(op, processorIds, numProcessors);
57a70aa7bbSRiver Riddle     });
58a70aa7bbSRiver Riddle   }
59a70aa7bbSRiver Riddle };
60a70aa7bbSRiver Riddle } // namespace
61a70aa7bbSRiver Riddle 
62a70aa7bbSRiver Riddle namespace mlir {
63a70aa7bbSRiver Riddle namespace test {
registerTestLoopMappingPass()64a70aa7bbSRiver Riddle void registerTestLoopMappingPass() { PassRegistration<TestLoopMappingPass>(); }
65a70aa7bbSRiver Riddle } // namespace test
66a70aa7bbSRiver Riddle } // namespace mlir
67