Skip to content

Commit 9596fbd

Browse files
authored
Adds PerceptionLM and PLM-VideoBench (#638)
* Implements PLM and PLM-VideoBench from 'PerceptionLM: Open-Access Data and Models for Detailed Visual Understanding' * Updates docs. * Removes redundant code.
1 parent 43d616f commit 9596fbd

35 files changed

Lines changed: 1499 additions & 12 deletions

lmms_eval/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
logger.add(sys.stdout, level="WARNING")
1212

1313
AVAILABLE_MODELS = {
14+
"plm": "PerceptionLM",
1415
"aria": "Aria",
1516
"auroracap": "AuroraCap",
1617
"batch_gpt4": "BatchGPT4",

lmms_eval/models/plm.py

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
from typing import List, Optional, Tuple, Union
2+
3+
import torch
4+
from accelerate import Accelerator, DistributedType
5+
from lmms_eval import utils
6+
from lmms_eval.api.instance import Instance
7+
from lmms_eval.api.model import lmms
8+
from lmms_eval.api.registry import register_model
9+
from loguru import logger as eval_logger
10+
from omegaconf import OmegaConf
11+
from PIL import Image
12+
from tqdm import tqdm
13+
14+
from apps.plm.generate import (PackedCausalTransformerGenerator,
15+
PackedCausalTransformerGeneratorArgs,
16+
load_consolidated_model_and_tokenizer)
17+
from core.args import dataclass_from_dict
18+
from core.transforms.image_transform import get_image_transform
19+
from core.transforms.video_transform import get_video_transform
20+
21+
22+
@register_model("plm")
23+
class PerceptionLM(lmms):
24+
"""
25+
Perception Lanugate Model (PLM)
26+
"Paste the paper link"
27+
"Paste the github link"
28+
"Paste the huggingface link"
29+
"""
30+
31+
def __init__(
32+
self,
33+
pretrained: str = "facebook/Perception-LM-8B",
34+
device: Optional[str] = "cuda",
35+
batch_size: Optional[Union[int, str]] = 1,
36+
compile_prefilling=False,
37+
reduce_generation_overhead=False,
38+
max_tokens=11264,
39+
**kwargs,
40+
) -> None:
41+
super().__init__()
42+
43+
accelerator = Accelerator()
44+
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
45+
46+
# Collect all arguments into a dictionary
47+
args = {
48+
"pretrained": pretrained,
49+
"device": device,
50+
"batch_size": batch_size,
51+
"compile_prefilling": compile_prefilling,
52+
"reduce_generation_overhead": reduce_generation_overhead,
53+
"max_tokens": max_tokens,
54+
**kwargs, # Include any additional keyword arguments
55+
}
56+
# Convert the dictionary to a dotlist format
57+
dotlist = [f"{key}={value}" for key, value in args.items()]
58+
cfg = OmegaConf.from_dotlist(dotlist)
59+
gen_cfg = dataclass_from_dict(PackedCausalTransformerGeneratorArgs, cfg, strict=False)
60+
# Load PLM model
61+
eval_logger.info(f"Lodding PLM model from {cfg.pretrained}")
62+
model, tokenizer, config = load_consolidated_model_and_tokenizer(cfg.pretrained)
63+
64+
# Create preprocessors (transforms)
65+
processor = {}
66+
vision_input_type = config.get("model").get("vision_input_type", "thumb+tile")
67+
max_num_tiles = config.get("model").get("max_num_tiles", 36)
68+
processor["image"] = get_image_transform(vision_input_type=vision_input_type, image_res=model.vision_model.image_size, max_num_tiles=max_num_tiles)
69+
processor["video"] = get_video_transform(image_res=model.vision_model.image_size)
70+
self._video_max_frames = config.get("model").get("video_max_frames", 32)
71+
72+
# Create PLM generator
73+
eval_logger.info(f"Creating packed generator with gen_cfg: {gen_cfg}")
74+
generator = PackedCausalTransformerGenerator(gen_cfg, model, tokenizer)
75+
76+
# Set the class variables
77+
self._tokenizer = tokenizer
78+
self._processor = processor
79+
self._model = model
80+
self._generator = generator
81+
self.batch_size_per_gpu = int(batch_size)
82+
83+
if accelerator.num_processes > 1:
84+
assert accelerator.distributed_type in [
85+
DistributedType.FSDP,
86+
DistributedType.MULTI_GPU,
87+
], "Unsupported distributed type provided. Only DDP and FSDP are supported."
88+
if accelerator.distributed_type == DistributedType.FSDP:
89+
self._model = accelerator.prepare(self.model)
90+
else:
91+
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
92+
self.accelerator = accelerator
93+
if self.accelerator.is_local_main_process:
94+
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
95+
self._rank = self.accelerator.process_index
96+
self._world_size = self.accelerator.num_processes
97+
else:
98+
self._rank = 0
99+
self._world_size = 1
100+
101+
@property
102+
def generator(self):
103+
return self._generator
104+
105+
@property
106+
def tokenizer(self):
107+
return self._tokenizer
108+
109+
@property
110+
def processor(self):
111+
return self._processor
112+
113+
@property
114+
def model(self):
115+
# returns the model, unwrapping it if using Accelerate
116+
if hasattr(self, "accelerator"):
117+
return self.accelerator.unwrap_model(self._model)
118+
else:
119+
return self._model
120+
121+
@property
122+
def eot_token_id(self):
123+
# we use EOT because end of text is more accurate for what we're doing than end of sentence
124+
return self.tokenizer.eos_token_id
125+
126+
@property
127+
def batch_size(self):
128+
return self.batch_size_per_gpu
129+
130+
@property
131+
def video_max_frames(self):
132+
return self._video_max_frames
133+
134+
@property
135+
def device(self):
136+
return self._device
137+
138+
@property
139+
def rank(self):
140+
return self._rank
141+
142+
@property
143+
def world_size(self):
144+
return self._world_size
145+
146+
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
147+
raise NotImplementedError("Loglikelihood is not implemented for PLM")
148+
149+
def flatten(self, input):
150+
new_list = []
151+
for i in input:
152+
for j in i:
153+
new_list.append(j)
154+
return new_list
155+
156+
def generate_until(self, requests: List[Instance]) -> List[str]:
157+
res = []
158+
159+
def _collate(x):
160+
# the negative sign on len(toks) sorts descending - this has a few advantages:
161+
# - time estimates will always be over not underestimates, which is more useful for planning
162+
# - to know the size of a batch when going through the list, you know the first one is always the batch
163+
# padded context length. this is useful to simplify the batching logic and more importantly to make
164+
# automatic adaptive batches much much easier to implement
165+
# - any OOMs will happen right away rather than near the end
166+
toks = self.tokenizer.encode(x[0], add_bos=False, add_eos=False)
167+
return -len(toks), x[0]
168+
169+
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
170+
# we group requests by their generation_kwargs,
171+
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
172+
# in the same batch.
173+
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
174+
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
175+
for chunk in chunks:
176+
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
177+
task = task[0]
178+
split = split[0]
179+
visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
180+
visuals = self.flatten(visuals)
181+
182+
messages = []
183+
for i, context in enumerate(contexts):
184+
if len(visuals) > 0:
185+
visual = visuals[i] if i < len(visuals) else None
186+
if isinstance(visual, str) and visual.endswith((".mp4", ".avi", ".mov")): # Video file
187+
video_info = (visual, self.video_max_frames, None, None, None)
188+
visual, _ = self.processor["video"](video_info)
189+
message = (context, visual)
190+
elif isinstance(visual, Image.Image): # Single image
191+
visual = visual.convert("RGB")
192+
visual, _ = self.processor["image"](visual)
193+
message = (context, visual)
194+
elif isinstance(visual, (list, tuple)) and all(isinstance(v, Image.Image) for v in visual): # Multiple images or Video Frames
195+
visual = [image.convert("RGB") for image in visual]
196+
visual, _ = self.processor["video"]._process_multiple_images_pil(visual)
197+
message = (context, visual)
198+
else:
199+
# Text-only sample
200+
raise NotImplementedError("Text-only input is not yet supported.")
201+
else:
202+
# Text-only sample
203+
raise NotImplementedError("Text-only input is not yet supported.")
204+
205+
messages.append(message)
206+
207+
gen_kwargs = all_gen_kwargs[0]
208+
if "max_new_tokens" in gen_kwargs:
209+
self.generator.max_gen_len = gen_kwargs["max_new_tokens"]
210+
if "temperature" in gen_kwargs:
211+
self.generator.temperature = gen_kwargs["temperature"]
212+
# Default for PLM
213+
self.generator.top_p = None
214+
self.generator.top_k = 100
215+
216+
generation, loglikelihood, greedy = self.generator.generate(messages)
217+
218+
for gen, context in zip(generation, contexts):
219+
if gen.endswith("."):
220+
gen = gen[:-1]
221+
res.append(gen)
222+
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), gen)
223+
pbar.update(1)
224+
# reorder this group of results back to original unsorted form
225+
res = re_ords.get_original(res)
226+
227+
pbar.close()
228+
return res
229+
230+
def generate_until_multi_round(self, requests) -> List[str]:
231+
raise NotImplementedError("Multi-round generation is not implemented yet.")
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
group : coco_karpathy
2+
task:
3+
- coco_karpathy_val
4+
- coco_karpathy_test
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
dataset_path: yerevann/coco-karpathy
2+
dataset_kwargs:
3+
token: True
4+
task: "coco_karpathy_test"
5+
group : "coco_karpathy"
6+
test_split: test
7+
output_type: generate_until
8+
doc_to_visual: !function utils.coco_doc_to_visual
9+
doc_to_text: "Describe the image briefly."
10+
doc_to_target: "answer"
11+
generation_kwargs:
12+
max_new_tokens: 64
13+
temperature: 0
14+
top_p: 1.0
15+
num_beams: 1
16+
do_sample: false
17+
process_results: !function utils.coco_process_result_karpathy
18+
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
19+
metric_list:
20+
- metric: coco_Bleu_4
21+
aggregation : !function utils.coco_bleu4
22+
higher_is_better : true
23+
- metric: coco_Bleu_3
24+
aggregation : !function utils.coco_bleu3
25+
higher_is_better : true
26+
- metric: coco_Bleu_2
27+
aggregation : !function utils.coco_bleu2
28+
higher_is_better : true
29+
- metric: coco_Bleu_1
30+
aggregation : !function utils.coco_bleu1
31+
higher_is_better : true
32+
- metric: coco_METEOR
33+
aggregation : !function utils.coco_meteor
34+
higher_is_better : true
35+
- metric: coco_ROUGE_L
36+
aggregation : !function utils.coco_rougel
37+
higher_is_better : true
38+
- metric: coco_CIDEr
39+
aggregation : !function utils.coco_cider
40+
higher_is_better : true
41+
#- metric: coco_SPICE
42+
# aggregation : !function utils.coco_spice
43+
# higher_is_better : true
44+
metadata:
45+
- version: 0.0
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
dataset_path: yerevann/coco-karpathy
2+
dataset_kwargs:
3+
token: True
4+
task: "coco_karpathy_val"
5+
group : "coco_karpathy"
6+
test_split: validation
7+
output_type: generate_until
8+
doc_to_visual: !function utils.coco_doc_to_visual_karpathy
9+
doc_to_text: "Describe the image briefly."
10+
doc_to_target: "answer"
11+
generation_kwargs:
12+
max_new_tokens: 64
13+
temperature: 0
14+
top_p: 1.0
15+
num_beams: 1
16+
do_sample: false
17+
process_results: !function utils.coco_process_result_karpathy
18+
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
19+
metric_list:
20+
- metric: coco_Bleu_4
21+
aggregation : !function utils.coco_bleu4
22+
higher_is_better : true
23+
- metric: coco_Bleu_3
24+
aggregation : !function utils.coco_bleu3
25+
higher_is_better : true
26+
- metric: coco_Bleu_2
27+
aggregation : !function utils.coco_bleu2
28+
higher_is_better : true
29+
- metric: coco_Bleu_1
30+
aggregation : !function utils.coco_bleu1
31+
higher_is_better : true
32+
- metric: coco_METEOR
33+
aggregation : !function utils.coco_meteor
34+
higher_is_better : true
35+
- metric: coco_ROUGE_L
36+
aggregation : !function utils.coco_rougel
37+
higher_is_better : true
38+
- metric: coco_CIDEr
39+
aggregation : !function utils.coco_cider
40+
higher_is_better : true
41+
#- metric: coco_SPICE
42+
# aggregation : !function utils.coco_spice
43+
# higher_is_better : true
44+
metadata:
45+
- version: 0.0

