Skip to content

Commit 3968293

Browse files
authored
Allow combination of continuous-valued third factor and spike-based third factor (#1264)
1 parent 6358911 commit 3968293

19 files changed

Lines changed: 619 additions & 27 deletions

doc/running/running_nest.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ Additionally, if the synapse requires it, specify the ``"post_ports"`` entry to
282282
"post_ports": ["post_spikes",
283283
["I_post_dend", "I_dend"]]}]})
284284
285-
This specifies that the neuron ``iaf_psc_exp_dend`` has to be generated paired with the synapse ``third_factor_stdp``, and that the input ports ``post_spikes`` and ``I_post_dend`` in the synapse are to be connected to the postsynaptic partner. For the ``I_post_dend`` input port, the corresponding variable in the (postsynaptic) neuron is called ``I_dend``.
285+
This specifies that the neuron ``iaf_psc_exp_dend`` has to be generated paired with the synapse ``third_factor_stdp``, and that the input ports ``post_spikes`` and ``I_post_dend`` in the synapse are to be connected to the postsynaptic partner. For the ``I_post_dend`` input port, the corresponding variable in the (postsynaptic) neuron is called ``I_dend``. Note that inline expressions can also be used; in this example in case ``I_dend`` had been an inline expression in the postsynaptic neuron.
286286

