Skip to content

Commit 9c06e05

Browse files
committed
feat(aero_realtime): make audio tower causality configurable via is_causal
1 parent 695d0bc commit 9c06e05

3 files changed

Lines changed: 10 additions & 2 deletions

File tree

src/lmms_engine/models/aero_realtime/aero_realtime_audio_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _varlen_attention_forward(
102102
cu_seqlens_k=cu_seq_lens,
103103
max_seqlen_q=max_seqlen,
104104
max_seqlen_k=max_seqlen,
105-
causal=True,
105+
causal=bool(getattr(self.config, "is_causal", False)),
106106
softmax_scale=self.scaling,
107107
window_size=window_size,
108108
dropout_p=0.0 if not self.training else self.attention_dropout,

src/lmms_engine/models/aero_realtime/configuration_aero_realtime.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(
7373
sliding_window: int | None = 750,
7474
attention_window_left: int | None = None,
7575
attention_window_right: int = 0,
76+
is_causal: bool = False,
7677
norm_type: str = "rms_norm",
7778
mlp_type: str = "swiglu",
7879
conv_padding: str = "causal",
@@ -100,6 +101,7 @@ def __init__(
100101
attention_window_left = (sliding_window - 1) if sliding_window is not None else -1
101102
self.attention_window_left = attention_window_left
102103
self.attention_window_right = attention_window_right
104+
self.is_causal = bool(is_causal)
103105

104106
if norm_type not in ("rms_norm", "layer_norm"):
105107
raise ValueError(f"norm_type must be 'rms_norm' or 'layer_norm', got {norm_type}")

src/lmms_engine/models/aero_realtime/modeling_aero_realtime.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -952,6 +952,9 @@ def _aero_audio_eager_forward(
952952
key_rep = repeat_kv(key, module.num_key_value_groups)
953953
value_rep = repeat_kv(value, module.num_key_value_groups)
954954

955+
if bool(getattr(module.config, "is_causal", False)):
956+
window_right = 0 if window_right < 0 else min(window_right, 0)
957+
955958
local_mask = _aero_audio_build_local_additive_mask(
956959
q_len=q_len,
957960
k_len=k_len,
@@ -989,6 +992,9 @@ def _aero_audio_sdpa_forward(
989992
key_rep = repeat_kv(key, module.num_key_value_groups)
990993
value_rep = repeat_kv(value, module.num_key_value_groups)
991994

995+
if bool(getattr(module.config, "is_causal", False)):
996+
window_right = 0 if window_right < 0 else min(window_right, 0)
997+
992998
local_mask = _aero_audio_build_local_additive_mask(
993999
q_len=q_len,
9941000
k_len=k_len,
@@ -1033,7 +1039,7 @@ def _aero_audio_fa_forward(
10331039
v = value.transpose(1, 2)
10341040
bsz, q_len, _, _ = q.shape
10351041

1036-
causal = q_len == k.shape[1]
1042+
causal = bool(getattr(module.config, "is_causal", False)) or (q_len == k.shape[1])
10371043
window = (window_left, window_right)
10381044

10391045
if attention_mask is None or attention_mask.dim() != 2:

0 commit comments

Comments
 (0)