Lines Matching +full:row +full:- +full:stride

1 //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===//
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 /// This pass is always enabled and it skips when it is not -O0 and has no
11 /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx
13 /// point which post-dominate all the shape and dominate all amx intrinsics.
17 //===----------------------------------------------------------------------===//
45 #define DEBUG_TYPE "lower-amx-intrinsics"
50 return FVT->getNumElements() == 256 &&
51 FVT->getElementType()->isIntegerTy(32);
57 X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(false), cl::Hidden,
77 IRBuilderBase &B, Value *Row, Value *Col,
78 Value *Ptr, Value *Stride, Value *Tile);
87 Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS,
107 LLVMContext &Ctx = Preheader->getContext();
109 BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit);
111 BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit);
113 BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit);
119 PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator()->getIterator());
120 IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader);
126 IV->addIncoming(Inc, Latch);
128 BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
129 BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
130 PreheaderBr->setSuccessor(0, Header);
140 L->addBasicBlockToLoop(Header, *LI);
141 L->addBasicBlockToLoop(Body, *LI);
142 L->addBasicBlockToLoop(Latch, *LI);
149 BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row,
150 Value *Col, Value *Ptr, Value *Stride, Value *Tile) {
155 RowLoop = LI->AllocateLoop();
156 ColLoop = LI->AllocateLoop();
157 RowLoop->addChildLoop(ColLoop);
158 if (Loop *ParentL = LI->getLoopFor(Start))
159 ParentL->addChildLoop(RowLoop);
161 LI->addTopLevelLoop(RowLoop);
164 BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
166 BasicBlock *RowLatch = RowBody->getSingleSuccessor();
171 BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
172 BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
173 BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
174 Value *CurrentRow = &*RowLoopHeader->begin();
175 Value *CurrentCol = &*ColLoopHeader->begin();
182 B.SetInsertPoint(ColBody->getTerminator());
183 Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType());
184 Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType());
186 B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
191 // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec,
193 B.SetInsertPoint(RowLoopHeader->getTerminator());
195 PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row");
196 VecCPhiRowLoop->addIncoming(VecZero, Start);
199 // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body
201 B.SetInsertPoint(ColLoopHeader->getTerminator());
203 VecPhi->addIncoming(VecCPhiRowLoop, RowBody);
210 B.SetInsertPoint(ColBody->getTerminator());
213 VecPhi->addIncoming(ResVec, ColLoopLatch);
214 VecCPhiRowLoop->addIncoming(ResVec, RowLatch);
219 Value *Vec = BitCast->getOperand(0);
220 assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx");
222 // %mul = mul i16 %row.iv, i16 16
226 B.SetInsertPoint(ColBody->getTerminator());
242 IRBuilderBase &B, Value *Row,
267 RowLoop = LI->AllocateLoop();
268 ColLoop = LI->AllocateLoop();
269 InnerLoop = LI->AllocateLoop();
270 ColLoop->addChildLoop(InnerLoop);
271 RowLoop->addChildLoop(ColLoop);
272 if (Loop *ParentL = LI->getLoopFor(Start))
273 ParentL->addChildLoop(RowLoop);
275 LI->addTopLevelLoop(RowLoop);
278 BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
280 BasicBlock *RowLatch = RowBody->getSingleSuccessor();
285 BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
287 B.SetInsertPoint(ColBody->getTerminator());
292 BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
293 BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
294 BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor();
295 BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor();
296 Value *CurrentRow = &*RowLoopHeader->begin();
297 Value *CurrentCol = &*ColLoopHeader->begin();
298 Value *CurrentInner = &*InnerLoopHeader->begin();
302 Value *VecC = BitCastAcc->getOperand(0);
303 assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx");
306 // to vector. However with -O0, it doesn't happen.
308 Value *VecA = BitCastLHS->getOperand(0);
309 assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx");
311 Value *VecB = BitCastRHS->getOperand(0);
312 assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx");
315 // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC,
318 // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [
320 B.SetInsertPoint(RowLoopHeader->getTerminator());
321 PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row");
322 VecCPhiRowLoop->addIncoming(VecC, Start);
324 PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row");
325 VecDPhiRowLoop->addIncoming(VecZero, Start);
328 // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row,
333 // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD,
337 B.SetInsertPoint(ColLoopHeader->getTerminator());
339 VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody);
341 VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody);
350 B.SetInsertPoint(InnerLoopHeader->getTerminator());
352 VecCPhi->addIncoming(VecCPhiColLoop, ColBody);
354 B.SetInsertPoint(InnerBody->getTerminator());
455 B.SetInsertPoint(ColLoopLatch->getTerminator());
459 VecCPhi->addIncoming(NewVecC, InnerLoopLatch);
460 VecCPhiRowLoop->addIncoming(NewVecC, RowLatch);
461 VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch);
462 VecDPhiRowLoop->addIncoming(NewVecD, RowLatch);
463 VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch);
487 BasicBlock *Start = InsertI->getParent();
489 SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
495 Builder.SetInsertPoint(End, End->getFirstNonPHIIt());
498 // Delete TileDP intrinsic and do some clean-up.
499 for (Use &U : llvm::make_early_inc_range(TileDP->uses())) {
503 I->replaceAllUsesWith(ResVec);
504 I->eraseFromParent();
507 TileDP->replaceAllUsesWith(ResAMX);
508 TileDP->eraseFromParent();
514 Value *M, *N, *Ptr, *Stride, *Tile;
518 m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride)));
522 m_Value(Stride), m_Value(Tile)));
528 Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2));
529 BasicBlock *Start = InsertI->getParent();
531 SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
539 Builder.SetInsertPoint(End, End->getFirstNonPHIIt());
542 // Delete tileloadd6 intrinsic and do some clean-up
543 for (Use &U : llvm::make_early_inc_range(TileLoadStore->uses())) {
547 I->replaceAllUsesWith(ResVec);
548 I->eraseFromParent();
551 TileLoadStore->replaceAllUsesWith(ResAMX);
553 TileLoadStore->eraseFromParent();
561 for (Use &U : llvm::make_early_inc_range(TileZero->uses())) {
565 I->replaceAllUsesWith(VecZero);
566 I->eraseFromParent();
569 TileZero->eraseFromParent();
577 for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
579 switch (Inst->getIntrinsicID()) {
598 switch (Inst->getIntrinsicID()) {
646 TM->getOptLevel() != CodeGenOptLevel::None)
650 auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
652 auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;