@@ -319,7 +319,9 @@ def _compute_video_freqs(
319319 self , frame : int , height : int , width : int , idx : int = 0 , device : torch .device = None
320320 ) -> torch .Tensor :
321321 seq_lens = frame * height * width
322- pos_freqs , neg_freqs = self ._get_device_freqs (device ) if device is not None else (self .pos_freqs , self .neg_freqs )
322+ pos_freqs , neg_freqs = (
323+ self ._get_device_freqs (device ) if device is not None else (self .pos_freqs , self .neg_freqs )
324+ )
323325
324326 freqs_pos = pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
325327 freqs_neg = neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
@@ -445,7 +447,9 @@ def forward(
445447 @lru_cache_unless_export (maxsize = None )
446448 def _compute_video_freqs (self , frame , height , width , idx = 0 , device : torch .device = None ):
447449 seq_lens = frame * height * width
448- pos_freqs , neg_freqs = self ._get_device_freqs (device ) if device is not None else (self .pos_freqs , self .neg_freqs )
450+ pos_freqs , neg_freqs = (
451+ self ._get_device_freqs (device ) if device is not None else (self .pos_freqs , self .neg_freqs )
452+ )
449453
450454 freqs_pos = pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
451455 freqs_neg = neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
@@ -466,7 +470,9 @@ def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device
466470 @lru_cache_unless_export (maxsize = None )
467471 def _compute_condition_freqs (self , frame , height , width , device : torch .device = None ):
468472 seq_lens = frame * height * width
469- pos_freqs , neg_freqs = self ._get_device_freqs (device ) if device is not None else (self .pos_freqs , self .neg_freqs )
473+ pos_freqs , neg_freqs = (
474+ self ._get_device_freqs (device ) if device is not None else (self .pos_freqs , self .neg_freqs )
475+ )
470476
471477 freqs_pos = pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
472478 freqs_neg = neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
0 commit comments