Skip to content

Commit c5b78b6

Browse files
Copilotanxiangsir
andauthored
Translate Chinese comments to English in eval_encoder directory (#35)
* Initial plan * Translate Chinese comments to English in eval_encoder/attentive_probe.py Co-authored-by: anxiangsir <31175974+anxiangsir@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: anxiangsir <31175974+anxiangsir@users.noreply.github.com>
1 parent acc4eb6 commit c5b78b6

1 file changed

Lines changed: 40 additions & 40 deletions

File tree

eval_encoder/attentive_probe.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def parse_args() -> argparse.Namespace:
4242
parser.add_argument("--tubelet_size", type=int, default=1)
4343
parser.add_argument("--embedding_size", type=int, default=768)
4444
parser.add_argument("--num_classes", type=int, default=0)
45-
# ===> 新增:目标帧数参数 <===
45+
# ===> New: target frame number parameter <===
4646
parser.add_argument("--target_frames", type=int, default=64,
4747
help="Target number of frames to interpolate to (default: 64)")
4848

@@ -61,7 +61,7 @@ def parse_args() -> argparse.Namespace:
6161
# Dataloader
6262
parser.add_argument("--dali_num_threads", type=int, default=2)
6363
parser.add_argument("--dali_py_num_workers", type=int, default=4)
64-
# ===> 新增 decord 线程数参数 <===
64+
# ===> New: decord thread number parameter <===
6565
parser.add_argument("--decord_num_threads", type=int, default=2,
6666
help="Number of threads for decord video reader.")
6767
parser.add_argument("--short_side_size", type=int, default=256)
@@ -73,13 +73,13 @@ def parse_args() -> argparse.Namespace:
7373
parser.add_argument("--seed", type=int, default=1)
7474
parser.add_argument("--save_report", default="fewshot_video_report/ActionRecognition")
7575

76-
# 分布式相关参数
76+
# Distributed training parameters
7777
parser.add_argument("--rank", type=int, default=0)
7878
parser.add_argument("--local_rank", type=int, default=0)
7979
parser.add_argument("--world_size", type=int, default=1)
8080
parser.add_argument("--global_rank", type=int, default=0)
8181

82-
# 新增:时序空间crop参数(默认与dali默认一致)
82+
# New: temporal and spatial crop parameters (defaults match DALI defaults)
8383
parser.add_argument("--num_temporal_crops", type=int, default=1, help="Number of temporal crops for evaluation")
8484
parser.add_argument("--num_spatial_crops", type=int, default=1, help="Number of spatial crops for evaluation")
8585

@@ -90,32 +90,32 @@ def parse_args() -> argparse.Namespace:
9090

9191
def interpolate_frame_indices(frame_indices: torch.Tensor, total_frames: torch.Tensor, target_frames: int = 64) -> torch.Tensor:
9292
"""
93-
将帧索引从原始视频帧数插值到目标帧数
93+
Interpolate frame indices from original video frame count to target frame count
9494
9595
Args:
96-
frame_indices: [B, seq_len] 原始帧索引
97-
total_frames: [B] 每个视频的总帧数
98-
target_frames: 目标帧数 (默认 64)
96+
frame_indices: [B, seq_len] original frame indices
97+
total_frames: [B] total frames per video
98+
target_frames: target frame count (default 64)
9999
100100
Returns:
101-
interpolated_indices: [B, seq_len] 插值后的帧索引,范围在 [0, target_frames-1]
101+
interpolated_indices: [B, seq_len] interpolated frame indices, range [0, target_frames-1]
102102
"""
103103
bs, seq_len = frame_indices.shape
104104
device = frame_indices.device
105105

106-
# total_frames 转换为浮点数以进行插值计算
106+
# Convert total_frames to float for interpolation calculation
107107
total_frames_float = total_frames.float().view(bs, 1) # [B, 1]
108108
frame_indices_float = frame_indices.float() # [B, seq_len]
109109

110-
# 插值公式: new_idx = (old_idx / (total_frames - 1)) * (target_frames - 1)
111-
# 处理 total_frames = 1 的情况files.trimTrailingWhitespace: truefiles.trimTrailingWhitespace: tru
110+
# Interpolation formula: new_idx = (old_idx / (total_frames - 1)) * (target_frames - 1)
111+
# Handle total_frames = 1 case
112112
total_frames_safe = torch.clamp(total_frames_float - 1, min=1.0)
113113
interpolated_indices = (frame_indices_float / total_frames_safe) * (target_frames - 1)
114114

115-
# 四舍五入并转换为整数
115+
# Round and convert to integer
116116
interpolated_indices = torch.round(interpolated_indices).long()
117117

118-
# 确保索引在有效范围内
118+
# Ensure indices are in valid range
119119
interpolated_indices = torch.clamp(interpolated_indices, 0, target_frames - 1)
120120

121121
return interpolated_indices
@@ -130,19 +130,19 @@ def get_feature(
130130
is_training: bool = False
131131
) -> torch.Tensor:
132132
"""
133-
获取特征,支持视频及图片输入。
133+
Extract features, supporting both video and image input.
134134
135135
Args:
136-
args: 参数配置
137-
videos: 视频数据 [B, C, T, H, W] 或图片数据 [B, C, H, W]
138-
model: 模型
139-
frame_indices: 视频帧索引 [B, seq_len],用于 ov_encoder_large sampling
140-
total_frames: 每个视频的总帧数 [B]
141-
is_training: 是否为训练模式
136+
args: argument configuration
137+
videos: video data [B, C, T, H, W] or image data [B, C, H, W]
138+
model: model
139+
frame_indices: video frame indices [B, seq_len], used for ov_encoder_large sampling
140+
total_frames: total frames per video [B]
141+
is_training: whether in training mode
142142
"""
143143
def video_to_images(videos: torch.Tensor) -> torch.Tensor:
144144
"""
145-
将视频 [B, C, T, H, W] 展开为图片序列 [B*T, C, H, W]
145+
Unfold video [B, C, T, H, W] into image sequence [B*T, C, H, W]
146146
"""
147147
B, C, T, H, W = videos.shape
148148
images = videos.permute(0, 2, 1, 3, 4).reshape(-1, C, H, W) # [B*T, C, H, W]
@@ -159,25 +159,25 @@ def video_to_images(videos: torch.Tensor) -> torch.Tensor:
159159
"aimv2"
160160
]
161161
if args.model_family in list_vit_single_image:
162-
# ===> 专门图片分支 <===
162+
# ===> Image-specific branch <===
163163
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
164164
with torch.no_grad():
165-
# 如果是视频输入,将其转化为图片
165+
# If video input, convert to images
166166
B, C, T, H, W = videos.shape
167-
if videos.dim() == 5: # 视频分支 [B, C, T, H, W]
167+
if videos.dim() == 5: # Video branch [B, C, T, H, W]
168168
videos = video_to_images(videos)
169169

