Skip to content

Commit a45f22b

Browse files
committed
Fix type error
1 parent e1ab751 commit a45f22b

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

rectools/models/nn/item_net.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,4 +486,4 @@ def forward(self, items: torch.Tensor) -> torch.Tensor:
486486
@property
487487
def out_dim(self) -> int:
488488
"""Return item net constructor output dimension."""
489-
return self.item_net_blocks[0].out_dim
489+
return self.item_net_blocks[0].out_dim # type: ignore[return-value]

rectools/models/nn/transformers/lightning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def _calc_bce_loss(cls, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor)
193193
return loss
194194

195195
def _calc_gbce_loss(self, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
196-
n_actual_items = self.torch_model.item_model.n_items - len(self.item_extra_tokens)
196+
n_actual_items = tp.cast(int, self.torch_model.item_model.n_items) - len(self.item_extra_tokens)
197197
logits = self._get_reduced_overconfidence_logits(logits, n_actual_items)
198198
loss = self._calc_bce_loss(logits, y, w)
199199
return loss

0 commit comments

Comments
 (0)