Skip to content

Commit 908a217

Browse files
committed
updated
1 parent 78d9505 commit 908a217

1 file changed

Lines changed: 5 additions & 17 deletions

File tree

eval_encoder/attentive_probe.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def parse_args() -> argparse.Namespace:
3434

3535
# Model
3636
parser.add_argument("--model_family", default="llava_vit_sampling")
37-
parser.add_argument("--model_name", default="llava_vit_base_ln")
37+
parser.add_argument("--model_name", default="ov_encoder_large")
3838
parser.add_argument("--model_weight", default="NULL")
3939
parser.add_argument("--num_frames", type=int, default=8)
4040
parser.add_argument("--num_tokens", type=int, default=1568)
@@ -136,7 +136,7 @@ def get_feature(
136136
args: 参数配置
137137
videos: 视频数据 [B, C, T, H, W] 或图片数据 [B, C, H, W]
138138
model: 模型
139-
frame_indices: 视频帧索引 [B, seq_len],用于 llava_vit_sampling
139+
frame_indices: 视频帧索引 [B, seq_len],用于 ov_encoder_large sampling
140140
total_frames: 每个视频的总帧数 [B]
141141
is_training: 是否为训练模式
142142
"""
@@ -401,23 +401,11 @@ def evaluate(
401401

402402
def get_model(args: argparse.Namespace) -> nn.Module:
403403

404-
if args.model_name == "hf_llava_vit_large_ln_auto":
404+
if args.model_name == "ov_encoder_large":
405405
model = AutoModel.from_pretrained(
406-
"/video_vit/xiangan/LLaVA-ViT/onevision-encoder-large",
406+
"lmms-lab/onevision-encoder-large",
407407
trust_remote_code=True,
408-
attn_implementation="flash_attention_2"
409-
)
410-
model = torch.compile(model)
411-
return model
412-
413-
if args.model_name == "hf_llava_vit_large_ln_remote":
414-
model = AutoModel.from_pretrained("lmms-lab/llava-vit-large-patch14", trust_remote_code=True)
415-
model = torch.compile(model)
416-
return model
417-
418-
if args.model_name == "hf_llava_vit_large_ln":
419-
from model_factory.vit_ov_encoder import LlavaViTModel
420-
model = LlavaViTModel.from_pretrained(args.model_weight, dtype=torch.bfloat16)
408+
attn_implementation="flash_attention_2")
421409
model = torch.compile(model)
422410
return model
423411

0 commit comments

Comments
 (0)