lmms_eval/tasks/coco_cap/utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
77
from pycocotools.coco import COCO
88

9+
from PIL import Image
10+
import requests
11+
from io import BytesIO
12+
913
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
1014

1115
dir_name = os.path.dirname(os.path.abspath(__file__))
@@ -16,11 +20,35 @@
1620
def coco_doc_to_visual(doc):
1721
return [doc["image"].convert("RGB")]
1822

23+
def coco_doc_to_visual_karpathy(doc):
24+
image_url = doc["url"]
25+
response = requests.get(image_url)
26+
image = Image.open(BytesIO(response.content))
27+
return [image.convert("RGB")]
1928

2029
def coco_doc_to_text(doc):
2130
return f"Provide a one-sentence caption for the provided image."
2231

2332

33+
def coco_process_result_karpathy(doc, result):
34+
"""
35+
Args:
36+
doc: a instance of the eval dataset
37+
results: [pred]
38+
Returns:
39+
a dictionary with key: metric name, value: metric value
40+
"""
41+
pred = result[0] if len(result) > 0 else ""
42+
question_id = doc["filename"]
43+
# The question id in our dataset is the image file itself
44+
image_id = int(question_id.split("_")[-1].split(".")[0])
45+
id = doc["imgid"]
46+
47+
data_dict = {"answer": doc["sentences"], "pred": pred, "image_id": image_id, "id": id}
48+
49+
return {f"coco_{metric}": data_dict for metric in COCO_METRICS}
50+
51+
2452
def coco_process_result(doc, result):
2553
"""
2654
Args:

0 commit comments

Comments
 (0)