Skip to content

Commit b68fa0b

Browse files
authored
[Fix] Aria and LLama Vision and OpenAI compatible models (#641)
* Enhance Qwen model with additional parameters and improved visual handling - Added `system_prompt`, `interleave_visuals`, and `max_length` parameters to Qwen2_VL class. - Simplified device assignment logic for single process scenarios. - Improved visual processing by refining how visuals are handled and ensuring proper mapping to contexts. - Enhanced message construction to support interleaving of visuals and text based on placeholders. - Set default generation parameters and refined handling of generated outputs to ensure proper trimming and formatting. * Refactor VisualPuzzles utility functions for improved readability and consistency - Updated string formatting to use double quotes for consistency. - Enhanced the `VisualPuzzles_doc_to_text` and `parse_response` functions for better clarity and structure. - Simplified conditional checks and improved whitespace handling in response parsing. - Ensured consistent handling of options and answers throughout the utility functions. * Update model imports and parameters for Aria, LlamaVision, and OpenAICompatible - Replaced AutoModelForCausalLM and AutoProcessor with AriaForConditionalGeneration and AriaProcessor in the Aria model. - Updated the pretrained model string in LlamaVision to "meta-llama/Llama-3.2-11B-Vision-Instruct". - Enhanced OpenAICompatible to support AzureOpenAI and modified API key handling for better flexibility. - Adjusted timeout parameter and refined token handling in OpenAICompatible for improved functionality.
1 parent 9596fbd commit b68fa0b

7 files changed

Lines changed: 394 additions & 225 deletions

File tree

lmms_eval/models/aria.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from decord import VideoReader, cpu
1111
from PIL import Image
1212
from tqdm import tqdm
13-
from transformers import AutoModelForCausalLM, AutoProcessor
13+
from transformers import AriaForConditionalGeneration, AriaProcessor
1414

1515
from lmms_eval import utils
1616
from lmms_eval.api.instance import Instance
@@ -19,6 +19,8 @@
1919

2020
warnings.filterwarnings("ignore")
2121

22+
import re
23+
2224
from loguru import logger as eval_logger
2325

2426
DEFAULT_IMAGE_TOKEN = "<image>"
@@ -72,10 +74,10 @@ def __init__(
7274
dtype = getattr(torch, dtype)
7375

7476
self.max_frames_num = max_frames_num
75-
self._model = AutoModelForCausalLM.from_pretrained(pretrained, revision=revision, device_map=self.device_map, torch_dtype=torch.bfloat16, trust_remote_code=True, attn_implementation=attn_implementation)
77+
self._model = AriaForConditionalGeneration.from_pretrained(pretrained, revision=revision, device_map=self.device_map, torch_dtype=torch.bfloat16, trust_remote_code=True, attn_implementation=attn_implementation)
7678

7779
self.pretrained = pretrained
78-
self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision, trust_remote_code=True)
80+
self._image_processor = AriaProcessor.from_pretrained(pretrained, revision=revision, trust_remote_code=True)
7981
self._tokenizer = self._image_processor.tokenizer
8082

8183
self._config = self._model.config

lmms_eval/models/llama_vision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
class LlamaVision(lmms):
2828
def __init__(
2929
self,
30-
pretrained: str = "meta-llama/Llama-3.2-11B-Vision",
30+
pretrained: str = "meta-llama/Llama-3.2-11B-Vision-Instruct",
3131
revision: str = "main",
3232
device: str = "cuda",
3333
dtype: Optional[Union[str, torch.dtype]] = "auto",

lmms_eval/models/openai_compatible.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from dotenv import find_dotenv, load_dotenv
2323
from loguru import logger as eval_logger
24-
from openai import OpenAI
24+
from openai import AzureOpenAI, OpenAI
2525
from PIL import Image
2626

2727
load_dotenv(verbose=True)
@@ -32,11 +32,12 @@ class OpenAICompatible(lmms):
3232
def __init__(
3333
self,
3434
model_version: str = "grok-2-latest",
35-
timeout: int = 120,
35+
timeout: int = 10,
3636
max_retries: int = 5,
3737
max_size_in_mb: int = 20,
3838
continual_mode: bool = False,
3939
response_persistent_folder: str = None,
40+
azure_openai: bool = False,
4041
**kwargs,
4142
) -> None:
4243
super().__init__()
@@ -61,7 +62,11 @@ def __init__(
6162
self.response_cache = {}
6263
self.cache_mode = "start"
6364

64-
self.client = OpenAI(api_key=os.getenv("OPENAI_COMPATIBLE_API_KEY"), base_url=os.getenv("OPENAI_COMPATIBLE_API_URL"))
65+
self.client = (
66+
OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_API_BASE"))
67+
if not azure_openai
68+
else AzureOpenAI(api_key=os.getenv("AZURE_OPENAI_API_KEY"), azure_endpoint=os.getenv("AZURE_OPENAI_API_BASE"), api_version=os.getenv("AZURE_OPENAI_API_VERSION"))
69+
)
6570

6671
accelerator = Accelerator()
6772
# assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue."
@@ -186,14 +191,17 @@ def generate_until(self, requests) -> List[str]:
186191
if "num_beams" not in gen_kwargs:
187192
gen_kwargs["num_beams"] = 1
188193

189-
payload["max_completion_tokens"] = gen_kwargs["max_new_tokens"]
194+
# payload["max_completion_tokens"] = gen_kwargs["max_new_tokens"]
195+
payload["max_tokens"] = gen_kwargs["max_new_tokens"]
190196
payload["temperature"] = gen_kwargs["temperature"]
191197

192198
if "o1" in self.model_version or "o3" in self.model_version:
193199
# del payload["max_output_tokens"]
194200
del payload["temperature"]
195201
payload["reasoning_effort"] = "medium"
196202
payload["response_format"] = {"type": "text"}
203+
payload.pop("max_tokens")
204+
payload["max_completion_tokens"] = gen_kwargs["max_tokens"]
197205

198206
for attempt in range(self.max_retries):
199207
try:

lmms_eval/models/qwen2_5_vl.py

Lines changed: 103 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import base64
2+
import re
23
from io import BytesIO
34
from typing import List, Optional, Tuple, Union
45

@@ -48,6 +49,9 @@ def __init__(
4849
use_custom_video_loader: Optional[bool] = False,
4950
fps: Optional[float] = None, # Only applicable if use_custom_video_loader is True
5051
max_image_size: Optional[int] = None, # Only applicable if use_custom_video_loader is True
52+
system_prompt: Optional[str] = "You are a helpful assistant.",
53+
interleave_visuals: Optional[bool] = False,
54+
reasoning_prompt: Optional[str] = None,
5155
**kwargs,
5256
) -> None:
5357
super().__init__()
@@ -66,12 +70,9 @@ def __init__(
6670
if accelerator.num_processes > 1:
6771
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
6872
self.device_map = f"cuda:{accelerator.local_process_index}"
69-
elif accelerator.num_processes == 1 and device_map == "auto":
70-
self._device = torch.device(device)
71-
self.device_map = device_map
7273
else:
73-
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
74-
self.device_map = f"cuda:{accelerator.local_process_index}"
74+
self._device = torch.device(device)
75+
self.device_map = device_map if device_map else device
7576

7677
if use_flash_attention_2:
7778
self._model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
@@ -85,10 +86,18 @@ def __init__(
8586
self.max_pixels = max_pixels
8687
self.min_pixels = min_pixels
8788
self.max_num_frames = max_num_frames
88-
self.processor = AutoProcessor.from_pretrained(pretrained, max_pixels=max_pixels, min_pixels=min_pixels, padding_side="left")
89-
self._tokenizer = AutoTokenizer.from_pretrained(pretrained, padding_side="left")
89+
90+
if reasoning_prompt:
91+
self.reasoning_prompt = reasoning_prompt.replace("\\n", "\n")
92+
else:
93+
self.reasoning_prompt = None
94+
self.processor = AutoProcessor.from_pretrained(pretrained, max_pixels=max_pixels, min_pixels=min_pixels)
95+
self._tokenizer = AutoTokenizer.from_pretrained(pretrained)
96+
self.system_prompt = system_prompt
97+
self.interleave_visuals = interleave_visuals
9098

9199
self._config = self.model.config
100+
self._max_length = kwargs.get("max_length", 2048)
92101
self.batch_size_per_gpu = int(batch_size)
93102
self.use_cache = use_cache
94103

@@ -184,8 +193,11 @@ def _collate(x):
184193
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
185194
task = task[0]
186195
split = split[0]
187-
visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
188-
visuals = self.flatten(visuals)
196+
visual_list = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
197+
if None in visual_list:
198+
visual_list = []
199+
else:
200+
visual_list = self.flatten(visual_list)
189201

190202
gen_kwargs = all_gen_kwargs[0]
191203

@@ -200,112 +212,116 @@ def _collate(x):
200212
elif not isinstance(until, list):
201213
raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}")
202214

203-
# if isinstance(contexts, tuple):
204-
# contexts = list(contexts)
205-
206-
# for i in range(len(contexts)):
207-
# for j in range(32):
208-
# if f"<image {j}>" in contexts[i]:
209-
# contexts[i] = contexts[i].replace(f"<image {j}>", "<image>")
210-
# if f"\\<image {j}\\>" in contexts[i]:
211-
# contexts[i] = contexts[i].replace(f"\\<image {j}\\>", "<image>")
212-
# if "<image>" in contexts[i]:
213-
# contexts[i] = contexts[i].replace("<image>", "")
214-
# print(contexts[i])
215-
216-
# for i in range(len(contexts)):
217-
# if "<image>" in contexts[i]:
218-
# contexts[i] = contexts[i].replace("<image>", "")
219-
220-
messages = []
221-
processed_visuals = []
215+
if isinstance(contexts, tuple):
216+
contexts = list(contexts)
217+
218+
for i in range(len(contexts)):
219+
if "<image>" in contexts[i]:
220+
contexts[i] = contexts[i].replace("<image>", "")
221+
222+
batched_messages = []
222223
for i, context in enumerate(contexts):
223-
# context += "\nPlease think step by step."
224-
# if "<image>" in context:
225-
# context = context.replace("<image>", "")
224+
if "<image>" in context:
225+
context = context.replace("<image>", "")
226226

227-
message = [{"role": "system", "content": "You are a helpful assistant."}]
227+
message = [{"role": "system", "content": self.system_prompt}]
228+
if self.reasoning_prompt:
229+
context = context.strip() + self.reasoning_prompt
230+
contexts[i] = context
228231

229-
if len(visuals) > 0:
230-
visual = visuals[i] if i < len(visuals) else None
232+
processed_visuals = []
233+
for visual in visual_list:
231234
if isinstance(visual, str) and visual.endswith((".mp4", ".avi", ".mov")): # Video file
232-
if self.use_custom_video_loader:
233-
visual = read_video_pyav_base64(visual, num_frm=self.max_num_frames, fps=self.fps, img_format="JPEG", max_image_size=self.max_image_size)
234-
image_contents = list(map(lambda x: f"data:image/jpeg;base64,{x}", visual))
235-
message.append({"role": "user", "content": [{"type": "video", "video": image_contents}, {"type": "text", "text": context}]})
236-
else:
237-
vr = decord.VideoReader(visual)
238-
first_frame = vr[0].asnumpy()
239-
height, width = first_frame.shape[:2]
240-
# max_pixels = height * width
241-
message.append({"role": "user", "content": [{"type": "video", "video": visual, "max_pixels": 360 * 420}, {"type": "text", "text": context}]})
242-
elif isinstance(visual, Image.Image): # Single image
235+
vr = decord.VideoReader(visual)
236+
first_frame = vr[0].asnumpy()
237+
height, width = first_frame.shape[:2]
238+
# max_pixels = height * width
239+
processed_visuals.append({"type": "video", "video": visual, "max_pixels": self.max_pixels, "min_pixels": self.min_pixels})
240+
elif isinstance(visual, Image.Image): # Handle both single and multiple images
243241
base64_image = visual.convert("RGB")
244242
buffer = BytesIO()
245243
base64_image.save(buffer, format="JPEG")
246244
base64_bytes = base64.b64encode(buffer.getvalue())
247245
base64_string = base64_bytes.decode("utf-8")
248-
message.append({"role": "user", "content": [{"type": "image", "image": f"data:image/jpeg;base64,{base64_string}"}, {"type": "text", "text": context}]})
249-
elif isinstance(visual, (list, tuple)) and all(isinstance(v, Image.Image) for v in visual): # Multiple images
250-
image_content = []
251-
for v in visual:
252-
base64_image = v.convert("RGB")
253-
buffer = BytesIO()
254-
base64_image.save(buffer, format="JPEG")
255-
base64_bytes = base64.b64encode(buffer.getvalue())
256-
base64_string = base64_bytes.decode("utf-8")
257-
image_content.append({"type": "image", "image": f"data:image/jpeg;base64,{base64_string}"})
258-
message.append({"role": "user", "content": image_content + [{"type": "text", "text": context}]})
259-
else:
260-
message.append({"role": "user", "content": [{"type": "text", "text": context}]})
261-
else:
262-
message.append({"role": "user", "content": [{"type": "text", "text": context}]})
263-
264-
messages.append(message)
265-
# print("message")
266-
267-
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
268-
image_inputs, video_inputs = process_vision_info(messages)
269-
inputs = self.processor(
270-
text=text,
271-
images=image_inputs,
272-
videos=video_inputs,
273-
# fps=self.fps,
274-
padding=True,
275-
return_tensors="pt",
276-
)
246+
processed_visuals.append({"type": "image", "image": f"data:image/jpeg;base64,{base64_string}", "max_pixels": self.max_pixels, "min_pixels": self.min_pixels})
247+
248+
if self.interleave_visuals is False:
249+
message.append(
250+
{
251+
"role": "user",
252+
"content": processed_visuals + [{"type": "text", "text": context}],
253+
}
254+
)
255+
else: # currently support find <image x> in the context
256+
image_placeholders = re.findall(r"<image \d+>", context)
257+
content_parts = []
258+
text_parts = re.split(r"<image \d+>", context)
259+
if text_parts[0]:
260+
content_parts.append({"type": "text", "text": text_parts[0]})
261+
262+
for i, placeholder in enumerate(image_placeholders):
263+
img_idx = int(re.search(r"<image (\d+)>", placeholder).group(1)) - 1
264+
image_idx = min(img_idx, len(processed_visuals) - 1) if processed_visuals else 0
265+
if processed_visuals and image_idx < len(processed_visuals):
266+
content_parts.append(processed_visuals[image_idx])
267+
if i + 1 < len(text_parts) and text_parts[i + 1]:
268+
content_parts.append({"type": "text", "text": text_parts[i + 1]})
269+
270+
message.append(
271+
{
272+
"role": "user",
273+
"content": content_parts,
274+
}
275+
)
276+
277+
batched_messages.append(message)
278+
279+
texts = [self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batched_messages]
280+
image_inputs, video_inputs = process_vision_info(batched_messages)
281+
if video_inputs is not None:
282+
total_frames = video_inputs[0].shape[0]
283+
indices = np.linspace(0, total_frames - 1, self.max_num_frames, dtype=int)
284+
# Append the last frame index if not already included
285+
if total_frames - 1 not in indices:
286+
indices = np.append(indices, total_frames - 1)
287+
video_inputs[0] = video_inputs[0][indices]
288+
inputs = self.processor(text=texts, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
277289

278290
if self.device_map == "auto":
279291
inputs = inputs.to("cuda")
280292
else:
281293
inputs = inputs.to(self.device)
282294

283-
if "max_new_tokens" not in gen_kwargs:
284-
gen_kwargs["max_new_tokens"] = 4096
285-
if "temperature" not in gen_kwargs:
286-
gen_kwargs["temperature"] = 0
287-
if "top_p" not in gen_kwargs:
288-
gen_kwargs["top_p"] = None
289-
if "num_beams" not in gen_kwargs:
290-
gen_kwargs["num_beams"] = 1
295+
# Set default generation kwargs
296+
default_gen_kwargs = {
297+
"max_new_tokens": 128,
298+
"temperature": 0.0, # Set to 0 for greedy default
299+
"top_p": None,
300+
"num_beams": 1,
301+
}
302+
# Update with provided kwargs
303+
current_gen_kwargs = {**default_gen_kwargs, **gen_kwargs}
291304

292305
pad_token_id = self.tokenizer.pad_token_id
293306

294307
cont = self.model.generate(
295308
**inputs,
296309
eos_token_id=self.tokenizer.eos_token_id,
297310
pad_token_id=pad_token_id,
298-
do_sample=True if gen_kwargs["temperature"] > 0 else False,
299-
temperature=gen_kwargs["temperature"],
300-
top_p=gen_kwargs["top_p"],
301-
num_beams=gen_kwargs["num_beams"],
302-
max_new_tokens=gen_kwargs["max_new_tokens"],
311+
do_sample=True if current_gen_kwargs["temperature"] > 0 else False,
312+
temperature=current_gen_kwargs["temperature"],
313+
top_p=current_gen_kwargs["top_p"],
314+
num_beams=current_gen_kwargs["num_beams"],
315+
max_new_tokens=current_gen_kwargs["max_new_tokens"],
303316
use_cache=self.use_cache,
304317
)
305318

306319
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, cont)]
307320
answers = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
308321
for i, ans in enumerate(answers):
322+
for term in until:
323+
if len(term) > 0:
324+
ans = ans.split(term)[0]
309325
answers[i] = ans
310326

311327
for ans, context in zip(answers, contexts):

0 commit comments

Comments
 (0)