================
@@ -14175,27 +14222,350 @@ bool SemaOpenMP::checkTransformableLoopNest(
         return false;
       },
       [&OriginalInits](OMPLoopBasedDirective *Transform) {
-        Stmt *DependentPreInits;
-        if (auto *Dir = dyn_cast<OMPTileDirective>(Transform))
-          DependentPreInits = Dir->getPreInits();
-        else if (auto *Dir = dyn_cast<OMPStripeDirective>(Transform))
-          DependentPreInits = Dir->getPreInits();
-        else if (auto *Dir = dyn_cast<OMPUnrollDirective>(Transform))
-          DependentPreInits = Dir->getPreInits();
-        else if (auto *Dir = dyn_cast<OMPReverseDirective>(Transform))
-          DependentPreInits = Dir->getPreInits();
-        else if (auto *Dir = dyn_cast<OMPInterchangeDirective>(Transform))
-          DependentPreInits = Dir->getPreInits();
-        else
-          llvm_unreachable("Unhandled loop transformation");
-
-        appendFlattenedStmtList(OriginalInits.back(), DependentPreInits);
+        updatePreInits(Transform, OriginalInits);
       });
   assert(OriginalInits.back().empty() && "No preinit after innermost loop");
   OriginalInits.pop_back();
   return Result;
 }
 
+// Counts the total number of nested loops, including the outermost loop (the
+// original loop). PRECONDITION of this visitor is that it must be invoked from
+// the original loop to be analyzed. The traversal is stop for Decl's and
+// Expr's given that they may contain inner loops that must not be counted.
+//
+// Example AST structure for the code:
+//
+// int main() {
+//     #pragma omp fuse
+//     {
+//         for (int i = 0; i < 100; i++) {    <-- Outer loop
+//             []() {
+//                 for(int j = 0; j < 100; j++) {}  <-- NOT A LOOP
+//             };
+//             for(int j = 0; j < 5; ++j) {}    <-- Inner loop
+//         }
+//         for (int r = 0; i < 100; i++) {    <-- Outer loop
+//             struct LocalClass {
+//                 void bar() {
+//                     for(int j = 0; j < 100; j++) {}  <-- NOT A LOOP
+//                 }
+//             };
+//             for(int k = 0; k < 10; ++k) {}    <-- Inner loop
+//             {x = 5; for(k = 0; k < 10; ++k) x += k; x}; <-- NOT A LOOP
+//         }
+//     }
+// }
+// Result: Loop 'i' contains 2 loops, Loop 'r' also contains 2 loops
+class NestedLoopCounterVisitor : public DynamicRecursiveASTVisitor {
+private:
+  unsigned NestedLoopCount = 0;
+
+public:
+  explicit NestedLoopCounterVisitor() {}
+
+  unsigned getNestedLoopCount() const { return NestedLoopCount; }
+
+  bool VisitForStmt(ForStmt *FS) override {
+    ++NestedLoopCount;
+    return true;
+  }
+
+  bool VisitCXXForRangeStmt(CXXForRangeStmt *FRS) override {
+    ++NestedLoopCount;
+    return true;
+  }
+
+  bool TraverseStmt(Stmt *S) override {
+    if (!S)
+      return true;
+
+    // Skip traversal of all expressions, including special cases like
+    // LambdaExpr, StmtExpr, BlockExpr, and RequiresExpr. These expressions
+    // may contain inner statements (and even loops), but they are not part
+    // of the syntactic body of the surrounding loop structure.
+    //  Therefore must not be counted
+    if (isa<Expr>(S))
+      return true;
+
+    // Only recurse into CompoundStmt (block {}) and loop bodies
+    if (isa<CompoundStmt>(S) || isa<ForStmt>(S) || isa<CXXForRangeStmt>(S)) {
+      return DynamicRecursiveASTVisitor::TraverseStmt(S);
+    }
+
+    // Stop traversal of the rest of statements, that break perfect
+    // loop nesting, such as control flow (IfStmt, SwitchStmt...)
+    return true;
+  }
+
+  bool TraverseDecl(Decl *D) override {
+    // Stop in the case of finding a declaration, it is not important
+    // in order to find nested loops (Possible CXXRecordDecl, RecordDecl,
+    // FunctionDecl...)
+    return true;
+  }
+};
+
+bool SemaOpenMP::analyzeLoopSequence(
+    Stmt *LoopSeqStmt, unsigned &LoopSeqSize, unsigned &NumLoops,
+    SmallVectorImpl<OMPLoopBasedDirective::HelperExprs> &LoopHelpers,
+    SmallVectorImpl<Stmt *> &ForStmts,
+    SmallVectorImpl<SmallVector<Stmt *, 0>> &OriginalInits,
+    SmallVectorImpl<SmallVector<Stmt *, 0>> &TransformsPreInits,
+    SmallVectorImpl<SmallVector<Stmt *, 0>> &LoopSequencePreInits,
+    SmallVectorImpl<OMPLoopCategory> &LoopCategories, ASTContext &Context,
+    OpenMPDirectiveKind Kind) {
+
+  VarsWithInheritedDSAType TmpDSA;
+  QualType BaseInductionVarType;
+  // Helper Lambda to handle storing initialization and body statements for 
both
+  // ForStmt and CXXForRangeStmt and checks for any possible mismatch between
+  // induction variables types
+  auto storeLoopStatements = [&OriginalInits, &ForStmts, &BaseInductionVarType,
+                              this, &Context](Stmt *LoopStmt) {
+    if (auto *For = dyn_cast<ForStmt>(LoopStmt)) {
+      OriginalInits.back().push_back(For->getInit());
+      ForStmts.push_back(For);
+      // Extract induction variable
+      if (auto *InitStmt = dyn_cast_or_null<DeclStmt>(For->getInit())) {
+        if (auto *InitDecl = dyn_cast<VarDecl>(InitStmt->getSingleDecl())) {
+          QualType InductionVarType = InitDecl->getType().getCanonicalType();
+
+          // Compare with first loop type
+          if (BaseInductionVarType.isNull()) {
+            BaseInductionVarType = InductionVarType;
+          } else if (!Context.hasSameType(BaseInductionVarType,
+                                          InductionVarType)) {
+            Diag(InitDecl->getBeginLoc(),
+                 diag::warn_omp_different_loop_ind_var_types)
+                << getOpenMPDirectiveName(OMPD_fuse) << BaseInductionVarType
+                << InductionVarType;
+          }
+        }
+      }
+    } else {
+      auto *CXXFor = cast<CXXForRangeStmt>(LoopStmt);
+      OriginalInits.back().push_back(CXXFor->getBeginStmt());
+      ForStmts.push_back(CXXFor);
+    }
+  };
+
+  // Helper lambda functions to encapsulate the processing of different
+  // derivations of the canonical loop sequence grammar
+  //
+  // Modularized code for handling loop generation and transformations
+  auto analyzeLoopGeneration = [&storeLoopStatements, &LoopHelpers,
+                                &OriginalInits, &TransformsPreInits,
+                                &LoopCategories, &LoopSeqSize, &NumLoops, Kind,
+                                &TmpDSA, &ForStmts, &Context,
+                                &LoopSequencePreInits, this](Stmt *Child) {
----------------
alexey-bataev wrote:

```suggestion
  auto AnalyzeLoopGeneration = [&](Stmt *Child) {
```

https://github.com/llvm/llvm-project/pull/139293
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to