Skip to content

Commit e0afbb6

Browse files
cursoragentmplatzer
andcommitted
fix(tabular): avoid NestedTensor CTXSEQ batches when using Opacus DP
Opacus 1.x per-sample gradient hooks hit NotImplementedError on NestedTensorCPU (aten::new_empty). For DP training, collate CTXSEQ as padded dense tensors with -1 padding; SequentialContextEmbedders already masks -1 and maps to embedding index 0. Non-DP sequential training keeps nested CTXSEQ collate for unchanged behavior. Co-authored-by: Michi Platzer <michael.platzer@gmail.com>
1 parent acc9761 commit e0afbb6

2 files changed

Lines changed: 37 additions & 12 deletions

File tree

mostlyai/engine/_tabular/argn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,8 @@ def forward(self, x) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
289289
mask = None
290290
for sub_col in self.cardinalities:
291291
xs = torch.as_tensor(x[sub_col], device=self.device)
292-
xs = torch.nested.to_padded_tensor(xs, padding=-1)
292+
if xs.is_nested:
293+
xs = torch.nested.to_padded_tensor(xs, padding=-1)
293294
mask = (xs != -1).squeeze(-1)
294295
xs = torch.where(xs == -1, torch.tensor(0), xs)
295296
xs = self.get(sub_col)(xs)

mostlyai/engine/_tabular/training.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,20 @@ class BatchCollator:
138138
For sequence data, it will sample subsequences with lengths up to max_sequence_window.
139139
"""
140140

141-
def __init__(self, is_sequential: bool, max_sequence_window: int | None, device: torch.device):
141+
def __init__(
142+
self,
143+
is_sequential: bool,
144+
max_sequence_window: int | None,
145+
device: torch.device,
146+
*,
147+
use_nested_ctxseq: bool = True,
148+
):
142149
self.is_sequential = is_sequential
143150
self.max_sequence_window = max_sequence_window
144151
self.device = device
152+
# Opacus per-sample gradients do not support NestedTensor on CPU/CUDA; use padded
153+
# dense tensors for CTXSEQ when training with DP (see test_tabular_sequential DP path).
154+
self.use_nested_ctxseq = use_nested_ctxseq
145155

146156
def __call__(self, batch: list[dict]) -> dict[str, torch.Tensor]:
147157
batch = pd.DataFrame(batch)
@@ -177,15 +187,26 @@ def _convert_to_tensors(self, batch: pd.DataFrame) -> dict[str, torch.Tensor]:
177187
dim=-1,
178188
)
179189
elif column.startswith(CTXSEQ):
180-
# construct row tensors and convert the list to nested column tensor
181-
tensors[column] = torch.unsqueeze(
182-
torch.nested.as_nested_tensor(
183-
[torch.tensor(row, dtype=torch.int64, device=self.device) for row in batch[column]],
184-
dtype=torch.int64,
185-
device=self.device,
186-
),
187-
dim=-1,
188-
)
190+
if self.use_nested_ctxseq:
191+
# construct row tensors and convert the list to nested column tensor
192+
tensors[column] = torch.unsqueeze(
193+
torch.nested.as_nested_tensor(
194+
[torch.tensor(row, dtype=torch.int64, device=self.device) for row in batch[column]],
195+
dtype=torch.int64,
196+
device=self.device,
197+
),
198+
dim=-1,
199+
)
200+
else:
201+
# padded batch (variable-length rows); -1 marks padding (matches SequentialContextEmbedders)
202+
tensors[column] = torch.unsqueeze(
203+
torch.tensor(
204+
np.array(list(zip_longest(*batch[column], fillvalue=-1))).T,
205+
dtype=torch.int64,
206+
device=self.device,
207+
),
208+
dim=-1,
209+
)
189210
return tensors
190211

191212
@staticmethod
@@ -544,7 +565,10 @@ def train(
544565

545566
# and see if it's possible to make it compatible with DP
546567
batch_collator = BatchCollator(
547-
is_sequential=is_sequential, max_sequence_window=max_sequence_window, device=device
568+
is_sequential=is_sequential,
569+
max_sequence_window=max_sequence_window,
570+
device=device,
571+
use_nested_ctxseq=not with_dp,
548572
)
549573
disable_progress_bar()
550574
trn_dataset = load_dataset("parquet", data_files=[str(p) for p in workspace.encoded_data_trn.fetch_all()])[

0 commit comments

Comments
 (0)