Skip to content

Commit 75243b0

Browse files
committed
[general] Add strict=True to all zip() calls as per new ruff lint
1 parent e2c11c9 commit 75243b0

File tree

4 files changed

+217
-3
lines changed

4 files changed

+217
-3
lines changed

benchmark/autoshot_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, dataset_dir: str):
1717
self._scene_files = [
1818
file for file in sorted(glob.glob(os.path.join(dataset_dir, "annotations", "*.txt")))
1919
]
20-
for video_file, scene_file in zip(self._video_files, self._scene_files):
20+
for video_file, scene_file in zip(self._video_files, self._scene_files, strict=True):
2121
video_id = os.path.basename(video_file).split(".")[0]
2222
scene_id = os.path.basename(scene_file).split(".")[0]
2323
assert video_id == scene_id

benchmark/bbc_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self, dataset_dir: str):
1818
file for file in sorted(glob.glob(os.path.join(dataset_dir, "fixed", "*.txt")))
1919
]
2020
assert len(self._video_files) == len(self._scene_files)
21-
for video_file, scene_file in zip(self._video_files, self._scene_files):
21+
for video_file, scene_file in zip(self._video_files, self._scene_files, strict=True):
2222
video_id = os.path.basename(video_file).replace("bbc_", "").split(".")[0]
2323
scene_id = os.path.basename(scene_file).split("-")[0]
2424
assert video_id == scene_id

scenedetect/detectors/content_detector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ def _calculate_frame_score(self, timecode: FrameTimecode, frame_img: numpy.ndarr
173173
)
174174

175175
frame_score: float = sum(
176-
component * weight for (component, weight) in zip(score_components, self._weights)
176+
component * weight
177+
for (component, weight) in zip(score_components, self._weights, strict=True)
177178
) / sum(abs(weight) for weight in self._weights)
178179