287287
To prevent the NESTML code generator from moving specific variables from synapse into postsynaptic neuron, the code generation option ``strictly_synaptic_vars`` may be used (see https://nestml.readthedocs.io/en/latest/pynestml.transformers.html#pynestml.transformers.synapse_post_neuron_transformer.SynapsePostNeuronTransformer).
288288

pynestml/codegeneration/nest_code_generator.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,14 @@ def _get_synapse_model_namespace(self, synapse: ASTModel) -> Dict:
593593
namespace["state_vars_that_need_continuous_buffering_transformed"] = [xfrm.get_neuron_var_name_from_syn_port_name(port_name, removesuffix(synapse.paired_neuron.unpaired_name, FrontendConfiguration.suffix), removesuffix(synapse.paired_neuron.paired_synapse.get_name().split("__with_")[0], FrontendConfiguration.suffix)) for port_name in synapse.paired_neuron.state_vars_that_need_continuous_buffering]
594594
namespace["state_vars_that_need_continuous_buffering_transformed_iv"] = {}
595595
for var_name, var_name_transformed in zip(namespace["state_vars_that_need_continuous_buffering"], namespace["state_vars_that_need_continuous_buffering_transformed"]):
596-
namespace["state_vars_that_need_continuous_buffering_transformed_iv"][var_name] = self._nest_printer.print(synapse.paired_neuron.get_initial_value(var_name_transformed))
596+
if synapse.paired_neuron.get_initial_value(var_name_transformed) is None:
597+
if var_name_transformed in [sym.name for sym in synapse.paired_neuron.get_inline_expression_symbols()]:
598+
# the postsynaptic variable is actually an inline expression: initial value is 0
599+
namespace["state_vars_that_need_continuous_buffering_transformed_iv"][var_name] = "0"
600+
else:
601+
raise Exception("State variable \"" + str(var_name_transformed) + "\" was not found in the neuron model \"" + synapse.paired_neuron.name + "\"")
602+
else:
603+
namespace["state_vars_that_need_continuous_buffering_transformed_iv"][var_name] = self._nest_printer.print(synapse.paired_neuron.get_initial_value(var_name_transformed))
597604

598605
namespace["continuous_post_ports"] = []
599606
if "neuron_synapse_pairs" in FrontendConfiguration.get_codegen_opts().keys():
@@ -711,9 +718,20 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict:
711718
codegen_and_builder_opts = FrontendConfiguration.get_codegen_opts()
712719
xfrm = SynapsePostNeuronTransformer(codegen_and_builder_opts)
713720
namespace["state_vars_that_need_continuous_buffering_transformed"] = [xfrm.get_neuron_var_name_from_syn_port_name(port_name, removesuffix(neuron.unpaired_name, FrontendConfiguration.suffix), removesuffix(neuron.paired_synapse.get_name().split("__with_")[0], FrontendConfiguration.suffix)) for port_name in neuron.state_vars_that_need_continuous_buffering]
721+
for i, item in enumerate(namespace["state_vars_that_need_continuous_buffering_transformed"]):
722+
if item is None:
723+
raise Exception("State variable \"" + str(neuron.state_vars_that_need_continuous_buffering[i]) + "\" was not found in the neuron model \"" + neuron.name + "\"")
724+
714725
namespace["state_vars_that_need_continuous_buffering_transformed_iv"] = {}
715726
for var_name, var_name_transformed in zip(namespace["state_vars_that_need_continuous_buffering"], namespace["state_vars_that_need_continuous_buffering_transformed"]):
716-
namespace["state_vars_that_need_continuous_buffering_transformed_iv"][var_name] = self._nest_printer.print(neuron.get_initial_value(var_name_transformed))
727+
if neuron.get_initial_value(var_name_transformed) is None:
728+
if var_name_transformed in [sym.name for sym in neuron.get_inline_expression_symbols()]:
729+
# the postsynaptic variable is actually an inline expression: initial value is 0
730+
namespace["state_vars_that_need_continuous_buffering_transformed_iv"][var_name] = "0"
731+
else:
732+
raise Exception("State variable \"" + str(var_name_transformed) + "\" was not found in the neuron model \"" + neuron.name + "\"")
733+
else:
734+
namespace["state_vars_that_need_continuous_buffering_transformed_iv"][var_name] = self._nest_printer.print(neuron.get_initial_value(var_name_transformed))
717735
else:
718736
namespace["state_vars_that_need_continuous_buffering"] = []
719737
if "extra_on_emit_spike_stmts_from_synapse" in dir(neuron):

pynestml/codegeneration/resources_genn/directives_cpp/BufferDeclarationValue.jinja2

Lines changed: 0 additions & 1 deletion
This file was deleted.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../resources_nest/point_neuron/directives_cpp/MemberVariableGetter.jinja2

pynestml/codegeneration/resources_genn/directives_cpp/MemberVariableGetterSetter.jinja2

Lines changed: 0 additions & 1 deletion
This file was deleted.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../resources_nest/point_neuron/directives_cpp/MemberVariableSetter.jinja2

pynestml/codegeneration/resources_nest/point_neuron/common/NeuronHeader.jinja2

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,8 @@ public:
412412
{%- for variable_symbol in neuron.get_state_symbols() %}
413413
{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %}
414414
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
415-
{%- include "directives_cpp/MemberVariableGetterSetter.jinja2" %}
415+
{%- include "directives_cpp/MemberVariableGetter.jinja2" %}
416+
{%- include "directives_cpp/MemberVariableSetter.jinja2" %}
416417
{% endif %}
417418
{% endfor %}
418419
{%- endfilter %}
@@ -426,7 +427,8 @@ public:
426427
{% filter indent(2, True) -%}
427428
{%- for variable_symbol in neuron.get_parameter_symbols() %}
428429
{%- set variable = utils.get_parameter_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
429-
{%- include "directives_cpp/MemberVariableGetterSetter.jinja2" %}
430+
{%- include "directives_cpp/MemberVariableGetter.jinja2" %}
431+
{%- include "directives_cpp/MemberVariableSetter.jinja2" %}
430432

431433
{% endfor %}
432434
{%- endfilter %}
@@ -440,7 +442,8 @@ public:
440442
{% filter indent(2, True) -%}
441443
{%- for variable_symbol in neuron.get_internal_symbols() %}
442444
{%- with variable = utils.get_internal_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
443-
{%- include "directives_cpp/MemberVariableGetterSetter.jinja2" %}
445+
{%- include "directives_cpp/MemberVariableGetter.jinja2" %}
446+
{%- include "directives_cpp/MemberVariableSetter.jinja2" %}
444447
{%- endwith %}
445448
{% endfor %}
446449
{%- endfilter %}
@@ -878,15 +881,30 @@ private:
878881
};
879882

880883
// -------------------------------------------------------------------------
881-
// Getters/setters for inline expressions
884+
// Getters for inline expressions
882885
// -------------------------------------------------------------------------
886+
public:
887+
{% filter indent(2, True) -%}
888+
{%- for equations_block in neuron.get_equations_blocks() %}
889+
{%- for inline_expr in equations_block.get_inline_expressions() %}
890+
{%- set variable = ast_node_factory.create_ast_variable(inline_expr.get_variable_name(), differential_order=0, scope=inline_expr.scope) %}
891+
{%- set variable_symbol = equations_block.get_scope().resolve_to_symbol(inline_expr.get_variable_name(), SymbolKind.VARIABLE) %}
892+
{%- include "directives_cpp/MemberVariableGetter.jinja2" %}
883893

894+
{% endfor %}
895+
{%- endfor %}
896+
{%- endfilter %}
897+
898+
// -------------------------------------------------------------------------
899+
// Setters for inline expressions (this is allowed for expressions containing a convolve() call)
900+
// -------------------------------------------------------------------------
901+
private:
884902
{% filter indent(2, True) -%}
885903
{%- for equations_block in neuron.get_equations_blocks() %}
886904
{%- for inline_expr in equations_block.get_inline_expressions() %}
887905
{%- set variable = ast_node_factory.create_ast_variable(inline_expr.get_variable_name(), differential_order=0, scope=inline_expr.scope) %}
888906
{%- set variable_symbol = equations_block.get_scope().resolve_to_symbol(inline_expr.get_variable_name(), SymbolKind.VARIABLE) %}
889-
{%- include "directives_cpp/MemberVariableGetterSetter.jinja2" %}
907+
{%- include "directives_cpp/MemberVariableSetter.jinja2" %}
890908

891909
{% endfor %}
892910
{%- endfor %}

pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,8 @@ private:
428428
{%- set isHomogeneous = PyNestMLLexer["DECORATOR_HOMOGENEOUS"] in variable_symbol.get_decorators() %}
429429
{%- if not isHomogeneous %}
430430
{%- if variable.get_name() != nest_codegen_opt_delay_variable and variable.get_name() != synapse_weight_variable %}
431-
{%- include "directives_cpp/MemberVariableGetterSetter.jinja2" %}
431+
{%- include "directives_cpp/MemberVariableGetter.jinja2" %}
432+
{%- include "directives_cpp/MemberVariableSetter.jinja2" %}
432433
{% elif variable.get_name() == synapse_weight_variable and variable.get_name() != "weight" %}
433434
{# weight is its own special case in NEST #}
434435
inline {{ declarations.print_variable_type(variable_symbol) }} get_{{ variable.get_name() }}() const
@@ -454,7 +455,8 @@ inline void set_{{ variable.get_name() }}(const {{ declarations.print_variable_t
454455
{%- for inline_expr in equations_block.get_inline_expressions() %}
455456
{%- set variable = ast_node_factory.create_ast_variable(inline_expr.get_variable_name(), differential_order=0, scope=inline_expr.scope) %}
456457
{%- set variable_symbol = equations_block.get_scope().resolve_to_symbol(variable.get_complete_name(), SymbolKind.VARIABLE) %}
457-
{%- include "directives_cpp/MemberVariableGetterSetter.jinja2" %}
458+
{%- include "directives_cpp/MemberVariableGetter.jinja2" %}
459+
{%- include "directives_cpp/MemberVariableSetter.jinja2" %}
458460
{%- endfor %}
459461
{%- endfor %}
460462
{%- endfilter %}
@@ -1082,9 +1084,11 @@ void
10821084

10831085
{%- filter indent(4, True) %}
10841086
{%- set dynamics = synapse.get_on_receive_block(vt_port) %}
1085-
{%- with ast = dynamics.get_stmts_body() %}
1086-
{%- include "directives_cpp/StmtsBody.jinja2" %}
1087-
{%- endwith %}
1087+
{%- if dynamics is not none %}
1088+
{%- with ast = dynamics.get_stmts_body() %}
1089+
{%- include "directives_cpp/StmtsBody.jinja2" %}
1090+
{%- endwith %}
1091+
{%- endif %}
10881092
{%- endfilter %}
10891093
// process remaining dopa spikes in (t0, t1]
10901094
double cd;
@@ -1102,9 +1106,11 @@ void
11021106
**/
11031107
{%- filter indent(6, True) %}
11041108
{%- set dynamics = synapse.get_on_receive_block(vt_port) %}
1105-
{%- with ast = dynamics.get_stmts_body() %}
1106-
{%- include "directives_cpp/StmtsBody.jinja2" %}
1107-
{%- endwith %}
1109+
{%- if dynamics is not none %}
1110+
{%- with ast = dynamics.get_stmts_body() %}
1111+
{%- include "directives_cpp/StmtsBody.jinja2" %}
1112+
{%- endwith %}
1113+
{%- endif %}
11081114
{%- endfilter %}
11091115

11101116
/**
@@ -1546,6 +1552,20 @@ inline void
15461552
std::cout << "[synapse " << this << "] {{ synapseName }}::trigger_update_weight(): t = " << t_trig << std::endl;
15471553
#endif
15481554

1555+
const size_t tid = kernel().vp_manager.get_thread_id();
1556+
{%- if paired_neuron_name is not none and paired_neuron_name|length > 0 %}
1557+
{{ paired_neuron_name }}* __target = static_cast< {{ paired_neuron_name }}* >(get_target(tid));
1558+
assert(__target);
1559+
{%- else %}
1560+
Node* __target = get_target( tid );
1561+
{%- endif %}
1562+
1563+
auto get_t = [t_trig](){ return t_trig; }; // do not remove, this is in case the predefined time variable ``t`` is used in the NESTML model
1564+
1565+
{%- if paired_neuron_name is not none and paired_neuron_name|length > 0 and paired_neuron.state_vars_that_need_continuous_buffering | length > 0 and continuous_state_buffering_method == "continuous_time_buffer" %}
1566+
#error "For third-factor plasticity combining spiking and continuous-time inputs, \"continuous_time_buffer\" is not supported yet!"
1567+
{%- endif %}
1568+
15491569
// purely dendritic delay
15501570
double dendritic_delay = get_delay();
15511571

@@ -1562,7 +1582,7 @@ inline void
15621582
while ( start != finish )
15631583
{
15641584
{%- for vt_port in vt_ports %}
1565-
{%- set vt_port = vt_ports[0] %}
1585+
{%- set vt_port = vt_ports[0] %}
15661586
process_{{vt_port}}_spikes_( vt_spikes, t0, start->t_ + dendritic_delay, cp );
15671587
{%- endfor %}
15681588

@@ -1574,6 +1594,20 @@ inline void
15741594
* update synapse internal state from `t_last_update_` to `start->t_ + dendritic_delay`
15751595
**/
15761596

1597+
{%- if paired_neuron_name is not none and paired_neuron_name|length > 0 and paired_neuron.state_vars_that_need_continuous_buffering | length > 0 %}
1598+
/**
1599+
* grab state variables from the postsynaptic neuron at the time of the post spike
1600+
**/
1601+
1602+
{#- post spike based: grab the entry from the post spiking history buffer #}
1603+
{%- if continuous_state_buffering_method == "post_spike_based" %}
1604+
{%- for var_name in paired_neuron.state_vars_that_need_continuous_buffering %}
1605+
{%- set var_name_post = utils.get_var_name_tuples_of_neuron_synapse_pair(continuous_post_ports, var_name) %}
1606+
const double __{{ var_name }} = start->{{ var_name }}_;
1607+
{%- endfor %}
1608+
{%- endif %}
1609+
{%- endif %}
1610+
15771611
update_internal_state_(t_last_update_,
15781612
(start->t_ + dendritic_delay) - t_last_update_,
15791613
cp);
@@ -1609,7 +1643,7 @@ inline void
16091643
**/
16101644

16111645
{%- for vt_port in vt_ports %}
1612-
{%- set vt_port = vt_ports[0] %}
1646+
{%- set vt_port = vt_ports[0] %}
16131647
process_{{vt_port}}_spikes_( vt_spikes, t_lastspike_, t_trig, cp );
16141648
{%- endfor %}
16151649
#ifdef DEBUG

pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/MemberVariableGetterSetter.jinja2 renamed to pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/MemberVariableGetter.jinja2

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,4 @@ inline {{ declarations.print_variable_type(variable_symbol) }} get_{{ printer_no
88
{
99
return {{ nest_codegen_utils.print_symbol_origin(variable_symbol, variable) % printer_no_origin.print(variable) }};
1010
}
11-
12-
inline void set_{{ printer_no_origin.print(variable) }}(const {{ declarations.print_variable_type(variable_symbol) }} __v)
13-
{
14-
{{ nest_codegen_utils.print_symbol_origin(variable_symbol, variable) % printer_no_origin.print(variable) }} = __v;
15-
}
1611
{%- endif %}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{% if not (variable_symbol.is_inline_expression and not utils.contains_convolve_call(variable_symbol)) -%}
2+
inline void set_{{ printer_no_origin.print(variable) }}(const {{ declarations.print_variable_type(variable_symbol) }} __v)
3+
{
4+
{{ nest_codegen_utils.print_symbol_origin(variable_symbol, variable) % printer_no_origin.print(variable) }} = __v;
5+
}
6+
{%- endif %}

0 commit comments

Comments
 (0)