File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -109,7 +109,6 @@ def _att_block(
109109 freqs_cis : torch .Tensor ,
110110 mask : torch .Tensor ,
111111 ):
112-
113112 q , k , v = self .mixed_qkv (x ).split (
114113 [self .d_model , self .d_model , self .d_model ], dim = - 1
115114 )
@@ -166,7 +165,7 @@ def fill_condition_kv(self, emb: torch.Tensor):
166165 assert self .model_config .emb_size is not None
167166
168167 input_pos = torch .tensor ([0 ], device = emb .device )
169- mask = self .causal_mask [None , None , input_pos ]
168+ mask = self .causal_mask [input_pos ]. unsqueeze ( 0 ). unsqueeze ( 0 )
170169 freqs_cis = self .freqs_cis [input_pos ]
171170
172171 x = emb .unsqueeze (dim = 1 )
@@ -182,7 +181,7 @@ def forward(
182181 ):
183182 assert self .freqs_cis is not None , "Caches must be initialized first"
184183
185- mask = self .causal_mask [None , None , input_pos ]
184+ mask = self .causal_mask [input_pos ]. unsqueeze ( 0 ). unsqueeze ( 0 )
186185
187186 if pad_idxs is not None :
188187 mask = mask & ~ (pad_idxs .unsqueeze (1 ).unsqueeze (1 ))
You can’t perform that action at this time.
0 commit comments