Skip to content

Commit 78d6c72

Browse files
eamonn-zhCopilot
andauthored
feat: add ReVSI evaluation (#1307)
* add the ReVSI benchmark * add the ReVSI benchmark * feat: enhance REVSI metrics and aggregation functions Co-authored-by: Copilot <copilot@github.com> --------- Co-authored-by: Copilot <copilot@github.com>
1 parent ed09b9a commit 78d6c72

7 files changed

Lines changed: 237 additions & 0 deletions

File tree

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
dataset_path: 3dlg-hcvc/ReVSI
2+
test_split: test
3+
dataset_kwargs:
4+
token: True
5+
cache_dir: revsi
6+
video: True
7+
output_type: generate_until
8+
process_docs: !function utils.process_docs
9+
doc_to_visual: !function utils.revsi_doc_to_visual
10+
doc_to_text: !function utils.revsi_doc_to_text
11+
doc_to_target: "ground_truth"
12+
generation_kwargs:
13+
max_new_tokens: 16
14+
temperature: 0
15+
do_sample: false
16+
process_results: !function utils.revsi_process_results
17+
metric_list:
18+
- metric: overall_acc
19+
aggregation: !function utils.revsi_aggregate_overall
20+
higher_is_better: true
21+
- metric: object_abs_distance_acc
22+
aggregation: !function utils.revsi_aggregate_object_abs_distance_acc
23+
higher_is_better: true
24+
- metric: object_counting_acc
25+
aggregation: !function utils.revsi_aggregate_object_counting_acc
26+
higher_is_better: true
27+
- metric: object_rel_direction_acc
28+
aggregation: !function utils.revsi_aggregate_object_rel_direction_acc
29+
higher_is_better: true
30+
- metric: object_rel_distance_acc
31+
aggregation: !function utils.revsi_aggregate_object_rel_distance_acc
32+
higher_is_better: true
33+
- metric: object_size_estimation_acc
34+
aggregation: !function utils.revsi_aggregate_object_size_estimation_acc
35+
higher_is_better: true
36+
- metric: room_size_estimation_acc
37+
aggregation: !function utils.revsi_aggregate_room_size_estimation_acc
38+
higher_is_better: true
39+
- metric: route_planning_acc
40+
aggregation: !function utils.revsi_aggregate_route_planning_acc
41+
higher_is_better: true
42+
lmms_eval_specific_kwargs:
43+
default:
44+
pre_prompt: "These are frames of a video."
45+
mcq_post_prompt: "Answer with the option's letter from the given choices directly."
46+
nq_post_prompt: "Answer the question using a single integer or decimal number."
47+
metadata:
48+
- version: 1.0

lmms_eval/tasks/revsi/revsi.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
group: revsi
2+
task:
3+
- revsi_all_frame
4+
- revsi_64_frame
5+
- revsi_32_frame
6+
- revsi_16_frame
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
dataset_name: 16_frame
2+
task: revsi_16_frame
3+
include: _default_template_yaml
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
dataset_name: 32_frame
2+
task: revsi_32_frame
3+
include: _default_template_yaml
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
dataset_name: 64_frame
2+
task: revsi_64_frame
3+
include: _default_template_yaml
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
dataset_name: all_frame
2+
task: revsi_all_frame
3+
include: _default_template_yaml

lmms_eval/tasks/revsi/utils.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import os
2+
import datasets
3+
import numpy as np
4+
import pandas as pd
5+
from huggingface_hub.constants import HF_HOME
6+
from lmms_eval.utils import resolve_cache_dir
7+
from lmms_eval.tasks._task_utils.default_template_yaml import load_default_template_yaml
8+
9+
10+
MCQ_QUESTION_TYPES = [
11+
"object_rel_direction_forward_easy",
12+
"object_rel_direction_backward_easy",
13+
"object_rel_direction_forward_hard",
14+
"object_rel_direction_backward_hard",
15+
"object_rel_distance_closest",
16+
"object_rel_distance_farthest",
17+
"route_planning",
18+
]
19+
20+
21+
NQ_QUESTION_TYPES = [
22+
"object_counting_single",
23+
"object_counting_multiple",
24+
"object_abs_distance",
25+
"object_size_estimation",
26+
"room_size_estimation_single",
27+
"room_size_estimation_multiple"
28+
]
29+
30+
31+
REVSI_METRICS = [
32+
"overall_acc",
33+
"object_abs_distance_acc",
34+
"object_counting_acc",
35+
"object_rel_direction_acc",
36+
"object_rel_distance_acc",
37+
"object_size_estimation_acc",
38+
"room_size_estimation_acc",
39+
"route_planning_acc",
40+
]
41+
42+
43+
COMPOSITE_METRICS = {
44+
"object_rel_direction_acc": [
45+
"object_rel_direction_forward_easy",
46+
"object_rel_direction_backward_easy",
47+
"object_rel_direction_forward_hard",
48+
"object_rel_direction_backward_hard",
49+
],
50+
"object_rel_distance_acc": [
51+
"object_rel_distance_closest",
52+
"object_rel_distance_farthest",
53+
],
54+
"object_counting_acc": [
55+
"object_counting_single",
56+
"object_counting_multiple",
57+
],
58+
"room_size_estimation_acc": [
59+
"room_size_estimation_single",
60+
"room_size_estimation_multiple",
61+
],
62+
}
63+
64+
65+
config = load_default_template_yaml(__file__)
66+
cache_dir = resolve_cache_dir(config["dataset_kwargs"]["cache_dir"], base_dir=HF_HOME)
67+
68+
69+
def revsi_doc_to_visual(doc):
70+
video_path = os.path.join(cache_dir, f"{doc['num_frames']}_frame", f"{doc['scene_id']}.mp4")
71+
if not os.path.exists(video_path):
72+
raise FileExistsError(f"video path:{video_path} does not exist.")
73+
return [video_path]
74+
75+
76+
def revsi_doc_to_text(doc, lmms_eval_specific_kwargs=None):
77+
question = doc["question"]
78+
pre_prompt = lmms_eval_specific_kwargs.get("pre_prompt", "")
79+
if doc["question_type"] in NQ_QUESTION_TYPES:
80+
post_prompt = lmms_eval_specific_kwargs.get("nq_post_prompt", "")
81+
return "\n".join([pre_prompt, question, post_prompt]).strip()
82+
elif doc["question_type"] in MCQ_QUESTION_TYPES:
83+
options = "Options:\n" + "\n".join(doc["options"])
84+
post_prompt = lmms_eval_specific_kwargs.get("mcq_post_prompt", "")
85+
return "\n".join([pre_prompt, question, options, post_prompt]).strip()
86+
87+
88+
def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
89+
if os.getenv("LMMS_EVAL_SHUFFLE_DOCS", None):
90+
return dataset.shuffle(seed=42)
91+
return dataset
92+
93+
94+
def _mean_relative_accuracy(pred, target, start, end, interval):
95+
num_pts = (end - start) / interval + 2
96+
conf_intervs = np.linspace(start, end, int(num_pts))
97+
acc = (abs(pred - target) / target) <= (1 - conf_intervs)
98+
return acc.mean()
99+
100+
101+
def revsi_process_results(doc, results):
102+
pred_answer = str(results[0]).strip().split(" ")[0].rstrip(".").strip()
103+
gt_answer = doc["ground_truth"]
104+
if doc["question_type"] in MCQ_QUESTION_TYPES:
105+
acc = 1.0 if pred_answer.lower() == gt_answer.lower() else 0.0
106+
elif doc["question_type"] in NQ_QUESTION_TYPES:
107+
try:
108+
acc = _mean_relative_accuracy(float(pred_answer), float(gt_answer), 0.5, 0.95, 0.05)
109+
except:
110+
acc = 0.0
111+
doc["acc"] = acc
112+
return {metric: doc for metric in REVSI_METRICS}
113+
114+
115+
def _collapse_question_types(output, metric_name, question_types):
116+
question_type_metrics = [
117+
f"{question_type}_acc" for question_type in question_types if f"{question_type}_acc" in output
118+
]
119+
if not question_type_metrics:
120+
return
121+
output[metric_name] = np.mean([output.pop(metric) for metric in question_type_metrics])
122+
123+
124+
def _compute_all_subscores(results) -> dict:
125+
results = pd.DataFrame(results)
126+
output = {
127+
f"{question_type}_acc": per_question_type["acc"].mean()
128+
for question_type, per_question_type in results.groupby("question_type")
129+
}
130+
131+
for metric_name, question_types in COMPOSITE_METRICS.items():
132+
_collapse_question_types(output, metric_name, question_types)
133+
134+
output["overall_acc"] = sum(output.values()) / len(output) if output else 0.0
135+
return output
136+
137+
138+
def _aggregate_metric(results, metric_name):
139+
return _compute_all_subscores(results).get(metric_name, 0.0)
140+
141+
142+
def revsi_aggregate_overall(results):
143+
return _aggregate_metric(results, "overall_acc")
144+
145+
146+
def revsi_aggregate_object_abs_distance_acc(results):
147+
return _aggregate_metric(results, "object_abs_distance_acc")
148+
149+
150+
def revsi_aggregate_object_counting_acc(results):
151+
return _aggregate_metric(results, "object_counting_acc")
152+
153+
154+
def revsi_aggregate_object_rel_direction_acc(results):
155+
return _aggregate_metric(results, "object_rel_direction_acc")
156+
157+
158+
def revsi_aggregate_object_rel_distance_acc(results):
159+
return _aggregate_metric(results, "object_rel_distance_acc")
160+
161+
162+
def revsi_aggregate_object_size_estimation_acc(results):
163+
return _aggregate_metric(results, "object_size_estimation_acc")
164+
165+
166+
def revsi_aggregate_room_size_estimation_acc(results):
167+
return _aggregate_metric(results, "room_size_estimation_acc")
168+
169+
170+
def revsi_aggregate_route_planning_acc(results):
171+
return _aggregate_metric(results, "route_planning_acc")

0 commit comments

Comments
 (0)