@@ -138,10 +138,20 @@ class BatchCollator:
138138 For sequence data, it will sample subsequences with lengths up to max_sequence_window.
139139 """
140140
141- def __init__ (self , is_sequential : bool , max_sequence_window : int | None , device : torch .device ):
141+ def __init__ (
142+ self ,
143+ is_sequential : bool ,
144+ max_sequence_window : int | None ,
145+ device : torch .device ,
146+ * ,
147+ use_nested_ctxseq : bool = True ,
148+ ):
142149 self .is_sequential = is_sequential
143150 self .max_sequence_window = max_sequence_window
144151 self .device = device
152+ # Opacus per-sample gradients do not support NestedTensor on CPU/CUDA; use padded
153+ # dense tensors for CTXSEQ when training with DP (see test_tabular_sequential DP path).
154+ self .use_nested_ctxseq = use_nested_ctxseq
145155
146156 def __call__ (self , batch : list [dict ]) -> dict [str , torch .Tensor ]:
147157 batch = pd .DataFrame (batch )
@@ -177,15 +187,26 @@ def _convert_to_tensors(self, batch: pd.DataFrame) -> dict[str, torch.Tensor]:
177187 dim = - 1 ,
178188 )
179189 elif column .startswith (CTXSEQ ):
180- # construct row tensors and convert the list to nested column tensor
181- tensors [column ] = torch .unsqueeze (
182- torch .nested .as_nested_tensor (
183- [torch .tensor (row , dtype = torch .int64 , device = self .device ) for row in batch [column ]],
184- dtype = torch .int64 ,
185- device = self .device ,
186- ),
187- dim = - 1 ,
188- )
190+ if self .use_nested_ctxseq :
191+ # construct row tensors and convert the list to nested column tensor
192+ tensors [column ] = torch .unsqueeze (
193+ torch .nested .as_nested_tensor (
194+ [torch .tensor (row , dtype = torch .int64 , device = self .device ) for row in batch [column ]],
195+ dtype = torch .int64 ,
196+ device = self .device ,
197+ ),
198+ dim = - 1 ,
199+ )
200+ else :
201+ # padded batch (variable-length rows); -1 marks padding (matches SequentialContextEmbedders)
202+ tensors [column ] = torch .unsqueeze (
203+ torch .tensor (
204+ np .array (list (zip_longest (* batch [column ], fillvalue = - 1 ))).T ,
205+ dtype = torch .int64 ,
206+ device = self .device ,
207+ ),
208+ dim = - 1 ,
209+ )
189210 return tensors
190211
191212 @staticmethod
@@ -544,7 +565,10 @@ def train(
544565
545566 # and see if it's possible to make it compatible with DP
546567 batch_collator = BatchCollator (
547- is_sequential = is_sequential , max_sequence_window = max_sequence_window , device = device
568+ is_sequential = is_sequential ,
569+ max_sequence_window = max_sequence_window ,
570+ device = device ,
571+ use_nested_ctxseq = not with_dp ,
548572 )
549573 disable_progress_bar ()
550574 trn_dataset = load_dataset ("parquet" , data_files = [str (p ) for p in workspace .encoded_data_trn .fetch_all ()])[
0 commit comments