@@ -60,20 +60,9 @@ def causal_conv1d_fn(*args, **kwargs):
6060 from subquadratic_ops_torch .fft_causal_conv1d import fft_causal_conv1d as _subq_fft_causal_conv1d
6161 from subquadratic_ops_torch .implicit_filter import implicit_filter
6262
63- def causal_conv1d (* args , ** kwargs ):
64- """Run guarded subquadratic causal_conv1d."""
65- ensure_subquadratic_causal_conv1d_supported ()
66- return _subq_causal_conv1d (* args , ** kwargs )
67-
68- def b2b_causal_conv1d (* args , ** kwargs ):
69- """Run guarded subquadratic b2b_causal_conv1d."""
70- ensure_subquadratic_b2b_causal_conv1d_supported ()
71- return _subq_b2b_causal_conv1d (* args , ** kwargs )
72-
73- def fft_causal_conv1d (* args , ** kwargs ):
74- """Run guarded subquadratic fft_causal_conv1d."""
75- ensure_subquadratic_fft_causal_conv1d_supported ()
76- return _subq_fft_causal_conv1d (* args , ** kwargs )
63+ causal_conv1d = _subq_causal_conv1d
64+ b2b_causal_conv1d = _subq_b2b_causal_conv1d
65+ fft_causal_conv1d = _subq_fft_causal_conv1d
7766except ImportError as e :
7867 msg_causal_conv1d = f"Problem importing subquadratic_ops: { e } . causal_conv1d is not available."
7968 msg_b2b_causal_conv1d = f"Problem importing subquadratic_ops: { e } . b2b_causal_conv1d is not available."
@@ -471,7 +460,17 @@ def hyena_no_weight_decay_cond_with_embeddings(name, param):
471460 return ("embedding" in name ) or hyena_no_weight_decay_cond (name , param )
472461
473462
474- def fftconv_func (u , k , D , dropout_mask , gelu = True , k_rev = None , bidirectional = False , use_subquadratic_ops = False ): # noqa: N803
463+ def fftconv_func (
464+ u ,
465+ k ,
466+ D , # noqa: N803
467+ dropout_mask ,
468+ gelu = True ,
469+ k_rev = None ,
470+ bidirectional = False ,
471+ use_subquadratic_ops = False ,
472+ check_subquadratic_ops = True ,
473+ ):
475474 """Apply a 1D convolution to the input sequence u using the filter k and the shortcut D."""
476475 seqlen = u .shape [- 1 ]
477476 fft_size = 2 * seqlen
@@ -504,6 +503,8 @@ def fftconv_func(u, k, D, dropout_mask, gelu=True, k_rev=None, bidirectional=Fal
504503 # causal
505504 else :
506505 if use_subquadratic_ops :
506+ if check_subquadratic_ops and u .is_cuda :
507+ ensure_subquadratic_fft_causal_conv1d_supported ()
507508 y = fft_causal_conv1d (u , k .squeeze (0 ))
508509 else :
509510 fft_size = max (fft_size , 2 * k .shape [- 1 ])
@@ -902,6 +903,7 @@ def __init__(
902903 self .zigzag = zigzag
903904
904905 self .use_subquadratic_ops = transformer_config .use_subquadratic_ops
906+ self ._subquadratic_ops_checked = False
905907
906908 self .model_parallel_size = self .pg_collection .tp .size () if self .pg_collection .tp is not None else 1
907909 self .model_parallel_rank = self .pg_collection .tp .rank () if self .pg_collection .tp is not None else 0
@@ -984,6 +986,16 @@ def reset_parameters(self):
984986 bounds = math .sqrt (1 / self .kernel_size )
985987 torch .nn .init .uniform_ (self .conv_bias , a = - bounds , b = bounds )
986988
989+ def _ensure_subquadratic_ops_supported (self ):
990+ """Run expensive subquadratic-op CUDA self-tests once per operator instance."""
991+ if self ._subquadratic_ops_checked or not self .use_subquadratic_ops :
992+ return
993+ if self .operator_type == "hyena_medium_conv" and self .kernel_size < 128 :
994+ ensure_subquadratic_causal_conv1d_supported ()
995+ else :
996+ ensure_subquadratic_fft_causal_conv1d_supported ()
997+ self ._subquadratic_ops_checked = True
998+
987999 def forward_long (self , * , x1 , x2 , v , h , bias , inference_context ):
9881000 """Forward pass long."""
9891001 import bionemo .evo2 .models .megatron .hyena .engine as engine
@@ -1074,6 +1086,7 @@ def get_filter_state(filter_name):
10741086 fir_length = self .kernel_size , # self.short_filter_length,
10751087 compute_state = inference_context is not None ,
10761088 use_subquadratic_ops = self .use_subquadratic_ops ,
1089+ check_subquadratic_ops = False ,
10771090 )
10781091 y = rearrange (y , "b d l -> b l d" )
10791092 y = y * x1
@@ -1099,6 +1112,8 @@ def forward(self, x1, x2, v, _hyena_use_cp=True, inference_context=None):
10991112 Input shapes: bs, (num_groups, group_size), seq_length
11001113 Output shapes: bs, (num_groups, group_size), seq_length
11011114 """
1115+ if x1 .is_cuda :
1116+ self ._ensure_subquadratic_ops_supported ()
11021117 B , GDG , L = x1 .shape # noqa: N806
11031118 x1 , x2 , v = x1 [..., :L ], x2 [..., :L ], v [..., :L ]
11041119
@@ -1189,6 +1204,7 @@ def forward(self, x1, x2, v, _hyena_use_cp=True, inference_context=None):
11891204 gelu = False ,
11901205 bidirectional = self .bidirectional ,
11911206 use_subquadratic_ops = self .use_subquadratic_ops ,
1207+ check_subquadratic_ops = False ,
11921208 )
11931209 z = z .to (v .dtype )
11941210
@@ -1388,6 +1404,7 @@ def __init__(
13881404 self .num_groups = num_groups
13891405 self .transformer_config = transformer_config
13901406 self .use_subquadratic_ops = transformer_config .use_subquadratic_ops
1407+ self ._subquadratic_ops_checked = False
13911408 self .short_conv_L = hyena_config .short_conv_L
13921409 self .local_init = local_init
13931410 if pg_collection is None :
@@ -1543,6 +1560,7 @@ def __init__(
15431560 """
15441561 super ().__init__ ()
15451562 self .b2b_causal_conv1d_fn = b2b_causal_conv1d
1563+ self ._check_subquadratic_ops = b2b_causal_conv1d is globals ()["b2b_causal_conv1d" ]
15461564 if pg_collection is None :
15471565 pg_collection = ProcessGroupCollection .use_mpu_process_groups ()
15481566 self .pg_collection = pg_collection
@@ -1567,6 +1585,14 @@ def __init__(
15671585 raise ValueError (f"Operator type { operator_type } not supported" )
15681586
15691587 self .effective_pad_size = (self ._mixer_kernel_size - 1 ) + (self ._proj_conv_kernel_size - 1 )
1588+ self ._subquadratic_ops_checked = False
1589+
1590+ def _ensure_subquadratic_ops_supported (self ):
1591+ """Run the B2B CUDA self-test once per wrapper instance."""
1592+ if self ._subquadratic_ops_checked or not self ._check_subquadratic_ops :
1593+ return
1594+ ensure_subquadratic_b2b_causal_conv1d_supported ()
1595+ self ._subquadratic_ops_checked = True
15701596
15711597 def forward (self , x , _use_cp = True ):
15721598 """Forward pass for the B2BCausalConv1dModule.
@@ -1580,6 +1606,8 @@ def forward(self, x, _use_cp=True):
15801606 # Validate input dimensions
15811607 if x .dim () != 3 :
15821608 raise ValueError ("Input tensor must be 3D [batch_size, hidden_dim, seq_len]" )
1609+ if x .is_cuda :
1610+ self ._ensure_subquadratic_ops_supported ()
15831611
15841612 # Extract weights at runtime to avoid parameter registration
15851613 proj_weight = self ._proj_conv_module .short_conv_weight
@@ -1713,6 +1741,9 @@ def get_filter_state(filter_name):
17131741 L = u .shape [1 ] # noqa: N806
17141742 fir_state = get_filter_state ("fir" )
17151743 if fir_state is None :
1744+ if self .use_subquadratic_ops and u .is_cuda and not self ._subquadratic_ops_checked :
1745+ ensure_subquadratic_causal_conv1d_supported ()
1746+ self ._subquadratic_ops_checked = True
17161747 z_pre , fir_state = engine .parallel_fir (
17171748 u = u ,
17181749 weight = torch .tensor (weight ), # self.short_filter_weight,
@@ -1722,6 +1753,7 @@ def get_filter_state(filter_name):
17221753 fir_length = self .kernel_size , # self.short_filter_length,
17231754 compute_state = inference_context is not None ,
17241755 use_subquadratic_ops = self .use_subquadratic_ops ,
1756+ check_subquadratic_ops = False ,
17251757 )
17261758 else :
17271759 if len (u .shape ) > 2 :
0 commit comments