Lines Matching full:env
248 static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
250 for (OpOperand &t : env.op()->getOpOperands()) {
251 const TensorId tid = env.makeTensorId(t.getOperandNumber());
252 const auto map = env.op().getMatchingIndexingMap(&t);
258 assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
271 if (!findDepIdxSet(env.merger(), tid, l, a, lt))
274 if (!findAffine(env.merger(), tid, l, a, lt))
287 static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
288 linalg::GenericOp op = env.op();
296 env.emitter().initializeLoopEmit(
337 static Value genIndex(CodegenEnv &env, OpOperand *t) {
338 const auto map = env.op().getMatchingIndexingMap(t);
344 const LoopId idx = env.makeLoopId(cast<AffineDimExpr>(a).getPosition());
345 return env.getLoopVar(idx);
349 static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
351 const Location loc = env.op().getLoc();
352 const TensorId tid = env.makeTensorId(t->getOperandNumber());
353 const auto map = env.op().getMatchingIndexingMap(t);
357 const auto pos = env.emitter().getValPosits(tid);
361 if (env.options().sparseEmitStrategy == SparseEmitStrategy::kSparseIterator)
369 const auto lvlCrd = env.emitter().genAffine(builder, loc, lvlExpr);
373 return env.emitter().getValBuffer()[tid];
377 static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder,
379 linalg::GenericOp op = env.op();
382 if (!env.isExpand()) {
387 Value index = genIndex(env, t);
388 return builder.create<memref::LoadOp>(loc, env.getExpandValues(), index);
392 static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder,
394 linalg::GenericOp op = env.op();
396 Value identity = env.getCustomRedId();
398 if (!env.isExpand())
401 Value values = env.getExpandValues();
402 Value filled = env.getExpandFilled();
403 Value index = genIndex(env, t);
426 static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
428 linalg::GenericOp op = env.op();
431 if (!env.isExpand()) {
435 env.emitter().getLoopIVsRange(), env.getCurrentDepth() - numLoops));
436 Value chain = env.getInsertionChain();
437 if (env.isValidLexInsert()) {
445 Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(),
447 env.updateInsertionChain(out);
450 if (!hasAnySparseType(env.op().getInputs().getTypes())) {
459 env.updateInsertionChain(sparseOut);
469 Value values = env.getExpandValues();
470 Value filled = env.getExpandFilled();
471 Value added = env.getExpandAdded();
472 Value count = env.getExpandCount();
473 Value index = genIndex(env, t);
494 env.updateExpandCount(ifOp.getResult(0));
499 static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
501 Value val = env.exp(exp).val;
505 linalg::GenericOp op = env.op();
507 OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
513 if (env.isSparseOutput(t)) {
514 if (env.isCustomReduc())
515 return genInsertionLoadReduce(env, builder, t);
516 return genInsertionLoad(env, builder, t);
521 Value ptr = genSubscript(env, builder, t, args);
523 assert(env.options().sparseEmitStrategy ==
532 static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
538 assert(env.exp(exp).kind == TensorExp::Kind::kUnary ||
539 env.exp(exp).kind == TensorExp::Kind::kBinary ||
540 env.exp(exp).kind == TensorExp::Kind::kReduce);
544 if (env.isReduc()) {
545 env.updateReduc(rhs);
549 linalg::GenericOp op = env.op();
552 if (!env.isSparseOutput(t)) {
554 Value ptr = genSubscript(env, builder, t, args);
559 if (env.exp(exp).kind != TensorExp::Kind::kSelect) {
560 genInsertionStore(env, builder, t, rhs);
564 Value chain = env.getInsertionChain();
569 assert(env.exp(exp).val);
570 Value v0 = env.exp(exp).val;
571 genInsertionStore(env, builder, t, v0);
572 env.merger().clearExprValue(exp);
574 Value mchain = env.getInsertionChain();
580 env.updateInsertionChain(ifOp->getResult(0));
585 inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
586 return env.exp(exp).val;
594 static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
600 linalg::GenericOp op = env.op();
602 const TensorId tid = env.makeTensorId(arg.getArgNumber());
606 Value ptr = genSubscript(env, rewriter, t, args);
612 return env.getLoopVar(env.makeLoopId(indexOp.getDim()));
619 i, relinkBranch(env, rewriter, block, def->getOperand(i)));
628 static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
632 linalg::GenericOp op = env.op();
634 const TensorExp &exp = env.exp(e);
637 return genTensorLoad(env, rewriter, e);
639 return genInvariantValue(env, e);
641 return env.getLoopVar(exp.loop);
644 env.startCustomReduc(e); // enter custom
650 env.exp(exp.children.e0).kind == TensorExp::Kind::kSynZero) {
651 v1 = genExp(env, rewriter, exp.children.e1);
654 env.exp(exp.children.e1).kind == TensorExp::Kind::kSynZero) {
655 v0 = genExp(env, rewriter, exp.children.e0);
658 v0 = genExp(env, rewriter, exp.children.e0);
659 v1 = genExp(env, rewriter, exp.children.e1);
666 ee = env.merger().buildExp(rewriter, loc, e, v0, v1);
673 ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee);
678 env.endCustomReduc(); // exit custom
681 env.merger().setExprValue(e, v0); // Preserve value for later use.
687 static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
691 if (env.exp(exp).kind == TensorExp::Kind::kTensor) {
693 linalg::GenericOp op = env.op();
694 OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
716 if (env.isCustomReduc()) {
717 if (!env.isReduc())
718 env.startReduc(exp, env.getCustomRedId());
720 env.startReduc(exp, genTensorLoad(env, builder, exp));
722 if (env.hasSparseOutput())
723 env.startValidLexInsert(
724 constantI1(builder, env.op().getLoc(), false));
726 if (!env.isCustomReduc() || env.isReduc())
727 genTensorStore(env, builder, exp, env.endReduc());
728 if (env.hasSparseOutput())
729 env.endValidLexInsert();
734 env.merger().setExprValue(exp, genTensorLoad(env, builder, exp));
736 env.merger().clearExprValue(exp);
739 } else if (env.exp(exp).kind != TensorExp::Kind::kInvariant &&
740 env.exp(exp).kind != TensorExp::Kind::kLoopVar &&
741 env.exp(exp).kind != TensorExp::Kind::kSynZero) {
745 if (env.exp(exp).kind == TensorExp::Kind::kReduce)
746 env.startCustomReduc(exp); // enter custom
747 const ExprId e0 = env.exp(exp).children.e0;
748 const ExprId e1 = env.exp(exp).children.e1;
749 genInvariants(env, builder, e0, curr, isStart);
750 genInvariants(env, builder, e1, curr, isStart);
751 if (env.exp(exp).kind == TensorExp::Kind::kReduce)
752 env.endCustomReduc(); // exit custom
757 static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr,
759 linalg::GenericOp op = env.op();
761 if (!env.atExpandLevel(lhs, op.getRank(lhs), curr))
763 assert(!env.isReduc());
780 env.startExpand(r.getResult(0), r.getResult(1), r.getResult(2),
785 indices.push_back(env.emitter().getLoopIV(i));
786 Value values = env.getExpandValues();
787 Value filled = env.getExpandFilled();
788 Value added = env.getExpandAdded();
789 Value count = env.getExpandCount();
790 Value chain = env.getInsertionChain();
793 env.updateInsertionChain(compress);
794 env.endExpand();
801 static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
803 if (env.hasSparseOutput())
806 if (env.isExpand())
809 switch (env.options().parallelizationStrategy) {
826 static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
828 linalg::GenericOp op = env.op();
830 bool isSparse = llvm::any_of(tidLvls, [curr, &env](TensorLevel tidLvl) {
834 const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, curr);
837 return isParallelFor(env, /*isOuter=*/curr == 0, isSparse);
843 static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
847 Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
849 return env.emitter().enterCoIterationOverTensorsAtLvls(
850 builder, env.op().getLoc(), tidLvls, numCases, reduc, tryParallel,
859 static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr,
862 bool tryParallel = shouldTryParallize(env, curr, tidLvls);
863 return genCoIteration(env, builder, tidLvls, numCases, tryParallel,
868 static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
870 Location loc = env.op().getLoc();
872 if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
882 if (env.isReduc()) {
883 yields.push_back(env.getReduc());
884 env.updateReduc(ifOp.getResult(y++));
885 if (env.isValidLexInsert()) {
886 yields.push_back(env.getValidLexInsert());
887 env.updateValidLexInsert(ifOp.getResult(y++));
890 if (env.isExpand()) {
891 yields.push_back(env.getExpandCount());
892 env.updateExpandCount(ifOp->getResult(y++));
894 if (env.getInsertionChain()) {
895 yields.push_back(env.getInsertionChain());
896 env.updateInsertionChain(ifOp->getResult(y++));
908 static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder,
912 assert(allCase == curCase || env.merger().latGT(allCase, curCase));
913 const BitVector &allCaseBits = env.merger().lat(allCase).simple;
914 const BitVector &curCaseBits = env.merger().lat(curCase).simple;
923 env.emitter().enterCurrentCoIterationCase(builder, env.op().getLoc(), caseBit,
928 static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
930 Location loc = env.op().getLoc();
933 env.merger().foreachTensorLoopId(
943 auto stt = getSparseTensorType(env.op().getInputs()[tid]);
946 assert(curr == env.merger().loop(b));
950 const Value crd = env.emitter().getCoord(tid, *lvl);
951 const Value lvar = env.getLoopVar(curr);
960 if (env.isReduc()) {
961 types.push_back(env.getReduc().getType());
962 if (env.isValidLexInsert())
963 types.push_back(env.getValidLexInsert().getType());
965 if (env.isExpand())
967 if (env.getInsertionChain())
968 types.push_back(env.getInsertionChain().getType());
975 static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
979 if (env.isReduc()) {
980 operands.push_back(env.getReduc());
981 env.updateReduc(redInput);
982 if (env.isValidLexInsert()) {
984 operands.push_back(constantI1(builder, env.op().getLoc(), true));
985 env.updateValidLexInsert(validIns);
988 if (env.isExpand()) {
989 operands.push_back(env.getExpandCount());
990 env.updateExpandCount(cntInput);
992 if (env.getInsertionChain()) {
993 operands.push_back(env.getInsertionChain());
994 env.updateInsertionChain(insInput);
997 builder.create<scf::YieldOp>(env.op().getLoc(), operands);
1006 CodegenEnv &env, LatPointId li, LoopId curr,
1008 const BitVector &simple = env.lat(li).simple;
1009 const TensorId outTid = env.merger().getOutTensorID();
1010 const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr);
1014 env.merger().foreachTensorLoopId(
1019 callback(env.makeTensorLevel(tid, *lvl), nullptr);
1027 if (env.merger().getSynTensorID() == tid) {
1035 assert(curr == env.getCurrentDepth());
1043 callback(env.makeTensorLevel(tid, *lvl), nullptr);
1046 callback(env.makeTensorLevel(tid, *lvl), nullptr);
1049 linalg::GenericOp op = env.op();
1074 assert(curr == env.getCurrentDepth());
1081 callback(env.makeTensorLevel(tid, l), exp);
1088 if (isDenseLT(env.lt(outTid, curr))) {
1089 auto stt = getSparseTensorType(env.op().getOutputs().front());
1092 // linearized env.
1096 callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
1103 callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr);
1111 (!hasNonUnique || env.options().sparseEmitStrategy ==
1117 static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
1119 assert(!env.getLoopVar(curr));
1121 genInvariants(env, builder, exp, curr, /*isStart=*/true);
1123 genExpand(env, builder, curr, /*isStart=*/true);
1125 const LatPointId l0 = env.set(lts)[0];
1128 getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
1137 env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
1141 for (const LatPointId li : env.set(lts).drop_front())
1142 if (!env.merger().hasAnySparse(env.lat(li).simple))
1149 static void genConstantDenseAddressFromLevel(CodegenEnv &env,
1153 linalg::GenericOp op = env.op();
1160 const TensorId tid = env.makeTensorId(input->getOperandNumber());
1167 env.emitter().locateLvlAtAffineAddress(
1168 builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
1179 static void genInitConstantDenseAddress(CodegenEnv &env,
1181 for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
1182 genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
1187 CodegenEnv &env, LatPointId li, LoopId curr,
1190 return getAllTidLvlsInLatPoints(env, li, curr,
1200 static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
1213 translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls);
1216 Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls);
1217 Location loc = env.op().getLoc();
1219 env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp);
1227 for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) {
1228 if (tid != env.merger().getOutTensorID() &&
1229 tid != env.merger().getSynTensorID())
1230 genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1);
1237 static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
1242 if (env.isReduc() && env.isValidLexInsert())
1243 env.updateValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
1246 finalizeWhileOp(env, rewriter, needsUniv);
1250 env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
1251 env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc);
1258 static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
1260 assert(!env.getLoopVar(at));
1261 env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc());
1263 genInvariants(env, builder, exp, at, /*isStart=*/false);
1265 genExpand(env, builder, at, /*isStart=*/false);
1271 static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
1273 assert(curr == env.getCurrentDepth());
1276 if (curr == env.getLoopNum()) {
1277 Value rhs = genExp(env, rewriter, exp);
1278 genTensorStore(env, rewriter, exp, rhs);
1284 env.merger().optimizeSet(env.merger().buildLattices(exp, curr));
1287 bool needsUniv = startLoopSeq(env, rewriter, exp, curr, lts);
1291 const unsigned lsize = env.set(lts).size();
1292 if (env.generatingSparseIterator()) {
1294 const LatPointId li = env.set(lts)[0];
1296 startLoop(env, rewriter, curr, li, lsize, needsUniv);
1298 // We cannot change this to `for (const LatPointId li : env.set(lts))`
1302 const LatPointId lj = env.set(lts)[j];
1303 const ExprId ej = env.lat(lj).exp;
1306 env.genLoopBoundary([&, curr, j, li, lj](MutableArrayRef<Value> reduc) {
1307 genCoIterationCase(env, rewriter, /*caseIdx*/ j, li, lj, reduc);
1308 genStmt(env, rewriter, ej, curr + 1);
1311 rewriter.create<sparse_tensor::YieldOp>(env.op().getLoc());
1314 // endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1316 genStmt(env, rewriter, ej, curr + 1);
1320 needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1324 const LatPointId li = env.set(lts)[i];
1327 startLoop(env, rewriter, curr, li, lsize, needsUniv);
1331 Value redInput = env.getReduc();
1332 Value cntInput = env.getExpandCount();
1333 Value insInput = env.getInsertionChain();
1334 Value validIns = env.getValidLexInsert();
1335 // We cannot change this to `for (const LatPointId lj : env.set(lts))`
1339 const LatPointId lj = env.set(lts)[j];
1340 const ExprId ej = env.lat(lj).exp;
1341 if (li == lj || env.merger().latGT(li, lj)) {
1344 scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
1345 genStmt(env, rewriter, ej, curr + 1);
1346 endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1348 genStmt(env, rewriter, ej, curr + 1);
1354 needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1359 endLoopSeq(env, rewriter, exp, curr);
1360 assert(curr == env.getCurrentDepth());
1364 static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
1365 linalg::GenericOp op = env.op();
1374 if (Value chain = env.getInsertionChain()) {
1382 Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()];
1438 CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
1439 if (!findSparseAnnotations(env, needIdxRed))
1462 if (failed(env.initTensorExp()))
1466 env.startEmit(options.sparseEmitStrategy);
1467 genBuffers(env, rewriter);
1471 genInitConstantDenseAddress(env, rewriter);
1472 genStmt(env, rewriter, env.getExprId(), 0);
1473 genResult(env, rewriter);