@@ -34,7 +34,7 @@ def parse_args() -> argparse.Namespace:
3434
3535 # Model
3636 parser .add_argument ("--model_family" , default = "llava_vit_sampling" )
37- parser .add_argument ("--model_name" , default = "llava_vit_base_ln " )
37+ parser .add_argument ("--model_name" , default = "ov_encoder_large " )
3838 parser .add_argument ("--model_weight" , default = "NULL" )
3939 parser .add_argument ("--num_frames" , type = int , default = 8 )
4040 parser .add_argument ("--num_tokens" , type = int , default = 1568 )
@@ -136,7 +136,7 @@ def get_feature(
136136 args: 参数配置
137137 videos: 视频数据 [B, C, T, H, W] 或图片数据 [B, C, H, W]
138138 model: 模型
139- frame_indices: 视频帧索引 [B, seq_len],用于 llava_vit_sampling
139+ frame_indices: 视频帧索引 [B, seq_len],用于 ov_encoder_large sampling
140140 total_frames: 每个视频的总帧数 [B]
141141 is_training: 是否为训练模式
142142 """
@@ -401,23 +401,11 @@ def evaluate(
401401
402402def get_model (args : argparse .Namespace ) -> nn .Module :
403403
404- if args .model_name == "hf_llava_vit_large_ln_auto " :
404+ if args .model_name == "ov_encoder_large " :
405405 model = AutoModel .from_pretrained (
406- "/video_vit/xiangan/LLaVA-ViT /onevision-encoder-large" ,
406+ "lmms-lab /onevision-encoder-large" ,
407407 trust_remote_code = True ,
408- attn_implementation = "flash_attention_2"
409- )
410- model = torch .compile (model )
411- return model
412-
413- if args .model_name == "hf_llava_vit_large_ln_remote" :
414- model = AutoModel .from_pretrained ("lmms-lab/llava-vit-large-patch14" , trust_remote_code = True )
415- model = torch .compile (model )
416- return model
417-
418- if args .model_name == "hf_llava_vit_large_ln" :
419- from model_factory .vit_ov_encoder import LlavaViTModel
420- model = LlavaViTModel .from_pretrained (args .model_weight , dtype = torch .bfloat16 )
408+ attn_implementation = "flash_attention_2" )
421409 model = torch .compile (model )
422410 return model
423411
0 commit comments