@@ -281,7 +281,7 @@ def link(self, parent_feature) -> None:
281281 Allow two features to share tensor values
282282 """
283283
284- valid_features = (Probability , Weight , Bias , Intensity )
284+ valid_features = (Probability , Delay , Weight , Bias , Intensity )
285285
286286 assert isinstance (self , valid_features ), f"A { self } cannot use feature linking"
287287 assert isinstance (
@@ -464,6 +464,198 @@ def assert_valid_range(self):
464464 ), f"Invalid range for feature { self .name } : the min value must be of type torch.Tensor, float, or int"
465465
466466
467+ class Delay (AbstractFeature ):
468+ def __init__ (
469+ self ,
470+ name : str ,
471+ value : Union [torch .Tensor , float , int ] = None ,
472+ range : Optional [Sequence [float ]] = None ,
473+ norm : Optional [Union [torch .Tensor , float , int ]] = None ,
474+ learning_rule : Optional [bindsnet .learning .LearningRule ] = None ,
475+ nu : Optional [Union [list , tuple ]] = None ,
476+ reduction : Optional [callable ] = None ,
477+ decay : float = 0.0 ,
478+ max_delay : Optional [int ] = 32 ,
479+ delay_decay : Optional [float ] = 0 , # TODO: Make this global + lambda
480+ drop_late_spikes : Optional [bool ] = False ,
481+ refractory : Optional [bool ] = False , # TODO: Change this name
482+ normalize_delays : Optional [
483+ bool
484+ ] = False , # force normalize delays instead of clipping
485+ ) -> None :
486+ # language=rst
487+ """
488+ Delays outgoing signals based on the values of :code:`value` and :code:`max_delay`. Delays are calculated as
489+ being :code:`value` * :code:`max_delay`, where :code: `value` is in range [0, 1]
490+ :param name: Name of the feature
491+ :param value: Unscaled delays. Unscaled implies that these values are in [0, 1], and will be multiplied by :code:`max_delay` to determine delay time
492+ :param range: Range of acceptable values for the :code:`value` parameter
493+ :param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample
494+ and after the value has been updated by the learning rule (if there is one)
495+ :param learning_rule: Rule which will modify the :code:`value` after each sample
496+ :param nu: Learning rate for the learning rule
497+ :param reduction: Method for reducing parameter updates along the minibatch
498+ dimension
499+ :param decay: Constant multiple to decay weights by on each iteration
500+ :param max_delay: Maximum possible delay
501+ :param delay_decay: Decay :code:`value` by this amount every time step
502+ :param drop_late_spikes: Surpress spikes when delay is at maximum
503+ :param refractory: Block spikes in synapse until earlier ones pass
504+ :param normalize: Force normalize delay every run instead of clipping values
505+ """
506+
507+ ### Assertions ###
508+ super ().__init__ (
509+ name = name ,
510+ value = value ,
511+ range = [0 , 1 ], # note: Value isn't used, not 'None' to avoid errors
512+ norm = norm ,
513+ learning_rule = learning_rule ,
514+ nu = nu ,
515+ reduction = reduction ,
516+ decay = decay ,
517+ )
518+ self .max_delay = max_delay
519+ self .delay_decay = delay_decay
520+ self .drop_late_spikes = drop_late_spikes
521+ self .refractory = refractory
522+ self .normalize_delays = normalize_delays
523+
524+ def compute (self , conn_spikes ) -> Union [torch .Tensor , float , int ]:
525+ value = self .value .clone ().detach ().flatten ()
526+ if self .normalize_delays :
527+ # force normalize delay values
528+ tmp_min , tmp_max = torch .min (value ), torch .max (value )
529+ if tmp_max > 1 or tmp_min < 0 :
530+ value = (value - tmp_min ) / (tmp_max - tmp_min )
531+ else :
532+ # force clip delay values
533+ value = torch .clamp (value , 0 , 1 )
534+
535+ # Generate new delays for insertion into buffers
536+ delays = ((1 - value ) * (self .max_delay - 1 )).long ()
537+
538+ if self .refractory :
539+ # TODO: Is there a reason why this is in here?
540+ if self .drop_late_spikes :
541+ conn_spikes [delays == self .max_delay ] = 0
542+
543+ # Prevent additional spikes if one is already on the synapse
544+ conn_spikes &= self .refrac_count <= 0
545+ self .refrac_count -= 1
546+ bool_spikes = conn_spikes .bool ()
547+ self .refrac_count [bool_spikes ] = delays [bool_spikes ]
548+
549+ # add circular time index to delays
550+ # TODO: Dead spikes of delay = self.dmax don't properly die if self.time_idx > 0
551+ delays = (delays + self .time_idx ) % self .max_delay
552+
553+ # Fill the delay buffer, according to connection delays
554+ # |delay_buffer| = [source.n * target.n, max_delay]
555+ # TODO: Can we remove .float() for performance? (Change delay buffer type?)
556+ flattened_conn_spikes = conn_spikes .flatten ().float ()
557+ self .delay_buffer [self .delays_idx , delays ] = flattened_conn_spikes # .bool()
558+
559+ # Outgoing signal is spikes scheduled to fire at time_idx
560+ # TODO: Detach + Clone likely not efficient as passing reference to buffer at current time index; efficiency
561+ out_signal = (
562+ self .delay_buffer [:, self .time_idx ]
563+ .view (self .source_n , self .target_n )
564+ .detach ()
565+ .clone ()
566+ )
567+
568+ # Clear transmitted spikes
569+ self .delay_buffer [:, self .time_idx ] = 0.0
570+
571+ # Suppress max delays
572+ if self .drop_late_spikes and not self .refractory :
573+ late_spikes_time = (self .time_idx - 1 ) % self .max_delay
574+ self .delay_buffer [:, late_spikes_time ] = 0.0
575+
576+ # Increment circular time pointer
577+ self .time_idx = (self .time_idx + 1 ) % self .max_delay
578+
579+ # TODO: Remember to move this to global
580+ # Decay
581+ if self .delay_decay :
582+ self .delay_buffer = self .delay_buffer - self .delay_decay .to ("cuda" )
583+ self .delay_buffer [
584+ self .delay_decay < 0
585+ ] = 0 # TODO: Determine if this is faster than clamp(min=0)
586+
587+ return out_signal
588+
589+ def reset_state_variables (self ) -> None :
590+ super ().reset_state_variables ()
591+
592+ # Reset time index and empty buffer
593+ self .time_idx = 0
594+ self .delay_buffer .zero_ ()
595+
596+ def prime_feature (self , connection , device , ** kwargs ) -> None :
597+ #### Initialize value ####
598+ if self .value is None :
599+ self .initialize_value = lambda : torch .clamp (
600+ torch .rand (
601+ (connection .source .n , connection .target .n ),
602+ dtype = torch .float32 ,
603+ device = device ,
604+ ),
605+ self .range [0 ],
606+ self .range [1 ],
607+ )
608+ else :
609+ self .value = self .value .to (torch .float32 ).to (device )
610+
611+ super ().prime_feature (connection , device , ** kwargs )
612+
613+ #### Initialize additional class variables ####
614+ self .delays_idx = torch .arange (
615+ 0 , connection .source .n * connection .target .n , dtype = torch .int32
616+ ).to (device )
617+ self .delay_buffer = torch .zeros (
618+ connection .source .n * connection .target .n ,
619+ self .max_delay ,
620+ dtype = torch .float32 ,
621+ ).to (device )
622+ self .time_idx = 0
623+ self .source_n = connection .source .n
624+ self .target_n = connection .target .n
625+
626+ # Tensor necessary for interaction with delay buffer
627+ if self .delay_decay :
628+ self .delay_decay = torch .tensor ([self .delay_decay ])
629+
630+ if self .refractory :
631+ self .refrac_count = torch .zeros (
632+ connection .source .n * connection .target .n ,
633+ dtype = torch .long ,
634+ device = device ,
635+ )
636+
637+ def assert_valid_range (self ):
638+ super ().assert_valid_range ()
639+
640+ r = self .range
641+
642+ ## Check min greater than 0 ##
643+ if isinstance (r [0 ], torch .Tensor ):
644+ assert (
645+ r [0 ] >= 0
646+ ).all (), (
647+ f"Invalid range for feature { self .name } : a min value is less than 0"
648+ )
649+ elif isinstance (r [0 ], (float , int )):
650+ assert (
651+ r [0 ] >= 0
652+ ), f"Invalid range for feature { self .name } : the min value is less than 0"
653+ else :
654+ assert (
655+ False
656+ ), f"Invalid range for feature { self .name } : the min value must be of type torch.Tensor, float, or int"
657+
658+
467659class Mask (AbstractFeature ):
468660 def __init__ (
469661 self ,
0 commit comments