xref: /llvm-project/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp (revision 9170fa58082e368f8b6ed8e0e6ef88fad8dd4633)
1 //===- LoopAnnotationTranslation.cpp - Loop annotation export -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "LoopAnnotationTranslation.h"
10 #include "llvm/IR/DebugInfoMetadata.h"
11 
12 using namespace mlir;
13 using namespace mlir::LLVM;
14 using namespace mlir::LLVM::detail;
15 
16 namespace {
17 /// Helper class that keeps the state of one attribute to metadata conversion.
18 struct LoopAnnotationConversion {
LoopAnnotationConversion__anon8f9a159b0111::LoopAnnotationConversion19   LoopAnnotationConversion(LoopAnnotationAttr attr, Operation *op,
20                            LoopAnnotationTranslation &loopAnnotationTranslation,
21                            llvm::LLVMContext &ctx)
22       : attr(attr), op(op),
23         loopAnnotationTranslation(loopAnnotationTranslation), ctx(ctx) {}
24 
25   /// Converts this struct's loop annotation into a corresponding LLVMIR
26   /// metadata representation.
27   llvm::MDNode *convert();
28 
29   /// Conversion functions for different payload attribute kinds.
30   void addUnitNode(StringRef name);
31   void addUnitNode(StringRef name, BoolAttr attr);
32   void addI32NodeWithVal(StringRef name, uint32_t val);
33   void convertBoolNode(StringRef name, BoolAttr attr, bool negated = false);
34   void convertI32Node(StringRef name, IntegerAttr attr);
35   void convertFollowupNode(StringRef name, LoopAnnotationAttr attr);
36   void convertLocation(FusedLoc attr);
37 
38   /// Conversion functions for each for each loop annotation sub-attribute.
39   void convertLoopOptions(LoopVectorizeAttr options);
40   void convertLoopOptions(LoopInterleaveAttr options);
41   void convertLoopOptions(LoopUnrollAttr options);
42   void convertLoopOptions(LoopUnrollAndJamAttr options);
43   void convertLoopOptions(LoopLICMAttr options);
44   void convertLoopOptions(LoopDistributeAttr options);
45   void convertLoopOptions(LoopPipelineAttr options);
46   void convertLoopOptions(LoopPeeledAttr options);
47   void convertLoopOptions(LoopUnswitchAttr options);
48 
49   LoopAnnotationAttr attr;
50   Operation *op;
51   LoopAnnotationTranslation &loopAnnotationTranslation;
52   llvm::LLVMContext &ctx;
53   llvm::SmallVector<llvm::Metadata *> metadataNodes;
54 };
55 } // namespace
56 
addUnitNode(StringRef name)57 void LoopAnnotationConversion::addUnitNode(StringRef name) {
58   metadataNodes.push_back(
59       llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name)}));
60 }
61 
addUnitNode(StringRef name,BoolAttr attr)62 void LoopAnnotationConversion::addUnitNode(StringRef name, BoolAttr attr) {
63   if (attr && attr.getValue())
64     addUnitNode(name);
65 }
66 
addI32NodeWithVal(StringRef name,uint32_t val)67 void LoopAnnotationConversion::addI32NodeWithVal(StringRef name, uint32_t val) {
68   llvm::Constant *cstValue = llvm::ConstantInt::get(
69       llvm::IntegerType::get(ctx, /*NumBits=*/32), val, /*isSigned=*/false);
70   metadataNodes.push_back(
71       llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name),
72                               llvm::ConstantAsMetadata::get(cstValue)}));
73 }
74 
convertBoolNode(StringRef name,BoolAttr attr,bool negated)75 void LoopAnnotationConversion::convertBoolNode(StringRef name, BoolAttr attr,
76                                                bool negated) {
77   if (!attr)
78     return;
79   bool val = negated ^ attr.getValue();
80   llvm::Constant *cstValue = llvm::ConstantInt::getBool(ctx, val);
81   metadataNodes.push_back(
82       llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name),
83                               llvm::ConstantAsMetadata::get(cstValue)}));
84 }
85 
convertI32Node(StringRef name,IntegerAttr attr)86 void LoopAnnotationConversion::convertI32Node(StringRef name,
87                                               IntegerAttr attr) {
88   if (!attr)
89     return;
90   addI32NodeWithVal(name, attr.getInt());
91 }
92 
convertFollowupNode(StringRef name,LoopAnnotationAttr attr)93 void LoopAnnotationConversion::convertFollowupNode(StringRef name,
94                                                    LoopAnnotationAttr attr) {
95   if (!attr)
96     return;
97 
98   llvm::MDNode *node =
99       loopAnnotationTranslation.translateLoopAnnotation(attr, op);
100 
101   metadataNodes.push_back(
102       llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name), node}));
103 }
104 
convertLoopOptions(LoopVectorizeAttr options)105 void LoopAnnotationConversion::convertLoopOptions(LoopVectorizeAttr options) {
106   convertBoolNode("llvm.loop.vectorize.enable", options.getDisable(), true);
107   convertBoolNode("llvm.loop.vectorize.predicate.enable",
108                   options.getPredicateEnable());
109   convertBoolNode("llvm.loop.vectorize.scalable.enable",
110                   options.getScalableEnable());
111   convertI32Node("llvm.loop.vectorize.width", options.getWidth());
112   convertFollowupNode("llvm.loop.vectorize.followup_vectorized",
113                       options.getFollowupVectorized());
114   convertFollowupNode("llvm.loop.vectorize.followup_epilogue",
115                       options.getFollowupEpilogue());
116   convertFollowupNode("llvm.loop.vectorize.followup_all",
117                       options.getFollowupAll());
118 }
119 
convertLoopOptions(LoopInterleaveAttr options)120 void LoopAnnotationConversion::convertLoopOptions(LoopInterleaveAttr options) {
121   convertI32Node("llvm.loop.interleave.count", options.getCount());
122 }
123 
convertLoopOptions(LoopUnrollAttr options)124 void LoopAnnotationConversion::convertLoopOptions(LoopUnrollAttr options) {
125   if (auto disable = options.getDisable())
126     addUnitNode(disable.getValue() ? "llvm.loop.unroll.disable"
127                                    : "llvm.loop.unroll.enable");
128   convertI32Node("llvm.loop.unroll.count", options.getCount());
129   convertBoolNode("llvm.loop.unroll.runtime.disable",
130                   options.getRuntimeDisable());
131   addUnitNode("llvm.loop.unroll.full", options.getFull());
132   convertFollowupNode("llvm.loop.unroll.followup_unrolled",
133                       options.getFollowupUnrolled());
134   convertFollowupNode("llvm.loop.unroll.followup_remainder",
135                       options.getFollowupRemainder());
136   convertFollowupNode("llvm.loop.unroll.followup_all",
137                       options.getFollowupAll());
138 }
139 
convertLoopOptions(LoopUnrollAndJamAttr options)140 void LoopAnnotationConversion::convertLoopOptions(
141     LoopUnrollAndJamAttr options) {
142   if (auto disable = options.getDisable())
143     addUnitNode(disable.getValue() ? "llvm.loop.unroll_and_jam.disable"
144                                    : "llvm.loop.unroll_and_jam.enable");
145   convertI32Node("llvm.loop.unroll_and_jam.count", options.getCount());
146   convertFollowupNode("llvm.loop.unroll_and_jam.followup_outer",
147                       options.getFollowupOuter());
148   convertFollowupNode("llvm.loop.unroll_and_jam.followup_inner",
149                       options.getFollowupInner());
150   convertFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_outer",
151                       options.getFollowupRemainderOuter());
152   convertFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_inner",
153                       options.getFollowupRemainderInner());
154   convertFollowupNode("llvm.loop.unroll_and_jam.followup_all",
155                       options.getFollowupAll());
156 }
157 
convertLoopOptions(LoopLICMAttr options)158 void LoopAnnotationConversion::convertLoopOptions(LoopLICMAttr options) {
159   addUnitNode("llvm.licm.disable", options.getDisable());
160   addUnitNode("llvm.loop.licm_versioning.disable",
161               options.getVersioningDisable());
162 }
163 
convertLoopOptions(LoopDistributeAttr options)164 void LoopAnnotationConversion::convertLoopOptions(LoopDistributeAttr options) {
165   convertBoolNode("llvm.loop.distribute.enable", options.getDisable(), true);
166   convertFollowupNode("llvm.loop.distribute.followup_coincident",
167                       options.getFollowupCoincident());
168   convertFollowupNode("llvm.loop.distribute.followup_sequential",
169                       options.getFollowupSequential());
170   convertFollowupNode("llvm.loop.distribute.followup_fallback",
171                       options.getFollowupFallback());
172   convertFollowupNode("llvm.loop.distribute.followup_all",
173                       options.getFollowupAll());
174 }
175 
convertLoopOptions(LoopPipelineAttr options)176 void LoopAnnotationConversion::convertLoopOptions(LoopPipelineAttr options) {
177   convertBoolNode("llvm.loop.pipeline.disable", options.getDisable());
178   convertI32Node("llvm.loop.pipeline.initiationinterval",
179                  options.getInitiationinterval());
180 }
181 
convertLoopOptions(LoopPeeledAttr options)182 void LoopAnnotationConversion::convertLoopOptions(LoopPeeledAttr options) {
183   convertI32Node("llvm.loop.peeled.count", options.getCount());
184 }
185 
convertLoopOptions(LoopUnswitchAttr options)186 void LoopAnnotationConversion::convertLoopOptions(LoopUnswitchAttr options) {
187   addUnitNode("llvm.loop.unswitch.partial.disable",
188               options.getPartialDisable());
189 }
190 
convertLocation(FusedLoc location)191 void LoopAnnotationConversion::convertLocation(FusedLoc location) {
192   auto localScopeAttr =
193       dyn_cast_or_null<DILocalScopeAttr>(location.getMetadata());
194   if (!localScopeAttr)
195     return;
196   auto *localScope = dyn_cast<llvm::DILocalScope>(
197       loopAnnotationTranslation.moduleTranslation.translateDebugInfo(
198           localScopeAttr));
199   if (!localScope)
200     return;
201   llvm::Metadata *loc =
202       loopAnnotationTranslation.moduleTranslation.translateLoc(location,
203                                                                localScope);
204   metadataNodes.push_back(loc);
205 }
206 
convert()207 llvm::MDNode *LoopAnnotationConversion::convert() {
208   // Reserve operand 0 for loop id self reference.
209   auto dummy = llvm::MDNode::getTemporary(ctx, std::nullopt);
210   metadataNodes.push_back(dummy.get());
211 
212   if (FusedLoc startLoc = attr.getStartLoc())
213     convertLocation(startLoc);
214 
215   if (FusedLoc endLoc = attr.getEndLoc())
216     convertLocation(endLoc);
217 
218   addUnitNode("llvm.loop.disable_nonforced", attr.getDisableNonforced());
219   addUnitNode("llvm.loop.mustprogress", attr.getMustProgress());
220   // "isvectorized" is encoded as an i32 value.
221   if (BoolAttr isVectorized = attr.getIsVectorized())
222     addI32NodeWithVal("llvm.loop.isvectorized", isVectorized.getValue());
223 
224   if (auto options = attr.getVectorize())
225     convertLoopOptions(options);
226   if (auto options = attr.getInterleave())
227     convertLoopOptions(options);
228   if (auto options = attr.getUnroll())
229     convertLoopOptions(options);
230   if (auto options = attr.getUnrollAndJam())
231     convertLoopOptions(options);
232   if (auto options = attr.getLicm())
233     convertLoopOptions(options);
234   if (auto options = attr.getDistribute())
235     convertLoopOptions(options);
236   if (auto options = attr.getPipeline())
237     convertLoopOptions(options);
238   if (auto options = attr.getPeeled())
239     convertLoopOptions(options);
240   if (auto options = attr.getUnswitch())
241     convertLoopOptions(options);
242 
243   ArrayRef<AccessGroupAttr> parallelAccessGroups = attr.getParallelAccesses();
244   if (!parallelAccessGroups.empty()) {
245     SmallVector<llvm::Metadata *> parallelAccess;
246     parallelAccess.push_back(
247         llvm::MDString::get(ctx, "llvm.loop.parallel_accesses"));
248     for (AccessGroupAttr accessGroupAttr : parallelAccessGroups)
249       parallelAccess.push_back(
250           loopAnnotationTranslation.getAccessGroup(accessGroupAttr));
251     metadataNodes.push_back(llvm::MDNode::get(ctx, parallelAccess));
252   }
253 
254   // Create loop options and set the first operand to itself.
255   llvm::MDNode *loopMD = llvm::MDNode::get(ctx, metadataNodes);
256   loopMD->replaceOperandWith(0, loopMD);
257 
258   return loopMD;
259 }
260 
261 llvm::MDNode *
translateLoopAnnotation(LoopAnnotationAttr attr,Operation * op)262 LoopAnnotationTranslation::translateLoopAnnotation(LoopAnnotationAttr attr,
263                                                    Operation *op) {
264   if (!attr)
265     return nullptr;
266 
267   llvm::MDNode *loopMD = lookupLoopMetadata(attr);
268   if (loopMD)
269     return loopMD;
270 
271   loopMD =
272       LoopAnnotationConversion(attr, op, *this, this->llvmModule.getContext())
273           .convert();
274   // Store a map from this Attribute to the LLVM metadata in case we
275   // encounter it again.
276   mapLoopMetadata(attr, loopMD);
277   return loopMD;
278 }
279 
280 llvm::MDNode *
getAccessGroup(AccessGroupAttr accessGroupAttr)281 LoopAnnotationTranslation::getAccessGroup(AccessGroupAttr accessGroupAttr) {
282   auto [result, inserted] =
283       accessGroupMetadataMapping.insert({accessGroupAttr, nullptr});
284   if (inserted)
285     result->second = llvm::MDNode::getDistinct(llvmModule.getContext(), {});
286   return result->second;
287 }
288 
289 llvm::MDNode *
getAccessGroups(AccessGroupOpInterface op)290 LoopAnnotationTranslation::getAccessGroups(AccessGroupOpInterface op) {
291   ArrayAttr accessGroups = op.getAccessGroupsOrNull();
292   if (!accessGroups || accessGroups.empty())
293     return nullptr;
294 
295   SmallVector<llvm::Metadata *> groupMDs;
296   for (AccessGroupAttr group : accessGroups.getAsRange<AccessGroupAttr>())
297     groupMDs.push_back(getAccessGroup(group));
298   if (groupMDs.size() == 1)
299     return llvm::cast<llvm::MDNode>(groupMDs.front());
300   return llvm::MDNode::get(llvmModule.getContext(), groupMDs);
301 }
302