Skip to content

Commit 4e9a8fb

Browse files
authored
Fix torch complier error (#127)
* fix temp in demo * update * fix compile
1 parent b0486c8 commit 4e9a8fb

1 file changed

Lines changed: 2 additions & 3 deletions

File tree

aria/inference/model_cuda.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ def _att_block(
109109
freqs_cis: torch.Tensor,
110110
mask: torch.Tensor,
111111
):
112-
113112
q, k, v = self.mixed_qkv(x).split(
114113
[self.d_model, self.d_model, self.d_model], dim=-1
115114
)
@@ -166,7 +165,7 @@ def fill_condition_kv(self, emb: torch.Tensor):
166165
assert self.model_config.emb_size is not None
167166

168167
input_pos = torch.tensor([0], device=emb.device)
169-
mask = self.causal_mask[None, None, input_pos]
168+
mask = self.causal_mask[input_pos].unsqueeze(0).unsqueeze(0)
170169
freqs_cis = self.freqs_cis[input_pos]
171170

172171
x = emb.unsqueeze(dim=1)
@@ -182,7 +181,7 @@ def forward(
182181
):
183182
assert self.freqs_cis is not None, "Caches must be initialized first"
184183

185-
mask = self.causal_mask[None, None, input_pos]
184+
mask = self.causal_mask[input_pos].unsqueeze(0).unsqueeze(0)
186185

187186
if pad_idxs is not None:
188187
mask = mask & ~(pad_idxs.unsqueeze(1).unsqueeze(1))

0 commit comments

Comments
 (0)