@@ -66,3 +66,35 @@ def _merge_with_image_features(self, inputs_embeds, input_ids, image_features):
6666 bridge_cls = KimiVLBridge ,
6767 visual_cls = KimiVLVit ,
6868))
69+
70+
71+ class KimiK25Vit (HuggingFaceVit ):
72+ module_mapping = {'vision_tower' : 'vision_tower' , 'mm_projector' : 'mm_projector' }
73+ _vision_tower = ['vision_tower' ]
74+ _aligner = ['mm_projector' ]
75+ support_multimodal = False
76+
77+ def prepare_model (self , hf_config : PretrainedConfig ):
78+ output = []
79+ for key in ['MoonViT3dPretrainedModel' , 'PatchMergerMLP' , 'VisionTowerConfig' , 'ProjectorConfig' ]:
80+ output .append (get_class_from_dynamic_module (f'modeling_kimi_k25.{ key } ' , hf_config .name_or_path ))
81+ MoonViT3dPretrainedModel , PatchMergerMLP , VisionTowerConfig , ProjectorConfig = output
82+ assert hf_config .vision_config .mm_projector_type == 'patchmerger'
83+ vit_config = VisionTowerConfig (hf_config .vision_config )
84+ proj_config = ProjectorConfig (hf_config .vision_config )
85+ self .vision_tower = MoonViT3dPretrainedModel ._from_config (vit_config )
86+ self .mm_projector = PatchMergerMLP (proj_config ).to (self .vision_tower .dtype )
87+
88+ def get_inputs_embeds (self , inputs_embeds , ** kwargs ):
89+ pixel_values = kwargs .pop ('pixel_values' , None )
90+ if pixel_values is not None :
91+ raise NotImplementedError ('Kimi-K25 currently only supports plain text training.' )
92+ return inputs_embeds
93+
94+
95+ register_model (ModelMeta (
96+ ModelType .kimi_k25 ,
97+ ['kimi_k25' ],
98+ bridge_cls = KimiVLBridge ,
99+ visual_cls = KimiK25Vit ,
100+ ))
0 commit comments