Skip to content

Commit 54bb866

Browse files
author
TOPAPEC
committed
Fix remaining review items: make_ffn docstring, N812 violation, padding test
- Add all parameters and Returns to make_ffn docstring (feldlime) - Replace import torch.nn.functional as F with direct imports in lightning.py to eliminate N812 violation; remove N812 from setup.cfg - Add padding-zeroed assertion to test_determinism_and_padding_masking so the test matches its docstring claim (copilot)
1 parent 3aaef75 commit 54bb866

4 files changed

Lines changed: 19 additions & 6 deletions

File tree

rectools/fast_transformers/unisrec/lightning.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pytorch_lightning as pl
77
import torch
8-
import torch.nn.functional as F
8+
from torch.nn.functional import binary_cross_entropy_with_logits, cross_entropy
99
from torch.optim.lr_scheduler import LambdaLR
1010

1111
from .net import UniSRecNet
@@ -114,7 +114,7 @@ def _full_softmax_loss(self, hidden: torch.Tensor, labels: torch.Tensor) -> torc
114114

115115
targets = labels.clone()
116116
targets[targets == 0] = -100
117-
return F.cross_entropy(
117+
return cross_entropy(
118118
logits.view(-1, logits.size(-1)),
119119
targets.view(-1),
120120
ignore_index=-100,
@@ -125,7 +125,7 @@ def _sampled_softmax_loss(self, logits: torch.Tensor, mask: torch.Tensor) -> tor
125125
logits = logits.clone()
126126
logits[:, :, [0, 1]] = logits[:, :, [1, 0]]
127127
targets = mask.long() # 1 where non-padding, 0 where padding
128-
return F.cross_entropy(
128+
return cross_entropy(
129129
logits.view(-1, logits.size(-1)),
130130
targets.view(-1),
131131
ignore_index=0,
@@ -134,7 +134,7 @@ def _sampled_softmax_loss(self, logits: torch.Tensor, mask: torch.Tensor) -> tor
134134
def _bce_loss(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
135135
target = torch.zeros_like(logits)
136136
target[:, :, 0] = 1.0
137-
loss = F.binary_cross_entropy_with_logits(logits, target, reduction="none")
137+
loss = binary_cross_entropy_with_logits(logits, target, reduction="none")
138138
loss = loss.mean(-1) * mask
139139
return loss.sum() / mask.sum().clamp(min=1)
140140

rectools/fast_transformers/unisrec/net.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,19 @@ def make_ffn(n_factors: int, ffn_type: str, expansion: int, dropout: float) -> n
4444
4545
Parameters
4646
----------
47+
n_factors : int
48+
Input and output dimension.
4749
ffn_type : ``"conv1d"`` | ``"linear_gelu"`` | ``"linear_relu"``
48-
expansion : hidden-dim multiplier (e.g. 1 or 4).
50+
Type of feed-forward block.
51+
expansion : int
52+
Hidden-dimension multiplier (e.g. 1 or 4).
53+
dropout : float
54+
Dropout rate applied inside the block.
55+
56+
Returns
57+
-------
58+
nn.Module
59+
A feed-forward network module.
4960
"""
5061
if ffn_type == "conv1d":
5162
return FeedForwardConv1d(n_factors, dropout)

setup.cfg

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ per-file-ignores =
5050
rectools/dataset/torch_datasets.py: D101,D102
5151
rectools/models/implicit_als.py: N806
5252
rectools/fast_transformers/net.py: N806
53-
rectools/fast_transformers/unisrec/lightning.py: N812
5453
rectools/fast_transformers/unisrec/net.py: N806
5554

5655
[mypy]

tests/fast_transformers/test_unisrec_net.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,7 @@ def test_determinism_and_padding_masking(self, net: UniSRecNet) -> None:
9292
with torch.no_grad():
9393
e_a = net.encode_last(x_a)
9494
e_b = net.encode_last(x_b)
95+
h_a = net(x_a)
9596
torch.testing.assert_close(e_a, e_b)
97+
# Padding positions (first 3 columns) should be zeroed in full output
98+
torch.testing.assert_close(h_a[:, :3, :], torch.zeros_like(h_a[:, :3, :]))

0 commit comments

Comments
 (0)