NEURON
state_discontinuity_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2025 EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 
9 #include <utility>
10 
11 #include "ast/all.hpp"
13 #include "utils/logger.hpp"
16 
17 namespace nmodl {
18 namespace visitor {
19 
20 
21 // convert `state_discontinuity(A, B)` call to `A = B` statement
23  const auto& args = node.get_arguments();
24  const auto& lhs = args[0];
25  const auto& rhs = args[1];
26  return create_statement(fmt::format("{} = {}", to_nmodl(lhs), to_nmodl(rhs)));
27 }
28 
29 
30 // we only change the call to `state_discontinuity` if it is in the NET_RECEIVE block
32  in_net_receive_block = true;
33  node.visit_children(*this);
34  in_net_receive_block = false;
35 }
36 
39  // where we will store all of the new statements to insert into the block
40  ast::StatementVector new_statements;
41 
42  // collect all statements; if it's a call to `state_discontinuity`, replace it with `A = B`
43  const auto& statements = node.get_statements();
44  for (const auto& statement: statements) {
45  auto current_statement = statement;
46  const auto& hits = collect_nodes(*current_statement, {ast::AstNodeType::FUNCTION_CALL});
47  for (const auto& hit: hits) {
48  const auto& fn_call = std::dynamic_pointer_cast<ast::FunctionCall>(hit);
49  if (fn_call->get_name()->get_node_name() ==
51  current_statement = convert_state_discontinuity(*fn_call);
52  logger->info("Converting {} to {}",
53  to_nmodl(statement),
54  to_nmodl(current_statement));
55  }
56  }
57  new_statements.push_back(current_statement);
58  }
59  node.set_statements(new_statements);
60  }
61  node.visit_children(*this);
62 }
63 
64 
65 } // namespace visitor
66 } // namespace nmodl
Auto generated AST classes declaration.
Represents block encapsulating list of statements.
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
void visit_net_receive_block(ast::NetReceiveBlock &node) override
visit node of type ast::NetReceiveBlock
bool in_net_receive_block
true if we are visiting a NET_RECEIVE block
@ FUNCTION_CALL
type of ast::FunctionCall
std::vector< std::shared_ptr< Statement > > StatementVector
Definition: ast_decl.hpp:302
#define rhs
Definition: lineq.h:6
static constexpr char NRN_STATE_DISC_METHOD[]
state_discontinuity function in nmodl
std::shared_ptr< Statement > create_statement(const std::string &code_statement)
Convert given code statement (in string format) to corresponding ast node.
static auto convert_state_discontinuity(const ast::FunctionCall &node)
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
logger_type logger
Definition: logger.cpp:34
static Node * node(Object *)
Definition: netcvode.cpp:291
Visitor used for replacing literal calls to state_discontinuity in a NET_RECEIVE block.
Utility functions for visitors implementation.