179180
# Record components and frame score if needed for analysis.
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
#
2+
# PySceneDetect: Python-Based Video Scene Detector
3+
# -------------------------------------------------------------------
4+
# [ Site: https://scenedetect.com ]
5+
# [ Docs: https://scenedetect.com/docs/ ]
6+
# [ Github: https://github.com/Breakthrough/PySceneDetect/ ]
7+
#
8+
# Copyright (C) 2014-2024 Brandon Castellano <http://www.bcastell.com>.
9+
# PySceneDetect is licensed under the BSD 3-Clause License; see the
10+
# included LICENSE file, or visit one of the above pages for details.
11+
#
12+
""":class:`TransnetV2Detector` uses a pretrained neural network.
13+
14+
This detector is available from the command-line as the `detect-transnetv2` command.
15+
"""
16+
17+
import typing as ty
18+
import warnings
19+
from enum import Enum
20+
from logging import getLogger
21+
from pathlib import Path
22+
23+
import cv2
24+
import numpy as np
25+
26+
from scenedetect.common import FrameTimecode, Timecode
27+
from scenedetect.detector import FlashFilter, SceneDetector
28+
29+
logger = getLogger("pyscenedetect")
30+
31+
32+
class Detector:
33+
def __init__(self, threshold: float, flash_filter: FlashFilter):
34+
self.i = 0
35+
self.y_prev = 0
36+
self.threshold = threshold
37+
self.flash_filter = flash_filter
38+
39+
def push(self, ys: np.ndarray, ts: np.ndarray):
40+
predictions = (ys > self.threshold).astype(np.uint8)
41+
42+
cuts = []
43+
for y, t in zip(predictions, ts, strict=True):
44+
if self.y_prev == 0 and y == 1 and self.i > 0:
45+
cuts.append(t)
46+
self.y_prev = y
47+
self.i += 1
48+
49+
return cuts
50+
51+
52+
class Predictor:
53+
def __init__(
54+
self,
55+
model_path: ty.Union[str, Path],
56+
flash_filter: FlashFilter,
57+
onnx_providers: ty.Union[ty.List[str], None],
58+
threshold,
59+
):
60+
import onnxruntime as ort
61+
62+
ort.set_default_logger_severity(3)
63+
64+
if onnx_providers is None:
65+
onnx_providers = ort.get_available_providers()
66+
67+
sess_opt = ort.SessionOptions()
68+
sess_opt.log_severity_level = 3
69+
70+
self.session = ort.InferenceSession(model_path, sess_opt=sess_opt, providers=onnx_providers)
71+
72+
self.pixels = None
73+
self.time = None
74+
75+
self.det = Detector(threshold, flash_filter)
76+
77+
def _inference(self, pixels: np.ndarray, time: np.ndarray):
78+
pred = np.array(self.session.run(["output"], {"input": pixels}))[0]
79+
80+
cuts = []
81+
for i in range(pred.shape[0]):
82+
cuts.extend(self.det.push(pred[i, 25:75, 0], time[i, 25:75]))
83+
return cuts
84+
85+
def push(self, pixels: np.ndarray, time: np.ndarray):
86+
if self.pixels is None:
87+
self.pixels = pixels
88+
self.time = time
89+
90+
return self._inference(
91+
np.stack(
92+
(
93+
np.tile(np.expand_dims(pixels[0], axis=0), (100, 1, 1, 1)),
94+
np.concatenate(
95+
(
96+
np.tile(np.expand_dims(pixels[0], axis=0), (25, 1, 1, 1)),
97+
pixels[:75],
98+
),
99+
0,
100+
),
101+
)
102+
),
103+
np.stack(
104+
(
105+
np.tile(np.expand_dims(time[0], axis=0), (100,)),
106+
np.concatenate(
107+
(np.tile(np.expand_dims(time[0], axis=0), (25,)), time[:75]), 0
108+
),
109+
)
110+
),
111+
)
112+
else:
113+
c1 = self.pixels
114+
c2 = pixels
115+
116+
t1 = self.time
117+
t2 = time
118+
119+
self.pixels = pixels
120+
self.time = time
121+
122+
return self._inference(
123+
np.stack(
124+
(np.concatenate((c1[25:], c2[:25]), 0), np.concatenate((c1[75:], c2[:75]), 0))
125+
),
126+
np.stack(
127+
(np.concatenate((t1[25:], t2[:25]), 0), np.concatenate((t1[75:], t2[:75]), 0))
128+
),
129+
)
130+
131+
132+
class TransnetV2Detector(SceneDetector):
133+
def __init__(
134+
self,
135+
model_path: ty.Union[str, Path] = "tests/resources/transnetv2.onnx",
136+
onnx_providers: ty.Union[ty.List[str], None] = None,
137+
threshold: float = 0.5,
138+
min_scene_len: int = 15,
139+
filter_mode: FlashFilter.Mode = FlashFilter.Mode.MERGE,
140+
):
141+
super().__init__()
142+
143+
self.px = np.zeros((2, 100, 27, 48, 3), dtype=np.uint8)
144+
self.time = np.zeros((2, 100), dtype=np.int64)
145+
146+
self.blank = np.zeros(self.px.shape[2:], dtype=np.uint8)
147+
148+
self.i = 0
149+
self.j = 0
150+
151+
self.predictor = Predictor(
152+
model_path=model_path,
153+
flash_filter=FlashFilter(mode=filter_mode, length=min_scene_len),
154+
onnx_providers=onnx_providers,
155+
threshold=threshold,
156+
)
157+
# TODO(https://scenedetect.com/issue/168): Figure out a better long term plan for handling
158+
# `min_scene_len` which should be specified in seconds, not frames.
159+
self._flash_filter = FlashFilter(mode=filter_mode, length=min_scene_len)
160+
161+
def mk_ft(self, pts: int):
162+
# t = Timecode(pts=pts, time_base=self.time_base)
163+
t = float(pts * self.time_base)
164+
return FrameTimecode(t, fps=self._fps)
165+
166+
def process_frame(
167+
self, timecode: FrameTimecode, frame_img: np.ndarray
168+
) -> ty.List[FrameTimecode]:
169+
"""Process the next frame."""
170+
171+
self.time_base = timecode.time_base
172+
self._fps = timecode._rate
173+
174+
pixels = cv2.resize(frame_img, (48, 27), interpolation=cv2.INTER_AREA)
175+
176+
self.px[self.j, self.i] = pixels
177+
self.time[self.j, self.i] = timecode.pts
178+
self.i += 1
179+
180+
if self.i >= 100:
181+
cuts = self.predictor.push(self.px[self.j], self.time[self.j])
182+
self.j = 1 - self.j
183+
self.i = 0
184+
185+
filtered_cuts = []
186+
for cut in cuts:
187+
filtered_cuts += self._flash_filter.filter(self.mk_ft(cut), True)
188+
return filtered_cuts
189+
else:
190+
return []
191+
192+
def post_process(self, timecode: FrameTimecode) -> ty.List[FrameTimecode]:
193+
"""Writes a final scene cut if the last detected fade was a fade-out."""
194+
195+
cuts = []
196+
197+
last_time = timecode.pts
198+
blank_frame = self.blank[:]
199+
200+
self.px[self.j, self.i :] = blank_frame
201+
self.time[self.j, self.i :] = last_time
202+
cuts.extend(self.predictor.push(self.px[self.j], self.time[self.j]))
203+
204+
self.j = 1 - self.j
205+
206+
self.px[self.j, :] = blank_frame
207+
self.time[self.j, :] = last_time
208+
cuts.extend(self.predictor.push(self.px[self.j], self.time[self.j]))
209+
210+
filtered_cuts = []
211+
for cut in cuts:
212+
filtered_cuts += self._flash_filter.filter(self.mk_ft(cut), True)
213+
return filtered_cuts

0 commit comments

Comments
 (0)