xref: /llvm-project/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp (revision 9170fa58082e368f8b6ed8e0e6ef88fad8dd4633)
1a9787577SChristian Ulmann //===- LoopAnnotationTranslation.cpp - Loop annotation export -------------===//
2a9787577SChristian Ulmann //
3a9787577SChristian Ulmann // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a9787577SChristian Ulmann // See https://llvm.org/LICENSE.txt for license information.
5a9787577SChristian Ulmann // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a9787577SChristian Ulmann //
7a9787577SChristian Ulmann //===----------------------------------------------------------------------===//
8a9787577SChristian Ulmann 
9889a1178SChristian Ulmann #include "LoopAnnotationTranslation.h"
1062d7d94cSChristian Ulmann #include "llvm/IR/DebugInfoMetadata.h"
11889a1178SChristian Ulmann 
12889a1178SChristian Ulmann using namespace mlir;
13889a1178SChristian Ulmann using namespace mlir::LLVM;
14889a1178SChristian Ulmann using namespace mlir::LLVM::detail;
15889a1178SChristian Ulmann 
16889a1178SChristian Ulmann namespace {
17889a1178SChristian Ulmann /// Helper class that keeps the state of one attribute to metadata conversion.
18889a1178SChristian Ulmann struct LoopAnnotationConversion {
LoopAnnotationConversion__anon8f9a159b0111::LoopAnnotationConversion1987a04795SChristian Ulmann   LoopAnnotationConversion(LoopAnnotationAttr attr, Operation *op,
2087a04795SChristian Ulmann                            LoopAnnotationTranslation &loopAnnotationTranslation,
2187a04795SChristian Ulmann                            llvm::LLVMContext &ctx)
2287a04795SChristian Ulmann       : attr(attr), op(op),
2387a04795SChristian Ulmann         loopAnnotationTranslation(loopAnnotationTranslation), ctx(ctx) {}
24889a1178SChristian Ulmann 
25889a1178SChristian Ulmann   /// Converts this struct's loop annotation into a corresponding LLVMIR
26889a1178SChristian Ulmann   /// metadata representation.
27889a1178SChristian Ulmann   llvm::MDNode *convert();
28889a1178SChristian Ulmann 
29889a1178SChristian Ulmann   /// Conversion functions for different payload attribute kinds.
30889a1178SChristian Ulmann   void addUnitNode(StringRef name);
31889a1178SChristian Ulmann   void addUnitNode(StringRef name, BoolAttr attr);
324d7c879dSChristian Ulmann   void addI32NodeWithVal(StringRef name, uint32_t val);
33889a1178SChristian Ulmann   void convertBoolNode(StringRef name, BoolAttr attr, bool negated = false);
34889a1178SChristian Ulmann   void convertI32Node(StringRef name, IntegerAttr attr);
35889a1178SChristian Ulmann   void convertFollowupNode(StringRef name, LoopAnnotationAttr attr);
3662d7d94cSChristian Ulmann   void convertLocation(FusedLoc attr);
37889a1178SChristian Ulmann 
38889a1178SChristian Ulmann   /// Conversion functions for each for each loop annotation sub-attribute.
39889a1178SChristian Ulmann   void convertLoopOptions(LoopVectorizeAttr options);
40889a1178SChristian Ulmann   void convertLoopOptions(LoopInterleaveAttr options);
41889a1178SChristian Ulmann   void convertLoopOptions(LoopUnrollAttr options);
42889a1178SChristian Ulmann   void convertLoopOptions(LoopUnrollAndJamAttr options);
43889a1178SChristian Ulmann   void convertLoopOptions(LoopLICMAttr options);
44889a1178SChristian Ulmann   void convertLoopOptions(LoopDistributeAttr options);
45889a1178SChristian Ulmann   void convertLoopOptions(LoopPipelineAttr options);
467f249e45SChristian Ulmann   void convertLoopOptions(LoopPeeledAttr options);
477f249e45SChristian Ulmann   void convertLoopOptions(LoopUnswitchAttr options);
48889a1178SChristian Ulmann 
49889a1178SChristian Ulmann   LoopAnnotationAttr attr;
50889a1178SChristian Ulmann   Operation *op;
51889a1178SChristian Ulmann   LoopAnnotationTranslation &loopAnnotationTranslation;
52889a1178SChristian Ulmann   llvm::LLVMContext &ctx;
53889a1178SChristian Ulmann   llvm::SmallVector<llvm::Metadata *> metadataNodes;
54889a1178SChristian Ulmann };
55889a1178SChristian Ulmann } // namespace
56889a1178SChristian Ulmann 
addUnitNode(StringRef name)57889a1178SChristian Ulmann void LoopAnnotationConversion::addUnitNode(StringRef name) {
58889a1178SChristian Ulmann   metadataNodes.push_back(
59889a1178SChristian Ulmann       llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name)}));
60889a1178SChristian Ulmann }
61889a1178SChristian Ulmann 
addUnitNode(StringRef name,BoolAttr attr)62889a1178SChristian Ulmann void LoopAnnotationConversion::addUnitNode(StringRef name, BoolAttr attr) {
63889a1178SChristian Ulmann   if (attr && attr.getValue())
64889a1178SChristian Ulmann     addUnitNode(name);
65889a1178SChristian Ulmann }
66889a1178SChristian Ulmann 
addI32NodeWithVal(StringRef name,uint32_t val)674d7c879dSChristian Ulmann void LoopAnnotationConversion::addI32NodeWithVal(StringRef name, uint32_t val) {
684d7c879dSChristian Ulmann   llvm::Constant *cstValue = llvm::ConstantInt::get(
694d7c879dSChristian Ulmann       llvm::IntegerType::get(ctx, /*NumBits=*/32), val, /*isSigned=*/false);
704d7c879dSChristian Ulmann   metadataNodes.push_back(
714d7c879dSChristian Ulmann       llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name),
724d7c879dSChristian Ulmann                               llvm::ConstantAsMetadata::get(cstValue)}));
734d7c879dSChristian Ulmann }
744d7c879dSChristian Ulmann 
convertBoolNode(StringRef name,BoolAttr attr,bool negated)75889a1178SChristian Ulmann void LoopAnnotationConversion::convertBoolNode(StringRef name, BoolAttr attr,
76889a1178SChristian Ulmann                                                bool negated) {
77889a1178SChristian Ulmann   if (!attr)
78889a1178SChristian Ulmann     return;
79889a1178SChristian Ulmann   bool val = negated ^ attr.getValue();
80889a1178SChristian Ulmann   llvm::Constant *cstValue = llvm::ConstantInt::getBool(ctx, val);
81889a1178SChristian Ulmann   metadataNodes.push_back(
82889a1178SChristian Ulmann       llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name),
83889a1178SChristian Ulmann                               llvm::ConstantAsMetadata::get(cstValue)}));
84889a1178SChristian Ulmann }
85889a1178SChristian Ulmann 
convertI32Node(StringRef name,IntegerAttr attr)86889a1178SChristian Ulmann void LoopAnnotationConversion::convertI32Node(StringRef name,
87889a1178SChristian Ulmann                                               IntegerAttr attr) {
88889a1178SChristian Ulmann   if (!attr)
89889a1178SChristian Ulmann     return;
904d7c879dSChristian Ulmann   addI32NodeWithVal(name, attr.getInt());
91889a1178SChristian Ulmann }
92889a1178SChristian Ulmann 
convertFollowupNode(StringRef name,LoopAnnotationAttr attr)93889a1178SChristian Ulmann void LoopAnnotationConversion::convertFollowupNode(StringRef name,
94889a1178SChristian Ulmann                                                    LoopAnnotationAttr attr) {
95889a1178SChristian Ulmann   if (!attr)
96889a1178SChristian Ulmann     return;
97889a1178SChristian Ulmann 
9887a04795SChristian Ulmann   llvm::MDNode *node =
9987a04795SChristian Ulmann       loopAnnotationTranslation.translateLoopAnnotation(attr, op);
100889a1178SChristian Ulmann 
101889a1178SChristian Ulmann   metadataNodes.push_back(
102889a1178SChristian Ulmann       llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name), node}));
103889a1178SChristian Ulmann }
104889a1178SChristian Ulmann 
convertLoopOptions(LoopVectorizeAttr options)105889a1178SChristian Ulmann void LoopAnnotationConversion::convertLoopOptions(LoopVectorizeAttr options) {
106889a1178SChristian Ulmann   convertBoolNode("llvm.loop.vectorize.enable", options.getDisable(), true);
107889a1178SChristian Ulmann   convertBoolNode("llvm.loop.vectorize.predicate.enable",
108889a1178SChristian Ulmann                   options.getPredicateEnable());
109889a1178SChristian Ulmann   convertBoolNode("llvm.loop.vectorize.scalable.enable",
110889a1178SChristian Ulmann                   options.getScalableEnable());
111889a1178SChristian Ulmann   convertI32Node("llvm.loop.vectorize.width", options.getWidth());
112889a1178SChristian Ulmann   convertFollowupNode("llvm.loop.vectorize.followup_vectorized",
113889a1178SChristian Ulmann                       options.getFollowupVectorized());
114889a1178SChristian Ulmann   convertFollowupNode("llvm.loop.vectorize.followup_epilogue",
115889a1178SChristian Ulmann                       options.getFollowupEpilogue());
116889a1178SChristian Ulmann   convertFollowupNode("llvm.loop.vectorize.followup_all",
117889a1178SChristian Ulmann                       options.getFollowupAll());
118889a1178SChristian Ulmann }
119889a1178SChristian Ulmann 
convertLoopOptions(LoopInterleaveAttr options)120889a1178SChristian Ulmann void LoopAnnotationConversion::convertLoopOptions(LoopInterleaveAttr options) {
121889a1178SChristian Ulmann   convertI32Node("llvm.loop.interleave.count", options.getCount());
122889a1178SChristian Ulmann }
123889a1178SChristian Ulmann 
convertLoopOptions(LoopUnrollAttr options)124889a1178SChristian Ulmann void LoopAnnotationConversion::convertLoopOptions(LoopUnrollAttr options) {
125889a1178SChristian Ulmann   if (auto disable = options.getDisable())
126889a1178SChristian Ulmann     addUnitNode(disable.getValue() ? "llvm.loop.unroll.disable"
127889a1178SChristian Ulmann                                    : "llvm.loop.unroll.enable");
128889a1178SChristian Ulmann   convertI32Node("llvm.loop.unroll.count", options.getCount());
129889a1178SChristian Ulmann   convertBoolNode("llvm.loop.unroll.runtime.disable",
130889a1178SChristian Ulmann                   options.getRuntimeDisable());
131889a1178SChristian Ulmann   addUnitNode("llvm.loop.unroll.full", options.getFull());
1324d7c879dSChristian Ulmann   convertFollowupNode("llvm.loop.unroll.followup_unrolled",
1334d7c879dSChristian Ulmann                       options.getFollowupUnrolled());
134889a1178SChristian Ulmann   convertFollowupNode("llvm.loop.unroll.followup_remainder",
135889a1178SChristian Ulmann                       options.getFollowupRemainder());
1364d7c879dSChristian Ulmann   convertFollowupNode("llvm.loop.unroll.followup_all",
1374d7c879dSChristian Ulmann                       options.getFollowupAll());
138889a1178SChristian Ulmann }
139889a1178SChristian Ulmann 
convertLoopOptions(LoopUnrollAndJamAttr options)140889a1178SChristian Ulmann void LoopAnnotationConversion::convertLoopOptions(
141889a1178SChristian Ulmann     LoopUnrollAndJamAttr options) {
142889a1178SChristian Ulmann   if (auto disable = options.getDisable())
143889a1178SChristian Ulmann     addUnitNode(disable.getValue() ? "llvm.loop.unroll_and_jam.disable"
144889a1178SChristian Ulmann                                    : "llvm.loop.unroll_and_jam.enable");
145889a1178SChristian Ulmann   convertI32Node("llvm.loop.unroll_and_jam.count", options.getCount());
146889a1178SChristian Ulmann   convertFollowupNode("llvm.loop.unroll_and_jam.followup_outer",
147889a1178SChristian Ulmann                       options.getFollowupOuter());
148889a1178SChristian Ulmann   convertFollowupNode("llvm.loop.unroll_and_jam.followup_inner",
149889a1178SChristian Ulmann                       options.getFollowupInner());
150889a1178SChristian Ulmann   convertFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_outer",
151889a1178SChristian Ulmann                       options.getFollowupRemainderOuter());
152889a1178SChristian Ulmann   convertFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_inner",
153889a1178SChristian Ulmann                       options.getFollowupRemainderInner());
154889a1178SChristian Ulmann   convertFollowupNode("llvm.loop.unroll_and_jam.followup_all",
155889a1178SChristian Ulmann                       options.getFollowupAll());
156889a1178SChristian Ulmann }
157889a1178SChristian Ulmann 
convertLoopOptions(LoopLICMAttr options)158889a1178SChristian Ulmann void LoopAnnotationConversion::convertLoopOptions(LoopLICMAttr options) {
159889a1178SChristian Ulmann   addUnitNode("llvm.licm.disable", options.getDisable());
160889a1178SChristian Ulmann   addUnitNode("llvm.loop.licm_versioning.disable",
161889a1178SChristian Ulmann               options.getVersioningDisable());
162889a1178SChristian Ulmann }
163889a1178SChristian Ulmann 
convertLoopOptions(LoopDistributeAttr options)164889a1178SChristian Ulmann void LoopAnnotationConversion::convertLoopOptions(LoopDistributeAttr options) {
165889a1178SChristian Ulmann   convertBoolNode("llvm.loop.distribute.enable", options.getDisable(), true);
166889a1178SChristian Ulmann   convertFollowupNode("llvm.loop.distribute.followup_coincident",
167889a1178SChristian Ulmann                       options.getFollowupCoincident());
168889a1178SChristian Ulmann   convertFollowupNode("llvm.loop.distribute.followup_sequential",
169889a1178SChristian Ulmann                       options.getFollowupSequential());
170889a1178SChristian Ulmann   convertFollowupNode("llvm.loop.distribute.followup_fallback",
171889a1178SChristian Ulmann                       options.getFollowupFallback());
172889a1178SChristian Ulmann   convertFollowupNode("llvm.loop.distribute.followup_all",
173889a1178SChristian Ulmann                       options.getFollowupAll());
174889a1178SChristian Ulmann }
175889a1178SChristian Ulmann 
convertLoopOptions(LoopPipelineAttr options)176889a1178SChristian Ulmann void LoopAnnotationConversion::convertLoopOptions(LoopPipelineAttr options) {
177889a1178SChristian Ulmann   convertBoolNode("llvm.loop.pipeline.disable", options.getDisable());
178889a1178SChristian Ulmann   convertI32Node("llvm.loop.pipeline.initiationinterval",
179889a1178SChristian Ulmann                  options.getInitiationinterval());
180889a1178SChristian Ulmann }
181889a1178SChristian Ulmann 
convertLoopOptions(LoopPeeledAttr options)1827f249e45SChristian Ulmann void LoopAnnotationConversion::convertLoopOptions(LoopPeeledAttr options) {
1837f249e45SChristian Ulmann   convertI32Node("llvm.loop.peeled.count", options.getCount());
1847f249e45SChristian Ulmann }
1857f249e45SChristian Ulmann 
convertLoopOptions(LoopUnswitchAttr options)1867f249e45SChristian Ulmann void LoopAnnotationConversion::convertLoopOptions(LoopUnswitchAttr options) {
1877f249e45SChristian Ulmann   addUnitNode("llvm.loop.unswitch.partial.disable",
1887f249e45SChristian Ulmann               options.getPartialDisable());
1897f249e45SChristian Ulmann }
1907f249e45SChristian Ulmann 
convertLocation(FusedLoc location)19162d7d94cSChristian Ulmann void LoopAnnotationConversion::convertLocation(FusedLoc location) {
19262d7d94cSChristian Ulmann   auto localScopeAttr =
1935550c821STres Popp       dyn_cast_or_null<DILocalScopeAttr>(location.getMetadata());
19462d7d94cSChristian Ulmann   if (!localScopeAttr)
19562d7d94cSChristian Ulmann     return;
19662d7d94cSChristian Ulmann   auto *localScope = dyn_cast<llvm::DILocalScope>(
19762d7d94cSChristian Ulmann       loopAnnotationTranslation.moduleTranslation.translateDebugInfo(
19862d7d94cSChristian Ulmann           localScopeAttr));
19962d7d94cSChristian Ulmann   if (!localScope)
20062d7d94cSChristian Ulmann     return;
201794b58b4SChristian Ulmann   llvm::Metadata *loc =
20262d7d94cSChristian Ulmann       loopAnnotationTranslation.moduleTranslation.translateLoc(location,
20362d7d94cSChristian Ulmann                                                                localScope);
204794b58b4SChristian Ulmann   metadataNodes.push_back(loc);
20562d7d94cSChristian Ulmann }
206889a1178SChristian Ulmann 
convert()20762d7d94cSChristian Ulmann llvm::MDNode *LoopAnnotationConversion::convert() {
208889a1178SChristian Ulmann   // Reserve operand 0 for loop id self reference.
209889a1178SChristian Ulmann   auto dummy = llvm::MDNode::getTemporary(ctx, std::nullopt);
210889a1178SChristian Ulmann   metadataNodes.push_back(dummy.get());
211889a1178SChristian Ulmann 
21262d7d94cSChristian Ulmann   if (FusedLoc startLoc = attr.getStartLoc())
21362d7d94cSChristian Ulmann     convertLocation(startLoc);
21462d7d94cSChristian Ulmann 
21562d7d94cSChristian Ulmann   if (FusedLoc endLoc = attr.getEndLoc())
21662d7d94cSChristian Ulmann     convertLocation(endLoc);
21762d7d94cSChristian Ulmann 
218889a1178SChristian Ulmann   addUnitNode("llvm.loop.disable_nonforced", attr.getDisableNonforced());
219889a1178SChristian Ulmann   addUnitNode("llvm.loop.mustprogress", attr.getMustProgress());
2204d7c879dSChristian Ulmann   // "isvectorized" is encoded as an i32 value.
2214d7c879dSChristian Ulmann   if (BoolAttr isVectorized = attr.getIsVectorized())
2224d7c879dSChristian Ulmann     addI32NodeWithVal("llvm.loop.isvectorized", isVectorized.getValue());
223889a1178SChristian Ulmann 
224889a1178SChristian Ulmann   if (auto options = attr.getVectorize())
225889a1178SChristian Ulmann     convertLoopOptions(options);
226889a1178SChristian Ulmann   if (auto options = attr.getInterleave())
227889a1178SChristian Ulmann     convertLoopOptions(options);
228889a1178SChristian Ulmann   if (auto options = attr.getUnroll())
229889a1178SChristian Ulmann     convertLoopOptions(options);
230889a1178SChristian Ulmann   if (auto options = attr.getUnrollAndJam())
231889a1178SChristian Ulmann     convertLoopOptions(options);
232889a1178SChristian Ulmann   if (auto options = attr.getLicm())
233889a1178SChristian Ulmann     convertLoopOptions(options);
234889a1178SChristian Ulmann   if (auto options = attr.getDistribute())
235889a1178SChristian Ulmann     convertLoopOptions(options);
236889a1178SChristian Ulmann   if (auto options = attr.getPipeline())
237889a1178SChristian Ulmann     convertLoopOptions(options);
2387f249e45SChristian Ulmann   if (auto options = attr.getPeeled())
2397f249e45SChristian Ulmann     convertLoopOptions(options);
2407f249e45SChristian Ulmann   if (auto options = attr.getUnswitch())
2417f249e45SChristian Ulmann     convertLoopOptions(options);
242889a1178SChristian Ulmann 
243*9170fa58SMarkus Böck   ArrayRef<AccessGroupAttr> parallelAccessGroups = attr.getParallelAccesses();
244889a1178SChristian Ulmann   if (!parallelAccessGroups.empty()) {
245889a1178SChristian Ulmann     SmallVector<llvm::Metadata *> parallelAccess;
246889a1178SChristian Ulmann     parallelAccess.push_back(
247889a1178SChristian Ulmann         llvm::MDString::get(ctx, "llvm.loop.parallel_accesses"));
248*9170fa58SMarkus Böck     for (AccessGroupAttr accessGroupAttr : parallelAccessGroups)
249889a1178SChristian Ulmann       parallelAccess.push_back(
250*9170fa58SMarkus Böck           loopAnnotationTranslation.getAccessGroup(accessGroupAttr));
251889a1178SChristian Ulmann     metadataNodes.push_back(llvm::MDNode::get(ctx, parallelAccess));
252889a1178SChristian Ulmann   }
253889a1178SChristian Ulmann 
254889a1178SChristian Ulmann   // Create loop options and set the first operand to itself.
255889a1178SChristian Ulmann   llvm::MDNode *loopMD = llvm::MDNode::get(ctx, metadataNodes);
256889a1178SChristian Ulmann   loopMD->replaceOperandWith(0, loopMD);
257889a1178SChristian Ulmann 
258889a1178SChristian Ulmann   return loopMD;
259889a1178SChristian Ulmann }
260889a1178SChristian Ulmann 
26187a04795SChristian Ulmann llvm::MDNode *
translateLoopAnnotation(LoopAnnotationAttr attr,Operation * op)26287a04795SChristian Ulmann LoopAnnotationTranslation::translateLoopAnnotation(LoopAnnotationAttr attr,
263889a1178SChristian Ulmann                                                    Operation *op) {
264889a1178SChristian Ulmann   if (!attr)
265889a1178SChristian Ulmann     return nullptr;
266889a1178SChristian Ulmann 
267889a1178SChristian Ulmann   llvm::MDNode *loopMD = lookupLoopMetadata(attr);
268889a1178SChristian Ulmann   if (loopMD)
269889a1178SChristian Ulmann     return loopMD;
270889a1178SChristian Ulmann 
271889a1178SChristian Ulmann   loopMD =
27287a04795SChristian Ulmann       LoopAnnotationConversion(attr, op, *this, this->llvmModule.getContext())
27387a04795SChristian Ulmann           .convert();
274889a1178SChristian Ulmann   // Store a map from this Attribute to the LLVM metadata in case we
275889a1178SChristian Ulmann   // encounter it again.
276889a1178SChristian Ulmann   mapLoopMetadata(attr, loopMD);
277889a1178SChristian Ulmann   return loopMD;
278889a1178SChristian Ulmann }
27987a04795SChristian Ulmann 
280*9170fa58SMarkus Böck llvm::MDNode *
getAccessGroup(AccessGroupAttr accessGroupAttr)281*9170fa58SMarkus Böck LoopAnnotationTranslation::getAccessGroup(AccessGroupAttr accessGroupAttr) {
282*9170fa58SMarkus Böck   auto [result, inserted] =
283*9170fa58SMarkus Böck       accessGroupMetadataMapping.insert({accessGroupAttr, nullptr});
284*9170fa58SMarkus Böck   if (inserted)
285*9170fa58SMarkus Böck     result->second = llvm::MDNode::getDistinct(llvmModule.getContext(), {});
286*9170fa58SMarkus Böck   return result->second;
28787a04795SChristian Ulmann }
28887a04795SChristian Ulmann 
28987a04795SChristian Ulmann llvm::MDNode *
getAccessGroups(AccessGroupOpInterface op)290*9170fa58SMarkus Böck LoopAnnotationTranslation::getAccessGroups(AccessGroupOpInterface op) {
291*9170fa58SMarkus Böck   ArrayAttr accessGroups = op.getAccessGroupsOrNull();
292*9170fa58SMarkus Böck   if (!accessGroups || accessGroups.empty())
29387a04795SChristian Ulmann     return nullptr;
29487a04795SChristian Ulmann 
29587a04795SChristian Ulmann   SmallVector<llvm::Metadata *> groupMDs;
296*9170fa58SMarkus Böck   for (AccessGroupAttr group : accessGroups.getAsRange<AccessGroupAttr>())
297*9170fa58SMarkus Böck     groupMDs.push_back(getAccessGroup(group));
29887a04795SChristian Ulmann   if (groupMDs.size() == 1)
29987a04795SChristian Ulmann     return llvm::cast<llvm::MDNode>(groupMDs.front());
30087a04795SChristian Ulmann   return llvm::MDNode::get(llvmModule.getContext(), groupMDs);
30187a04795SChristian Ulmann }
302