170-
if videos.dim() == 4: # 检测为图片分支 [B, C, H, W]
170+
if videos.dim() == 4: # Detected as image branch [B, C, H, W]
171171
hidden_states = model(videos)
172172
if isinstance(hidden_states, dict) and "visible_embeddings" in hidden_states:
173173
hidden_states = hidden_states["visible_embeddings"]
174174

175175
# hidden_states = hidden_states.view(B, -1, hidden_states.size(-1)) # [B, seq_len, hidden_size]
176176
hidden_states = hidden_states.reshape(B, -1, hidden_states.size(-1)) # [B, seq_len, hidden_size]
177-
# ===> 新增:sin/cos 时间位置编码(2行代码)<===
177+
# ===> New: sin/cos temporal positional encoding (2 lines of code) <===
178178
pos = torch.arange(T, device=videos.device).unsqueeze(1) * torch.exp(torch.arange(0, args.embedding_size, 2, device=videos.device) * (-math.log(10000.0) / args.embedding_size)) # [T, D/2]
179179
temporal_pos = torch.stack([torch.sin(pos), torch.cos(pos)], dim=2).flatten(1)[:, :args.embedding_size] # [T, D]
180-
hidden_states = hidden_states.view(B, T, -1, args.embedding_size) + temporal_pos.unsqueeze(0).unsqueeze(2) # 加到每帧的 tokens
180+
hidden_states = hidden_states.view(B, T, -1, args.embedding_size) + temporal_pos.unsqueeze(0).unsqueeze(2) # Add to tokens of each frame
181181
hidden_states = hidden_states.view(B, -1, args.embedding_size) # [B, T*tokens_per_frame, D]
182182
return hidden_states
183183
else:
@@ -188,17 +188,17 @@ def video_to_images(videos: torch.Tensor) -> torch.Tensor:
188188
with torch.no_grad():
189189
bs, C, T, H, W = videos.shape
190190
device = videos.device
191-
frame_tokens = args.frames_token_num # 每帧的 token 数量
192-
target_frames = args.target_frames # 目标帧数,默认 64
191+
frame_tokens = args.frames_token_num # Number of tokens per frame
192+
target_frames = args.target_frames # Target frame count, default 64
193193

