@@ -37,6 +37,7 @@ def __init__(
3737 to_coordinate_system : Optional [CoordinateSystem ] = None ,
3838 to_pitch_dimensions : Optional [PitchDimensions ] = None ,
3939 to_orientation : Optional [Orientation ] = None ,
40+ overlay_teams : bool = False ,
4041 ):
4142 if (
4243 from_pitch_dimensions
@@ -73,9 +74,13 @@ def __init__(
7374 "You must specify the source CoordinateSystem when specifying the target CoordinateSystem"
7475 )
7576 self ._to_pitch_dimensions = to_coordinate_system .pitch_dimensions
77+ else :
78+ self ._to_pitch_dimensions = self ._from_pitch_dimensions
7679
7780 self ._from_orientation = from_orientation
7881 self ._to_orientation = to_orientation
82+
83+ self ._overlay_teams = overlay_teams
7984 if (
8085 from_orientation
8186 and not to_orientation
@@ -185,6 +190,9 @@ def transform_frame(self, frame: Frame) -> Frame:
185190 elif self ._needs_pitch_dimensions_change :
186191 frame = self .__change_frame_dimensions (frame )
187192
193+ elif self ._overlay_teams :
194+ frame = self .transform_frame_overlay_teams (frame )
195+
188196 # Flip frame based on orientation
189197 if self ._needs_orientation_change :
190198 if self .__needs_flip (
@@ -308,6 +316,55 @@ def __flip_frame(self, frame: Frame):
308316 statistics = frame .statistics ,
309317 )
310318
319+ def _get_overlay_players_coordinates (
320+ self ,
321+ player_data : PlayerData ,
322+ player_team : Team ,
323+ ball_owning_team : Team ,
324+ attacking_direction : AttackingDirection ,
325+ ):
326+ if attacking_direction == AttackingDirection .RTL :
327+ if player_team != ball_owning_team :
328+ player_data .coordinates = self .flip_point (
329+ player_data .coordinates
330+ )
331+ else :
332+ if player_team == ball_owning_team :
333+ player_data .coordinates = self .flip_point (
334+ player_data .coordinates
335+ )
336+
337+ return player_data
338+
339+ def transform_frame_overlay_teams (self , frame : Frame ):
340+ players_data = {
341+ player : self ._get_overlay_players_coordinates (
342+ player_data ,
343+ player .team ,
344+ frame .ball_owning_team ,
345+ frame .attacking_direction ,
346+ )
347+ for player , player_data in frame .players_data .items ()
348+ }
349+
350+ ball_coordinates = frame .ball_coordinates
351+ if frame .attacking_direction != AttackingDirection .RTL :
352+ ball_coordinates = self .flip_point (ball_coordinates )
353+
354+ return Frame (
355+ # doesn't change
356+ timestamp = frame .timestamp ,
357+ frame_id = frame .frame_id ,
358+ ball_owning_team = frame .ball_owning_team ,
359+ ball_state = frame .ball_state ,
360+ period = frame .period ,
361+ other_data = frame .other_data ,
362+ statistics = frame .statistics ,
363+ # changes
364+ ball_coordinates = ball_coordinates ,
365+ players_data = players_data ,
366+ )
367+
311368 def transform_event (self , event : Event ) -> Event :
312369 # Change coordinate system
313370 if self ._needs_coordinate_system_change :
@@ -375,11 +432,13 @@ def transform_dataset(
375432 to_pitch_dimensions : Optional [PitchDimensions ] = None ,
376433 to_orientation : Optional [Orientation ] = None ,
377434 to_coordinate_system : Optional [CoordinateSystem ] = None ,
435+ overlay_teams : bool = False ,
378436 ) -> Dataset :
379437 if (
380438 to_pitch_dimensions is None
381439 and to_orientation is None
382440 and to_coordinate_system is None
441+ and overlay_teams is False
383442 ):
384443 return dataset
385444
@@ -391,8 +450,20 @@ def transform_dataset(
391450 "Cannot transform to BALL_OWNING_TEAM orientation when "
392451 "dataset doesn't contain ball owning team data"
393452 )
394-
395- if to_pitch_dimensions is not None :
453+ if overlay_teams :
454+ transformer = cls (
455+ from_pitch_dimensions = dataset .metadata .pitch_dimensions ,
456+ from_orientation = dataset .metadata .orientation ,
457+ to_orientation = to_orientation ,
458+ to_pitch_dimensions = to_pitch_dimensions ,
459+ overlay_teams = overlay_teams ,
460+ )
461+ metadata = replace (
462+ dataset .metadata ,
463+ pitch_dimensions = to_pitch_dimensions ,
464+ orientation = to_orientation ,
465+ )
466+ elif to_pitch_dimensions is not None :
396467 # Transform the pitch dimensions and optionally the orientation
397468 transformer = cls (
398469 from_pitch_dimensions = dataset .metadata .pitch_dimensions ,
@@ -418,7 +489,6 @@ def transform_dataset(
418489 dataset .metadata ,
419490 coordinate_system = to_coordinate_system ,
420491 pitch_dimensions = to_coordinate_system .pitch_dimensions ,
421- orientation = to_orientation ,
422492 )
423493
424494 else :
0 commit comments