NEURON
merge_top_level_blocks_visitor.hpp
Go to the documentation of this file.
1 /*
2  * Copyright 2025 EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #pragma once
9 
10 /**
11  * \file
12  * \brief \copybrief nmodl::visitor::MergeTopLevelBlocksVisitor
13  */
14 
15 #include "visitors/ast_visitor.hpp"
17 #include "ast/all.hpp"
18 
19 namespace nmodl {
20 namespace visitor {
21 
22 /**
23  * \addtogroup visitor_classes
24  * \{
25  */
26 
27 /**
28  * \class MergeTopLevelBlocksVisitor
29  * \brief Visitor which merges given top-level blocks into one
30  *
31  * This template takes two arguments which both describe the type of top-level
32  * block. The arguments must match and refer to the same type of block!
33  *
34  * The first argument is a subclass of nmodl::ast::Ast
35  *
36  * The second argument is an instance of nmodl::ast::AstNodeType
37  */
38 template <typename ast_class,
39  ast::AstNodeType ast_type,
40  // only enable it for descendents of `ast::Block`, and only if it's not an abstract class
41  typename = std::enable_if_t<std::is_base_of_v<ast::Block, ast_class> &&
42  !std::is_abstract_v<ast_class>>>
44  public:
46 
47  void visit_program(ast::Program& node) override {
48  // check if there is > 1 block in total (including nested includes)
49  if (collect_nodes(node, {ast_type}).size() <= 1) {
50  return;
51  }
52 
53  // collect all statements from all blocks (including nested includes)
54  ast::StatementVector statements;
55  std::unordered_set<ast::Node*> blocks_to_delete;
56  // since ast::Program::erase_node can only delete top-level nodes, we
57  // need to keep track of the found includes, as well as all of the
58  // _other_ statements (since ast::Include does not provide an
59  // erase_node function, but only set_blocks)
60  std::unordered_map<std::shared_ptr<ast::Include>, ast::NodeVector> include_blocks_to_keep;
61 
62  const auto& blocks = node.get_blocks();
64  statements,
65  blocks_to_delete,
66  include_blocks_to_keep);
67 
68  // insert new top-level block which has all the collected statements
69  auto statement_block = ast::StatementBlock(statements);
70  auto toplevel_block = ast_class(statement_block.clone());
71  node.emplace_back_node(toplevel_block.clone());
72 
73  // delete all of the previously-found top-level blocks
74  node.erase_node(blocks_to_delete);
75 
76  // also delete all of the blocks from INCLUDE blocks
77  for (const auto& [include_block, blocks_to_keep]: include_blocks_to_keep) {
78  include_block->set_blocks(blocks_to_keep);
79  }
80  }
81 
82  private:
83  // Helper function to collect all top-level blocks in an INCLUDE except `ast_class` ones
86  for (const auto& block: node.get_blocks()) {
87  // only insert if it's not an instance of `ast_class`
88  if (std::dynamic_pointer_cast<ast_class>(block) == nullptr) {
89  result.push_back(block);
90  }
91  }
92  return result;
93  }
94 
95  // Helper function to collect statements from a NodeVector (including nested includes)
97  const ast::NodeVector& blocks,
98  ast::StatementVector& statements,
99  std::unordered_set<ast::Node*>& blocks_to_delete,
100  std::unordered_map<std::shared_ptr<ast::Include>, ast::NodeVector>&
101  include_blocks_to_keep) {
102  for (auto& block: blocks) {
103  auto include_block = std::dynamic_pointer_cast<ast::Include>(block);
104  if (include_block) {
105  // Recursively process nested includes
106  const auto& included_blocks = include_block->get_blocks();
107  include_blocks_to_keep[include_block] = collect_include_except(*include_block);
108  collect_statements_from_vector(included_blocks,
109  statements,
110  blocks_to_delete,
111  include_blocks_to_keep);
112  } else {
113  auto temp_block = std::dynamic_pointer_cast<ast_class>(block);
114  // check if it's the correct type
115  if (temp_block) {
116  auto statement_block = temp_block->get_statement_block();
117  // if block is not empty, copy statement-block into vector
118  if (statement_block && !statement_block->get_statements().empty()) {
119  statements.push_back(
120  std::make_shared<ast::ExpressionStatement>(statement_block));
121  }
122  blocks_to_delete.insert(block.get());
123  }
124  }
125  }
126  }
127 };
128 
129 /** \} */ // end of visitor_classes
130 
131 } // namespace visitor
132 } // namespace nmodl
Auto generated AST classes declaration.
Concrete visitor for all AST classes.
Represents an INCLUDE statement in NMODL.
Definition: include.hpp:39
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
Represents block encapsulating list of statements.
Concrete visitor for all AST classes.
Definition: ast_visitor.hpp:37
Visitor which merges given top-level blocks into one.
ast::NodeVector collect_include_except(const ast::Include &node) const
void collect_statements_from_vector(const ast::NodeVector &blocks, ast::StatementVector &statements, std::unordered_set< ast::Node * > &blocks_to_delete, std::unordered_map< std::shared_ptr< ast::Include >, ast::NodeVector > &include_blocks_to_keep)
void visit_program(ast::Program &node) override
visit node of type ast::Program
AstNodeType
Enum type for every AST node type.
Definition: ast_decl.hpp:166
std::vector< std::shared_ptr< Statement > > StatementVector
Definition: ast_decl.hpp:302
std::vector< std::shared_ptr< Node > > NodeVector
Definition: ast_decl.hpp:301
constexpr const char * ast_class()
Definition: docstrings.hpp:33
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
static Node * node(Object *)
Definition: netcvode.cpp:291
Utility functions for visitors implementation.