xref: /llvm-project/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
13fef2d26SRiver Riddle //===- TestMemRefStrideCalculation.cpp - Pass to test strides computation--===//
23fef2d26SRiver Riddle //
33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63fef2d26SRiver Riddle //
73fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
83fef2d26SRiver Riddle 
93fef2d26SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
103fef2d26SRiver Riddle #include "mlir/IR/BuiltinTypes.h"
113fef2d26SRiver Riddle #include "mlir/Pass/Pass.h"
123fef2d26SRiver Riddle 
133fef2d26SRiver Riddle using namespace mlir;
143fef2d26SRiver Riddle 
153fef2d26SRiver Riddle namespace {
163fef2d26SRiver Riddle struct TestMemRefStrideCalculation
1787d6bf37SRiver Riddle     : public PassWrapper<TestMemRefStrideCalculation,
1887d6bf37SRiver Riddle                          InterfacePass<SymbolOpInterface>> {
195e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMemRefStrideCalculation)
205e50dd04SRiver Riddle 
21b5e22e6dSMehdi Amini   StringRef getArgument() const final {
22b5e22e6dSMehdi Amini     return "test-memref-stride-calculation";
23b5e22e6dSMehdi Amini   }
24b5e22e6dSMehdi Amini   StringRef getDescription() const final {
25b5e22e6dSMehdi Amini     return "Test operation constant folding";
26b5e22e6dSMehdi Amini   }
2741574554SRiver Riddle   void runOnOperation() override;
283fef2d26SRiver Riddle };
29be0a7e9fSMehdi Amini } // namespace
303fef2d26SRiver Riddle 
313fef2d26SRiver Riddle /// Traverse AllocOp and compute strides of each MemRefType independently.
3241574554SRiver Riddle void TestMemRefStrideCalculation::runOnOperation() {
3341574554SRiver Riddle   llvm::outs() << "Testing: " << getOperation().getName() << "\n";
3441574554SRiver Riddle   getOperation().walk([&](memref::AllocOp allocOp) {
355550c821STres Popp     auto memrefType = cast<MemRefType>(allocOp.getResult().getType());
363fef2d26SRiver Riddle     int64_t offset;
373fef2d26SRiver Riddle     SmallVector<int64_t, 4> strides;
38*6aaa8f25SMatthias Springer     if (failed(memrefType.getStridesAndOffset(strides, offset))) {
393fef2d26SRiver Riddle       llvm::outs() << "MemRefType " << memrefType << " cannot be converted to "
403fef2d26SRiver Riddle                    << "strided form\n";
413fef2d26SRiver Riddle       return;
423fef2d26SRiver Riddle     }
433fef2d26SRiver Riddle     llvm::outs() << "MemRefType offset: ";
44399638f9SAliia Khasanova     if (ShapedType::isDynamic(offset))
453fef2d26SRiver Riddle       llvm::outs() << "?";
463fef2d26SRiver Riddle     else
473fef2d26SRiver Riddle       llvm::outs() << offset;
483fef2d26SRiver Riddle     llvm::outs() << " strides: ";
493fef2d26SRiver Riddle     llvm::interleaveComma(strides, llvm::outs(), [&](int64_t v) {
50399638f9SAliia Khasanova       if (ShapedType::isDynamic(v))
513fef2d26SRiver Riddle         llvm::outs() << "?";
523fef2d26SRiver Riddle       else
533fef2d26SRiver Riddle         llvm::outs() << v;
543fef2d26SRiver Riddle     });
553fef2d26SRiver Riddle     llvm::outs() << "\n";
563fef2d26SRiver Riddle   });
573fef2d26SRiver Riddle   llvm::outs().flush();
583fef2d26SRiver Riddle }
593fef2d26SRiver Riddle 
603fef2d26SRiver Riddle namespace mlir {
613fef2d26SRiver Riddle namespace test {
623fef2d26SRiver Riddle void registerTestMemRefStrideCalculation() {
63b5e22e6dSMehdi Amini   PassRegistration<TestMemRefStrideCalculation>();
643fef2d26SRiver Riddle }
653fef2d26SRiver Riddle } // namespace test
663fef2d26SRiver Riddle } // namespace mlir
67