Skip to content

Commit 97ad27e

Browse files
authored
Add statistics attribute to DataRecord (#302)
* add Statistic class * introduce frame factory to handle passing of statistics attribute * add xg, psxg and obv for statstbomb * add xg and psxg for opta * add xg and psxg for wyscout
1 parent 3b40d02 commit 97ad27e

21 files changed

Lines changed: 261 additions & 13 deletions

File tree

kloppy/domain/models/common.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,85 @@ class DatasetFlag(Flag):
928928
BALL_STATE = 2
929929

930930

931+
@dataclass
932+
class Statistic(ABC):
933+
name: str = field(init=False)
934+
935+
936+
@dataclass
937+
class ScalarStatistic(Statistic):
938+
value: float
939+
940+
941+
@dataclass
942+
class ExpectedGoals(ScalarStatistic):
943+
"""Expected goals"""
944+
945+
def __post_init__(self):
946+
self.name = "xG"
947+
948+
949+
@dataclass
950+
class PostShotExpectedGoals(ScalarStatistic):
951+
"""Post-shot expected goals"""
952+
953+
def __post_init__(self):
954+
self.name = "PSxG"
955+
956+
957+
@dataclass
958+
class ActionValue(Statistic):
959+
"""Action value"""
960+
961+
name: str
962+
action_value_scoring_before: Optional[float] = field(default=None)
963+
action_value_scoring_after: Optional[float] = field(default=None)
964+
action_value_conceding_before: Optional[float] = field(default=None)
965+
action_value_conceding_after: Optional[float] = field(default=None)
966+
967+
@property
968+
def offensive_value(self) -> Optional[float]:
969+
return (
970+
None
971+
if None
972+
in (
973+
self.action_value_scoring_before,
974+
self.action_value_scoring_after,
975+
)
976+
else self.action_value_scoring_after
977+
- self.action_value_scoring_before
978+
)
979+
980+
@property
981+
def defensive_value(self) -> Optional[float]:
982+
return (
983+
None
984+
if None
985+
in (
986+
self.action_value_conceding_before,
987+
self.action_value_conceding_after,
988+
)
989+
else self.action_value_conceding_after
990+
- self.action_value_conceding_before
991+
)
992+
993+
@property
994+
def value(self) -> Optional[float]:
995+
if None in (
996+
self.action_value_scoring_before,
997+
self.action_value_scoring_after,
998+
self.action_value_conceding_before,
999+
self.action_value_conceding_after,
1000+
):
1001+
return None
1002+
return (
1003+
self.action_value_scoring_after - self.action_value_scoring_before
1004+
) - (
1005+
self.action_value_conceding_after
1006+
- self.action_value_conceding_before
1007+
)
1008+
1009+
9311010
@dataclass
9321011
class DataRecord(ABC):
9331012
"""
@@ -945,6 +1024,7 @@ class DataRecord(ABC):
9451024
next_record: Optional["DataRecord"] = field(init=False)
9461025
period: Period
9471026
timestamp: timedelta
1027+
statistics: List[Statistic]
9481028
ball_owning_team: Optional[Team]
9491029
ball_state: Optional[BallState]
9501030

kloppy/domain/services/event_factory.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def create_event(event_cls: Type[T], **kwargs) -> T:
4848
if "freeze_frame" not in kwargs:
4949
kwargs["freeze_frame"] = None
5050

51+
if "statistics" not in kwargs:
52+
kwargs["statistics"] = []
53+
5154
all_kwargs = dict(**kwargs, **extra_kwargs)
5255

5356
relevant_kwargs = {
@@ -66,7 +69,9 @@ def create_event(event_cls: Type[T], **kwargs) -> T:
6669
f"The following arguments were skipped: {skipped_kwargs}"
6770
)
6871

69-
return event_cls(**relevant_kwargs)
72+
event = event_cls(**relevant_kwargs)
73+
74+
return event
7075

7176

7277
class EventFactory:
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import dataclasses
2+
import warnings
3+
from dataclasses import fields
4+
5+
from kloppy.domain import Frame
6+
7+
8+
def create_frame(**kwargs) -> Frame:
9+
"""
10+
Do the actual construction of a frame.
11+
12+
This method does a couple of things:
13+
1. Fill in some arguments when not passed
14+
2. Pass only arguments that are accepted by the Frame class.
15+
"""
16+
if "statistics" not in kwargs:
17+
kwargs["statistics"] = []
18+
19+
relevant_kwargs = {
20+
field.name: kwargs.get(field.name, field.default)
21+
for field in fields(Frame)
22+
if field.init
23+
and not (
24+
field.default == dataclasses.MISSING and field.name not in kwargs
25+
)
26+
}
27+
28+
if len(relevant_kwargs) < len(kwargs):
29+
skipped_kwargs = set(kwargs.keys()) - set(relevant_kwargs.keys())
30+
warnings.warn(
31+
f"The following arguments were skipped: {skipped_kwargs}"
32+
)
33+
34+
frame = Frame(**relevant_kwargs)
35+
36+
return frame

kloppy/domain/services/transformers/dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def __change_frame_coordinate_system(self, frame: Frame):
220220
for key, player_data in frame.players_data.items()
221221
},
222222
other_data=frame.other_data,
223+
statistics=frame.statistics,
223224
)
224225

225226
def __change_frame_dimensions(self, frame: Frame):
@@ -246,6 +247,7 @@ def __change_frame_dimensions(self, frame: Frame):
246247
for key, player_data in frame.players_data.items()
247248
},
248249
other_data=frame.other_data,
250+
statistics=frame.statistics,
249251
)
250252

251253
def __change_point_coordinate_system(
@@ -303,6 +305,7 @@ def __flip_frame(self, frame: Frame):
303305
ball_coordinates=self.flip_point(frame.ball_coordinates),
304306
players_data=players_data,
305307
other_data=frame.other_data,
308+
statistics=frame.statistics,
306309
)
307310

308311
def transform_event(self, event: Event) -> Event:

kloppy/infra/serializers/code/sportscode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def deserialize(self, inputs: SportsCodeInputs) -> CodeDataset:
6868
labels=parse_labels(instance),
6969
ball_state=None,
7070
ball_owning_team=None,
71+
statistics=[],
7172
)
7273
period.end_timestamp = end_timestamp
7374
codes.append(code)

kloppy/infra/serializers/event/statsbomb/helpers.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
Player,
1212
PlayerData,
1313
PositionType,
14+
ActionValue,
15+
Provider,
1416
)
17+
from kloppy.domain.services.frame_factory import create_frame
1518
from kloppy.exceptions import DeserializationError
1619

1720

@@ -21,6 +24,25 @@ def parse_str_ts(timestamp: str) -> float:
2124
return timedelta(seconds=int(h) * 3600 + int(m) * 60 + float(s))
2225

2326

27+
def parse_obv_values(raw_event: dict) -> Optional[ActionValue]:
28+
game_state_values_data = {}
29+
obv_mapping = {
30+
"obv_for_before": "action_value_scoring_before",
31+
"obv_against_before": "action_value_conceding_before",
32+
"obv_for_after": "action_value_scoring_after",
33+
"obv_against_after": "action_value_conceding_after",
34+
}
35+
for sb_name, kloppy_name in obv_mapping.items():
36+
obv_value = raw_event.get(sb_name)
37+
if obv_value is not None:
38+
game_state_values_data[kloppy_name] = obv_value
39+
40+
if game_state_values_data:
41+
game_state_value = ActionValue(name="OBV", **game_state_values_data)
42+
43+
return game_state_value
44+
45+
2446
def get_team_by_id(team_id: int, teams: List[Team]) -> Team:
2547
"""Get a team by its id."""
2648
if str(team_id) == teams[0].team_id:
@@ -129,7 +151,7 @@ def get_player_from_freeze_frame(player_data, team, i):
129151
+ event.timestamp.total_seconds() * FREEZE_FRAME_FPS
130152
)
131153

132-
return Frame(
154+
frame = create_frame(
133155
frame_id=frame_id,
134156
ball_coordinates=Point3D(
135157
x=event.coordinates.x, y=event.coordinates.y, z=0
@@ -141,3 +163,5 @@ def get_player_from_freeze_frame(player_data, team, i):
141163
ball_owning_team=event.ball_owning_team,
142164
other_data={"visible_area": visible_area},
143165
)
166+
167+
return frame

kloppy/infra/serializers/event/statsbomb/specification.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,16 @@
2727
FormationType,
2828
PositionType,
2929
CounterAttackQualifier,
30+
ExpectedGoals,
31+
PostShotExpectedGoals,
3032
)
3133
from kloppy.exceptions import DeserializationError
3234
from kloppy.infra.serializers.event.statsbomb.helpers import (
3335
parse_str_ts,
3436
get_team_by_id,
3537
get_period_by_id,
3638
parse_coordinates,
39+
parse_obv_values,
3740
)
3841

3942

@@ -293,6 +296,7 @@ def deserialize(self, event_factory: EventFactory) -> List[Event]:
293296
return events
294297

295298
def _parse_generic_kwargs(self) -> Dict:
299+
game_state_value = parse_obv_values(self.raw_event)
296300
return {
297301
"period": self.period,
298302
"timestamp": parse_str_ts(self.raw_event["timestamp"]),
@@ -311,6 +315,7 @@ def _parse_generic_kwargs(self) -> Dict:
311315
),
312316
"related_event_ids": self.raw_event.get("related_events", []),
313317
"raw_event": self.raw_event,
318+
"statistics": [game_state_value] if game_state_value else [],
314319
}
315320

316321
def _create_aerial_won_event(
@@ -548,6 +553,16 @@ def _create_events(
548553
EVENT_TYPE.SHOT, shot_dict
549554
) + _get_body_part_qualifiers(shot_dict)
550555

556+
for statistic_cls, prop_name in {
557+
ExpectedGoals: "statsbomb_xg",
558+
PostShotExpectedGoals: "shot_execution_xg",
559+
}.items():
560+
value = shot_dict.get(prop_name, None)
561+
if value is not None:
562+
generic_event_kwargs["statistics"].append(
563+
statistic_cls(value=value)
564+
)
565+
551566
shot_event = event_factory.build_shot(
552567
result=result,
553568
qualifiers=qualifiers,

kloppy/infra/serializers/event/statsperform/deserializer.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
GoalkeeperActionType,
3636
CounterAttackQualifier,
3737
PositionType,
38+
ExpectedGoals,
39+
PostShotExpectedGoals,
3840
)
3941
from kloppy.exceptions import DeserializationError
4042
from kloppy.infra.serializers.event.deserializer import EventDataDeserializer
@@ -152,6 +154,9 @@
152154
EVENT_QUALIFIER_FORMATION_PLAYER_IDS = 30
153155
EVENT_QUALIFIER_FORMATION_PLAYER_POSITIONS = 131
154156

157+
EVENT_QUALIFIER_XG = 321
158+
EVENT_QUALIFIER_POST_SHOT_XG = 322
159+
155160
event_type_names = {
156161
1: "pass",
157162
2: "offside pass",
@@ -403,13 +408,27 @@ def _parse_shot(raw_event: OptaEvent) -> Dict:
403408
y=100 - result_coordinates.y,
404409
)
405410

406-
return dict(
411+
event_info = dict(
407412
coordinates=coordinates,
408413
result=result,
409414
result_coordinates=result_coordinates,
410415
qualifiers=qualifiers,
411416
)
412417

418+
statistics = []
419+
for event_qualifier, statistic in zip(
420+
[EVENT_QUALIFIER_XG, EVENT_QUALIFIER_POST_SHOT_XG],
421+
[ExpectedGoals, PostShotExpectedGoals],
422+
):
423+
xg_value = raw_event.qualifiers.get(event_qualifier)
424+
if xg_value:
425+
statistics.append(statistic(value=float(xg_value)))
426+
427+
if statistics:
428+
event_info["statistics"] = statistics
429+
430+
return event_info
431+
413432

414433
def _parse_goalkeeper_events(raw_event: OptaEvent) -> Dict:
415434
qualifiers = _get_event_qualifiers(raw_event.qualifiers)

kloppy/infra/serializers/event/wyscout/deserializer_v3.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737
ShotResult,
3838
TakeOnResult,
3939
Team,
40+
FormationType,
41+
CarryResult,
42+
ExpectedGoals,
43+
PostShotExpectedGoals,
4044
)
4145
from kloppy.exceptions import DeserializationError, DeserializationWarning
4246
from kloppy.utils import performance_logging
@@ -269,10 +273,20 @@ def _parse_shot(raw_event: Dict) -> Dict:
269273
elif raw_event["shot"]["bodyPart"] == "right_foot":
270274
qualifiers.append(BodyPartQualifier(value=BodyPart.RIGHT_FOOT))
271275

276+
statistics = []
277+
for statistic_cls, prop_name in {
278+
ExpectedGoals: "xg",
279+
PostShotExpectedGoals: "postShotXg",
280+
}.items():
281+
value = raw_event["shot"].get(prop_name, None)
282+
if value is not None:
283+
statistics.append(statistic_cls(value=value))
284+
272285
return {
273286
"result": result,
274287
"result_coordinates": _create_shot_result_coordinates(raw_event),
275288
"qualifiers": qualifiers,
289+
"statistics": statistics,
276290
}
277291

278292

kloppy/infra/serializers/tracking/metrica_csv.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Player,
2121
PlayerData,
2222
)
23+
from kloppy.domain.services.frame_factory import create_frame
2324
from kloppy.infra.serializers.tracking.deserializer import (
2425
TrackingDataDeserializer,
2526
)
@@ -188,7 +189,7 @@ def deserialize(
188189
**away_partial_frame.players_data,
189190
}
190191

191-
frame = Frame(
192+
frame = create_frame(
192193
frame_id=frame_id,
193194
timestamp=timedelta(seconds=frame_id / frame_rate)
194195
- period.start_timestamp,

0 commit comments

Comments
 (0)