Skip to content

Commit ed52fc5

Browse files
committed
Added new feature to topology features
1 parent 3ebc263 commit ed52fc5

File tree

1 file changed

+193
-1
lines changed

1 file changed

+193
-1
lines changed

bindsnet/network/topology_features.py

Lines changed: 193 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def link(self, parent_feature) -> None:
281281
Allow two features to share tensor values
282282
"""
283283

284-
valid_features = (Probability, Weight, Bias, Intensity)
284+
valid_features = (Probability, Delay, Weight, Bias, Intensity)
285285

286286
assert isinstance(self, valid_features), f"A {self} cannot use feature linking"
287287
assert isinstance(
@@ -464,6 +464,198 @@ def assert_valid_range(self):
464464
), f"Invalid range for feature {self.name}: the min value must be of type torch.Tensor, float, or int"
465465

466466

467+
class Delay(AbstractFeature):
468+
def __init__(
469+
self,
470+
name: str,
471+
value: Union[torch.Tensor, float, int] = None,
472+
range: Optional[Sequence[float]] = None,
473+
norm: Optional[Union[torch.Tensor, float, int]] = None,
474+
learning_rule: Optional[bindsnet.learning.LearningRule] = None,
475+
nu: Optional[Union[list, tuple]] = None,
476+
reduction: Optional[callable] = None,
477+
decay: float = 0.0,
478+
max_delay: Optional[int] = 32,
479+
delay_decay: Optional[float] = 0, # TODO: Make this global + lambda
480+
drop_late_spikes: Optional[bool] = False,
481+
refractory: Optional[bool] = False, # TODO: Change this name
482+
normalize_delays: Optional[
483+
bool
484+
] = False, # force normalize delays instead of clipping
485+
) -> None:
486+
# language=rst
487+
"""
488+
Delays outgoing signals based on the values of :code:`value` and :code:`max_delay`. Delays are calculated as
489+
being :code:`value` * :code:`max_delay`, where :code: `value` is in range [0, 1]
490+
:param name: Name of the feature
491+
:param value: Unscaled delays. Unscaled implies that these values are in [0, 1], and will be multiplied by :code:`max_delay` to determine delay time
492+
:param range: Range of acceptable values for the :code:`value` parameter
493+
:param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample
494+
and after the value has been updated by the learning rule (if there is one)
495+
:param learning_rule: Rule which will modify the :code:`value` after each sample
496+
:param nu: Learning rate for the learning rule
497+
:param reduction: Method for reducing parameter updates along the minibatch
498+
dimension
499+
:param decay: Constant multiple to decay weights by on each iteration
500+
:param max_delay: Maximum possible delay
501+
:param delay_decay: Decay :code:`value` by this amount every time step
502+
:param drop_late_spikes: Surpress spikes when delay is at maximum
503+
:param refractory: Block spikes in synapse until earlier ones pass
504+
:param normalize: Force normalize delay every run instead of clipping values
505+
"""
506+
507+
### Assertions ###
508+
super().__init__(
509+
name=name,
510+
value=value,
511+
range=[0, 1], # note: Value isn't used, not 'None' to avoid errors
512+
norm=norm,
513+
learning_rule=learning_rule,
514+
nu=nu,
515+
reduction=reduction,
516+
decay=decay,
517+
)
518+
self.max_delay = max_delay
519+
self.delay_decay = delay_decay
520+
self.drop_late_spikes = drop_late_spikes
521+
self.refractory = refractory
522+
self.normalize_delays = normalize_delays
523+
524+
def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
525+
value = self.value.clone().detach().flatten()
526+
if self.normalize_delays:
527+
# force normalize delay values
528+
tmp_min, tmp_max = torch.min(value), torch.max(value)
529+
if tmp_max > 1 or tmp_min < 0:
530+
value = (value - tmp_min) / (tmp_max - tmp_min)
531+
else:
532+
# force clip delay values
533+
value = torch.clamp(value, 0, 1)
534+
535+
# Generate new delays for insertion into buffers
536+
delays = ((1 - value) * (self.max_delay - 1)).long()
537+
538+
if self.refractory:
539+
# TODO: Is there a reason why this is in here?
540+
if self.drop_late_spikes:
541+
conn_spikes[delays == self.max_delay] = 0
542+
543+
# Prevent additional spikes if one is already on the synapse
544+
conn_spikes &= self.refrac_count <= 0
545+
self.refrac_count -= 1
546+
bool_spikes = conn_spikes.bool()
547+
self.refrac_count[bool_spikes] = delays[bool_spikes]
548+
549+
# add circular time index to delays
550+
# TODO: Dead spikes of delay = self.dmax don't properly die if self.time_idx > 0
551+
delays = (delays + self.time_idx) % self.max_delay
552+
553+
# Fill the delay buffer, according to connection delays
554+
# |delay_buffer| = [source.n * target.n, max_delay]
555+
# TODO: Can we remove .float() for performance? (Change delay buffer type?)
556+
flattened_conn_spikes = conn_spikes.flatten().float()
557+
self.delay_buffer[self.delays_idx, delays] = flattened_conn_spikes # .bool()
558+
559+
# Outgoing signal is spikes scheduled to fire at time_idx
560+
# TODO: Detach + Clone likely not efficient as passing reference to buffer at current time index; efficiency
561+
out_signal = (
562+
self.delay_buffer[:, self.time_idx]
563+
.view(self.source_n, self.target_n)
564+
.detach()
565+
.clone()
566+
)
567+
568+
# Clear transmitted spikes
569+
self.delay_buffer[:, self.time_idx] = 0.0
570+
571+
# Suppress max delays
572+
if self.drop_late_spikes and not self.refractory:
573+
late_spikes_time = (self.time_idx - 1) % self.max_delay
574+
self.delay_buffer[:, late_spikes_time] = 0.0
575+
576+
# Increment circular time pointer
577+
self.time_idx = (self.time_idx + 1) % self.max_delay
578+
579+
# TODO: Remember to move this to global
580+
# Decay
581+
if self.delay_decay:
582+
self.delay_buffer = self.delay_buffer - self.delay_decay.to("cuda")
583+
self.delay_buffer[
584+
self.delay_decay < 0
585+
] = 0 # TODO: Determine if this is faster than clamp(min=0)
586+
587+
return out_signal
588+
589+
def reset_state_variables(self) -> None:
590+
super().reset_state_variables()
591+
592+
# Reset time index and empty buffer
593+
self.time_idx = 0
594+
self.delay_buffer.zero_()
595+
596+
def prime_feature(self, connection, device, **kwargs) -> None:
597+
#### Initialize value ####
598+
if self.value is None:
599+
self.initialize_value = lambda: torch.clamp(
600+
torch.rand(
601+
(connection.source.n, connection.target.n),
602+
dtype=torch.float32,
603+
device=device,
604+
),
605+
self.range[0],
606+
self.range[1],
607+
)
608+
else:
609+
self.value = self.value.to(torch.float32).to(device)
610+
611+
super().prime_feature(connection, device, **kwargs)
612+
613+
#### Initialize additional class variables ####
614+
self.delays_idx = torch.arange(
615+
0, connection.source.n * connection.target.n, dtype=torch.int32
616+
).to(device)
617+
self.delay_buffer = torch.zeros(
618+
connection.source.n * connection.target.n,
619+
self.max_delay,
620+
dtype=torch.float32,
621+
).to(device)
622+
self.time_idx = 0
623+
self.source_n = connection.source.n
624+
self.target_n = connection.target.n
625+
626+
# Tensor necessary for interaction with delay buffer
627+
if self.delay_decay:
628+
self.delay_decay = torch.tensor([self.delay_decay])
629+
630+
if self.refractory:
631+
self.refrac_count = torch.zeros(
632+
connection.source.n * connection.target.n,
633+
dtype=torch.long,
634+
device=device,
635+
)
636+
637+
def assert_valid_range(self):
638+
super().assert_valid_range()
639+
640+
r = self.range
641+
642+
## Check min greater than 0 ##
643+
if isinstance(r[0], torch.Tensor):
644+
assert (
645+
r[0] >= 0
646+
).all(), (
647+
f"Invalid range for feature {self.name}: a min value is less than 0"
648+
)
649+
elif isinstance(r[0], (float, int)):
650+
assert (
651+
r[0] >= 0
652+
), f"Invalid range for feature {self.name}: the min value is less than 0"
653+
else:
654+
assert (
655+
False
656+
), f"Invalid range for feature {self.name}: the min value must be of type torch.Tensor, float, or int"
657+
658+
467659
class Mask(AbstractFeature):
468660
def __init__(
469661
self,

0 commit comments

Comments
 (0)