@@ -42,7 +42,7 @@ def parse_args() -> argparse.Namespace:
4242 parser .add_argument ("--tubelet_size" , type = int , default = 1 )
4343 parser .add_argument ("--embedding_size" , type = int , default = 768 )
4444 parser .add_argument ("--num_classes" , type = int , default = 0 )
45- # ===> 新增:目标帧数参数 <===
45+ # ===> New: target frame number parameter <===
4646 parser .add_argument ("--target_frames" , type = int , default = 64 ,
4747 help = "Target number of frames to interpolate to (default: 64)" )
4848
@@ -61,7 +61,7 @@ def parse_args() -> argparse.Namespace:
6161 # Dataloader
6262 parser .add_argument ("--dali_num_threads" , type = int , default = 2 )
6363 parser .add_argument ("--dali_py_num_workers" , type = int , default = 4 )
64- # ===> 新增 decord 线程数参数 <===
64+ # ===> New: decord thread number parameter <===
6565 parser .add_argument ("--decord_num_threads" , type = int , default = 2 ,
6666 help = "Number of threads for decord video reader." )
6767 parser .add_argument ("--short_side_size" , type = int , default = 256 )
@@ -73,13 +73,13 @@ def parse_args() -> argparse.Namespace:
7373 parser .add_argument ("--seed" , type = int , default = 1 )
7474 parser .add_argument ("--save_report" , default = "fewshot_video_report/ActionRecognition" )
7575
76- # 分布式相关参数
76+ # Distributed training parameters
7777 parser .add_argument ("--rank" , type = int , default = 0 )
7878 parser .add_argument ("--local_rank" , type = int , default = 0 )
7979 parser .add_argument ("--world_size" , type = int , default = 1 )
8080 parser .add_argument ("--global_rank" , type = int , default = 0 )
8181
82- # 新增:时序空间crop参数(默认与dali默认一致)
82+ # New: temporal and spatial crop parameters (defaults match DALI defaults)
8383 parser .add_argument ("--num_temporal_crops" , type = int , default = 1 , help = "Number of temporal crops for evaluation" )
8484 parser .add_argument ("--num_spatial_crops" , type = int , default = 1 , help = "Number of spatial crops for evaluation" )
8585
@@ -90,32 +90,32 @@ def parse_args() -> argparse.Namespace:
9090
9191def interpolate_frame_indices (frame_indices : torch .Tensor , total_frames : torch .Tensor , target_frames : int = 64 ) -> torch .Tensor :
9292 """
93- 将帧索引从原始视频帧数插值到目标帧数
93+ Interpolate frame indices from original video frame count to target frame count
9494
9595 Args:
96- frame_indices: [B, seq_len] 原始帧索引
97- total_frames: [B] 每个视频的总帧数
98- target_frames: 目标帧数 (默认 64)
96+ frame_indices: [B, seq_len] original frame indices
97+ total_frames: [B] total frames per video
98+ target_frames: target frame count (default 64)
9999
100100 Returns:
101- interpolated_indices: [B, seq_len] 插值后的帧索引,范围在 [0, target_frames-1]
101+ interpolated_indices: [B, seq_len] interpolated frame indices, range [0, target_frames-1]
102102 """
103103 bs , seq_len = frame_indices .shape
104104 device = frame_indices .device
105105
106- # 将 total_frames 转换为浮点数以进行插值计算
106+ # Convert total_frames to float for interpolation calculation
107107 total_frames_float = total_frames .float ().view (bs , 1 ) # [B, 1]
108108 frame_indices_float = frame_indices .float () # [B, seq_len]
109109
110- # 插值公式 : new_idx = (old_idx / (total_frames - 1)) * (target_frames - 1)
111- # 处理 total_frames = 1 的情况files.trimTrailingWhitespace: truefiles.trimTrailingWhitespace: tru
110+ # Interpolation formula : new_idx = (old_idx / (total_frames - 1)) * (target_frames - 1)
111+ # Handle total_frames = 1 case
112112 total_frames_safe = torch .clamp (total_frames_float - 1 , min = 1.0 )
113113 interpolated_indices = (frame_indices_float / total_frames_safe ) * (target_frames - 1 )
114114
115- # 四舍五入并转换为整数
115+ # Round and convert to integer
116116 interpolated_indices = torch .round (interpolated_indices ).long ()
117117
118- # 确保索引在有效范围内
118+ # Ensure indices are in valid range
119119 interpolated_indices = torch .clamp (interpolated_indices , 0 , target_frames - 1 )
120120
121121 return interpolated_indices
@@ -130,19 +130,19 @@ def get_feature(
130130 is_training : bool = False
131131) -> torch .Tensor :
132132 """
133- 获取特征,支持视频及图片输入。
133+ Extract features, supporting both video and image input.
134134
135135 Args:
136- args: 参数配置
137- videos: 视频数据 [B, C, T, H, W] 或图片数据 [B, C, H, W]
138- model: 模型
139- frame_indices: 视频帧索引 [B, seq_len],用于 ov_encoder_large sampling
140- total_frames: 每个视频的总帧数 [B]
141- is_training: 是否为训练模式
136+ args: argument configuration
137+ videos: video data [B, C, T, H, W] or image data [B, C, H, W]
138+ model: model
139+ frame_indices: video frame indices [B, seq_len], used for ov_encoder_large sampling
140+ total_frames: total frames per video [B]
141+ is_training: whether in training mode
142142 """
143143 def video_to_images (videos : torch .Tensor ) -> torch .Tensor :
144144 """
145- 将视频 [B, C, T, H, W] 展开为图片序列 [B*T, C, H, W]
145+ Unfold video [B, C, T, H, W] into image sequence [B*T, C, H, W]
146146 """
147147 B , C , T , H , W = videos .shape
148148 images = videos .permute (0 , 2 , 1 , 3 , 4 ).reshape (- 1 , C , H , W ) # [B*T, C, H, W]
@@ -159,25 +159,25 @@ def video_to_images(videos: torch.Tensor) -> torch.Tensor:
159159 "aimv2"
160160 ]
161161 if args .model_family in list_vit_single_image :
162- # ===> 专门图片分支 <===
162+ # ===> Image-specific branch <===
163163 with torch .cuda .amp .autocast (dtype = torch .bfloat16 ):
164164 with torch .no_grad ():
165- # 如果是视频输入,将其转化为图片
165+ # If video input, convert to images
166166 B , C , T , H , W = videos .shape
167- if videos .dim () == 5 : # 视频分支 [B, C, T, H, W]
167+ if videos .dim () == 5 : # Video branch [B, C, T, H, W]
168168 videos = video_to_images (videos )
169169
170- if videos .dim () == 4 : # 检测为图片分支 [B, C, H, W]
170+ if videos .dim () == 4 : # Detected as image branch [B, C, H, W]
171171 hidden_states = model (videos )
172172 if isinstance (hidden_states , dict ) and "visible_embeddings" in hidden_states :
173173 hidden_states = hidden_states ["visible_embeddings" ]
174174
175175 # hidden_states = hidden_states.view(B, -1, hidden_states.size(-1)) # [B, seq_len, hidden_size]
176176 hidden_states = hidden_states .reshape (B , - 1 , hidden_states .size (- 1 )) # [B, seq_len, hidden_size]
177- # ===> 新增: sin/cos 时间位置编码(2行代码) <===
177+ # ===> New: sin/cos temporal positional encoding (2 lines of code) <===
178178 pos = torch .arange (T , device = videos .device ).unsqueeze (1 ) * torch .exp (torch .arange (0 , args .embedding_size , 2 , device = videos .device ) * (- math .log (10000.0 ) / args .embedding_size )) # [T, D/2]
179179 temporal_pos = torch .stack ([torch .sin (pos ), torch .cos (pos )], dim = 2 ).flatten (1 )[:, :args .embedding_size ] # [T, D]
180- hidden_states = hidden_states .view (B , T , - 1 , args .embedding_size ) + temporal_pos .unsqueeze (0 ).unsqueeze (2 ) # 加到每帧的 tokens 上
180+ hidden_states = hidden_states .view (B , T , - 1 , args .embedding_size ) + temporal_pos .unsqueeze (0 ).unsqueeze (2 ) # Add to tokens of each frame
181181 hidden_states = hidden_states .view (B , - 1 , args .embedding_size ) # [B, T*tokens_per_frame, D]
182182 return hidden_states
183183 else :
@@ -188,17 +188,17 @@ def video_to_images(videos: torch.Tensor) -> torch.Tensor:
188188 with torch .no_grad ():
189189 bs , C , T , H , W = videos .shape
190190 device = videos .device
191- frame_tokens = args .frames_token_num # 每帧的 token 数量
192- target_frames = args .target_frames # 目标帧数,默认 64
191+ frame_tokens = args .frames_token_num # Number of tokens per frame
192+ target_frames = args .target_frames # Target frame count, default 64
193193
194194 if frame_indices is not None and total_frames is not None :
195- # ===> 插值帧索引到 target_frames <===
195+ # ===> Interpolate frame indices to target_frames <===
196196 interpolated_indices = interpolate_frame_indices (
197197 frame_indices ,
198198 total_frames .view (- 1 ),
199199 target_frames
200200 )
201- # ===> 计算 visible_index (基于 target_frames) <===
201+ # ===> Calculate visible_index (based on target_frames) <===
202202 per = torch .arange (frame_tokens , device = device )
203203 visible_index = (interpolated_indices .unsqueeze (- 1 ) * frame_tokens + per ).reshape (bs , - 1 )
204204 visible_index = visible_index .clamp_max (target_frames * frame_tokens - 1 )
@@ -277,7 +277,7 @@ def train_one_experiment(
277277 head .train ()
278278 train_metrics .reset ()
279279 for i , batch in enumerate (loader_train ):
280- # ===> 从字典中解包数据(包括 total_frames) <===
280+ # ===> Unpack data from batch dictionary (including total_frames) <===
281281 videos = batch ["videos" ].to (device , non_blocking = True )
282282 labels = batch ["labels" ].view (- 1 ).to (device , non_blocking = True )
283283 indices = batch ["indices" ].to (device , non_blocking = True ) # [B, seq_len]
@@ -363,9 +363,9 @@ def evaluate(
363363 total_frames = batch ["total_frames" ].to (device , non_blocking = True )
364364
365365 B = videos .shape [0 ] // num_crops
366- # reshape为 [B, num_crops, ...]
366+ # Reshape to [B, num_crops, ...]
367367 videos = videos .view (B , num_crops , * videos .shape [1 :])
368- labels = labels .view (B , num_crops )[:, 0 ] # [B],同一个视频的labels一样
368+ labels = labels .view (B , num_crops )[:, 0 ] # [B], labels are the same for the same video
369369 indices = indices .view (B , num_crops , * indices .shape [1 :])
370370 total_frames = total_frames .view (B , num_crops )[:, 0 ]
371371
@@ -377,9 +377,9 @@ def evaluate(
377377 logits_per_crop .append (logits )
378378 # [num_crops, B, num_classes] -> [B, num_crops, num_classes]
379379 logits_all = torch .stack (logits_per_crop , dim = 1 )
380- # 对 crop 维求平均(可 softmax 再平均/直接logit平均)
380+ # Average over crop dimension (can use softmax then average / directly average logits)
381381 logits_mean = logits_all .mean (dim = 1 ) # [B, num_classes]
382- # 收集
382+ # Collect results
383383 all_logits .append (logits_mean )
384384 all_targets .append (labels )
385385
@@ -501,7 +501,7 @@ def main() -> None:
501501 dali_py_num_workers = args .dali_py_num_workers ,
502502 decord_num_threads = args .decord_num_threads ,
503503 seed = args .seed
504- # 训练不需要传入 num_temporal_crops/num_spatial_crops(仅eval使用)
504+ # Training does not need num_temporal_crops/num_spatial_crops (only used for evaluation)
505505 )
506506 val_loader = get_dali_dataloader (
507507 data_root_path = args .val_data_root_path ,
@@ -517,8 +517,8 @@ def main() -> None:
517517 dali_py_num_workers = args .dali_py_num_workers ,
518518 decord_num_threads = args .decord_num_threads ,
519519 seed = 1024 ,
520- # num_temporal_crops=args.num_temporal_crops, # 新增!
521- # num_spatial_crops=args.num_spatial_crops # 新增!
520+ # num_temporal_crops=args.num_temporal_crops, # New!
521+ # num_spatial_crops=args.num_spatial_crops # New!
522522 )
523523 if args .rank == 0 :
524524 print ("Data loaders ready." )
0 commit comments