194194
if frame_indices is not None and total_frames is not None:
195-
# ===> 插值帧索引到 target_frames <===
195+
# ===> Interpolate frame indices to target_frames <===
196196
interpolated_indices = interpolate_frame_indices(
197197
frame_indices,
198198
total_frames.view(-1),
199199
target_frames
200200
)
201-
# ===> 计算 visible_index (基于 target_frames) <===
201+
# ===> Calculate visible_index (based on target_frames) <===
202202
per = torch.arange(frame_tokens, device=device)
203203
visible_index = (interpolated_indices.unsqueeze(-1) * frame_tokens + per).reshape(bs, -1)
204204
visible_index = visible_index.clamp_max(target_frames * frame_tokens - 1)
@@ -277,7 +277,7 @@ def train_one_experiment(
277277
head.train()
278278
train_metrics.reset()
279279
for i, batch in enumerate(loader_train):
280-
# ===> 从字典中解包数据(包括 total_frames <===
280+
# ===> Unpack data from batch dictionary (including total_frames) <===
281281
videos = batch["videos"].to(device, non_blocking=True)
282282
labels = batch["labels"].view(-1).to(device, non_blocking=True)
283283
indices = batch["indices"].to(device, non_blocking=True) # [B, seq_len]
@@ -363,9 +363,9 @@ def evaluate(
363363
total_frames = batch["total_frames"].to(device, non_blocking=True)
364364

365365
B = videos.shape[0] // num_crops
366-
# reshape为 [B, num_crops, ...]
366+
# Reshape to [B, num_crops, ...]
367367
videos = videos.view(B, num_crops, *videos.shape[1:])
368-
labels = labels.view(B, num_crops)[:, 0] # [B],同一个视频的labels一样
368+
labels = labels.view(B, num_crops)[:, 0] # [B], labels are the same for the same video
369369
indices = indices.view(B, num_crops, *indices.shape[1:])
370370
total_frames = total_frames.view(B, num_crops)[:, 0]
371371

@@ -377,9 +377,9 @@ def evaluate(
377377
logits_per_crop.append(logits)
378378
# [num_crops, B, num_classes] -> [B, num_crops, num_classes]
379379
logits_all = torch.stack(logits_per_crop, dim=1)
380-
# crop 维求平均(可 softmax 再平均/直接logit平均)
380+
# Average over crop dimension (can use softmax then average / directly average logits)
381381
logits_mean = logits_all.mean(dim=1) # [B, num_classes]
382-
# 收集
382+
# Collect results
383383
all_logits.append(logits_mean)
384384
all_targets.append(labels)
385385

@@ -501,7 +501,7 @@ def main() -> None:
501501
dali_py_num_workers=args.dali_py_num_workers,
502502
decord_num_threads=args.decord_num_threads,
503503
seed=args.seed
504-
# 训练不需要传入 num_temporal_crops/num_spatial_crops(仅eval使用)
504+
# Training does not need num_temporal_crops/num_spatial_crops (only used for evaluation)
505505
)
506506
val_loader = get_dali_dataloader(
507507
data_root_path=args.val_data_root_path,
@@ -517,8 +517,8 @@ def main() -> None:
517517
dali_py_num_workers=args.dali_py_num_workers,
518518
decord_num_threads=args.decord_num_threads,
519519
seed=1024,
520-
# num_temporal_crops=args.num_temporal_crops, # 新增!
521-
# num_spatial_crops=args.num_spatial_crops # 新增!
520+
# num_temporal_crops=args.num_temporal_crops, # New!
521+
# num_spatial_crops=args.num_spatial_crops # New!
522522
)
523523
if args.rank == 0:
524524
print("Data loaders ready.")

0 commit comments

Comments
 (0)