NEURON
loop_unroll.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <catch2/catch_test_macros.hpp>
9 
10 #include "ast/program.hpp"
11 #include "parser/nmodl_driver.hpp"
12 #include "utils/test_utils.hpp"
19 
20 using namespace nmodl;
21 using namespace visitor;
22 using namespace test;
23 using namespace test_utils;
24 
25 using ast::AstNodeType;
27 
28 //=============================================================================
29 // Loop unroll tests
30 //=============================================================================
31 
32 std::string run_loop_unroll_visitor(const std::string& text) {
34  const auto& ast = driver.parse_string(text);
35 
40 
41  // check that, after visitor rearrangement, parents are still up-to-date
43 
44  return to_nmodl(ast, {AstNodeType::DEFINE});
45 }
46 
47 SCENARIO("Perform loop unrolling of FROM construct", "[visitor][unroll]") {
48  GIVEN("A loop with known iteration space") {
49  std::string input_nmodl = R"(
50  DEFINE N 2
51  PROCEDURE rates() {
52  LOCAL x[N]
53  FROM i=0 TO N {
54  x[i] = x[i] + 11
55  }
56  FROM i=(0+(0+1)) TO (N+2-1) {
57  x[(i+0)] = x[i+1] + 11
58  }
59  }
60  KINETIC state {
61  FROM i=1 TO N+1 {
62  ~ ca[i] <-> ca[i+1] (DFree*frat[i+1]*1(um), DFree*frat[i+1]*1(um))
63  }
64  }
65  )";
66  std::string output_nmodl = R"(
67  PROCEDURE rates() {
68  LOCAL x[N]
69  {
70  LOCAL i
71  i = 0
72  x[0] = x[0]+11
73  i = 1
74  x[1] = x[1]+11
75  i = 2
76  x[2] = x[2]+11
77  }
78  {
79  LOCAL i
80  i = 1
81  x[1] = x[2]+11
82  i = 2
83  x[2] = x[3]+11
84  i = 3
85  x[3] = x[4]+11
86  }
87  }
88 
89  KINETIC state {
90  {
91  LOCAL i
92  i = 1
93  ~ ca[1] <-> ca[2] (DFree*frat[2]*1(um), DFree*frat[2]*1(um))
94  i = 2
95  ~ ca[2] <-> ca[3] (DFree*frat[3]*1(um), DFree*frat[3]*1(um))
96  i = 3
97  ~ ca[3] <-> ca[4] (DFree*frat[4]*1(um), DFree*frat[4]*1(um))
98  }
99  }
100  )";
101  THEN("Loop body gets correctly unrolled") {
102  auto result = run_loop_unroll_visitor(input_nmodl);
103  REQUIRE(reindent_text(output_nmodl) == reindent_text(result));
104  }
105  }
106 
107  GIVEN("A nested loop") {
108  std::string input_nmodl = R"(
109  DEFINE N 1
110  PROCEDURE rates() {
111  LOCAL x[N]
112  FROM i=0 TO N {
113  FROM j=1 TO N+1 {
114  x[i] = x[i+j] + 1
115  }
116  }
117  }
118  )";
119  std::string output_nmodl = R"(
120  PROCEDURE rates() {
121  LOCAL x[N]
122  {
123  LOCAL i
124  i = 0
125  {
126  LOCAL j
127  j = 1
128  x[0] = x[1]+1
129  j = 2
130  x[0] = x[2]+1
131  }
132  i = 1
133  {
134  LOCAL j
135  j = 1
136  x[1] = x[2]+1
137  j = 2
138  x[1] = x[3]+1
139  }
140  }
141  }
142  )";
143  THEN("Loop get unrolled recursively") {
144  auto result = run_loop_unroll_visitor(input_nmodl);
145  REQUIRE(reindent_text(output_nmodl) == reindent_text(result));
146  }
147  }
148 
149 
150  GIVEN("Loop with verbatim and unknown iteration space") {
151  std::string input_nmodl = R"(
152  DEFINE N 1
153  PROCEDURE rates() {
154  LOCAL x[N]
155  FROM i=((0+0)) TO (((N+0))) {
156  FROM j=1 TO k {
157  x[i] = x[i+k] + 1
158  }
159  }
160  FROM i=0 TO N {
161  VERBATIM ENDVERBATIM
162  }
163  }
164  )";
165  std::string output_nmodl = R"(
166  PROCEDURE rates() {
167  LOCAL x[N]
168  {
169  LOCAL i
170  i = 0
171  FROM j = 1 TO k {
172  x[0] = x[0+k]+1
173  }
174  i = 1
175  FROM j = 1 TO k {
176  x[1] = x[1+k]+1
177  }
178  }
179  FROM i = 0 TO N {
180  VERBATIM ENDVERBATIM
181  }
182  }
183  )";
184  THEN("Only some loops get unrolled") {
185  auto result = run_loop_unroll_visitor(input_nmodl);
186  REQUIRE(reindent_text(output_nmodl) == reindent_text(result));
187  }
188  }
189 
190  GIVEN("A loop with a standalone loop variable") {
191  std::string input_nmodl = R"(
192  DEFINE N 2
193  PROCEDURE rates() {
194  LOCAL x[N]
195  FROM i=0 TO N {
196  x[i] = x[i] + 11
197  if(i == 0) {
198  x[i] = i
199  }
200  }
201  FROM i=(0+(0+1)) TO (N+2-1) {
202  x[(i+0)] = x[i+1] + 11
203  }
204  }
205  KINETIC state {
206  FROM i=1 TO N+1 {
207  ~ ca[i] <-> ca[i+1] (DFree*frat[i+1]*1(um), DFree*frat[i+1]*1(um))
208  }
209  }
210  )";
211  std::string output_nmodl = R"(
212  PROCEDURE rates() {
213  LOCAL x[N]
214  {
215  LOCAL i
216  i = 0
217  x[0] = x[0]+11
218  IF (i == 0) {
219  x[0] = i
220  }
221  i = 1
222  x[1] = x[1]+11
223  IF (i == 0) {
224  x[1] = i
225  }
226  i = 2
227  x[2] = x[2]+11
228  IF (i == 0) {
229  x[2] = i
230  }
231  }
232  {
233  LOCAL i
234  i = 1
235  x[1] = x[2]+11
236  i = 2
237  x[2] = x[3]+11
238  i = 3
239  x[3] = x[4]+11
240  }
241  }
242 
243  KINETIC state {
244  {
245  LOCAL i
246  i = 1
247  ~ ca[1] <-> ca[2] (DFree*frat[2]*1(um), DFree*frat[2]*1(um))
248  i = 2
249  ~ ca[2] <-> ca[3] (DFree*frat[3]*1(um), DFree*frat[3]*1(um))
250  i = 3
251  ~ ca[3] <-> ca[4] (DFree*frat[4]*1(um), DFree*frat[4]*1(um))
252  }
253  }
254  )";
255  THEN("Loop body gets correctly unrolled") {
256  auto result = run_loop_unroll_visitor(input_nmodl);
257  REQUIRE(reindent_text(output_nmodl) == reindent_text(result));
258  }
259  }
260 }
Visitor for checking parents of ast nodes
Class that binds all pieces together for parsing nmodl file.
void visit_program(ast::Program &node) override
visit node of type ast::Program
Perform constant folding of integer/float/double expressions.
Concrete visitor for constructing symbol table from AST.
void visit_program(ast::Program &node) override
visit node of type ast::Program
Visitor for checking parents of ast nodes
int check_ast(const ast::Ast &node)
A small wrapper to have a nicer call in parser.cpp.
Perform constant folding of integer/float/double expressions.
AstNodeType
Enum type for every AST node type.
Definition: ast_decl.hpp:166
bool parse_string(const std::string &input)
parser Units provided as string (used for testing)
Definition: unit_driver.cpp:40
std::string run_loop_unroll_visitor(const std::string &text)
Definition: loop_unroll.cpp:32
SCENARIO("Perform loop unrolling of FROM construct", "[visitor][unroll]")
Definition: loop_unroll.cpp:47
Unroll for loop in the AST.
std::string reindent_text(const std::string &text, int indent_level)
Reindent nmodl text for text-to-text comparison.
Definition: test_utils.cpp:55
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
#define text
Definition: plot.cpp:60
Auto generated AST classes declaration.
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
nmodl::parser::UnitDriver driver
Definition: parser.cpp:28
Utility functions for visitors implementation.