Skip to content

Commit 8a1d774

Browse files
authored
add Extension mechanism and PET SMoS implementation (#34)
1 parent f6f68fd commit 8a1d774

23 files changed

Lines changed: 331 additions & 34 deletions

File tree

doc/source/reference/pandas/pose.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ Pose
77

88
A ``Pose`` specialized for trajectory data representing the state of a traffic participant for a specific point in time.
99

10-
10+
.. inheritance-diagram:: tasi.Pose
11+
:top-classes: pandas.core.frame.DataFrame
12+
:parts: 1
1113

1214
Serialization / IO / conversion
1315
************************************

doc/source/reference/pandas/trajectory.rst

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@ Trajectory
55

66
.. currentmodule:: tasi
77

8-
A :py:class:`Trajectory` specialized for trajectory data representing the evolution of a traffic participant. A
9-
trajectory is a collection of :py:class:`Pose`s.
10-
8+
A :py:class:`Trajectory` specialized for trajectory data representing the
9+
evolution of a traffic participant. A trajectory is a collection of
10+
:py:class:`Pose`s.
1111
12+
.. inheritance-diagram:: tasi.Trajectory
13+
:top-classes: pandas.core.frame.DataFrame
14+
:parts: 1
1215
1316
Serialization / IO / conversion
1417
************************************
@@ -26,7 +29,7 @@ Indexing
2629
:toctree: api/
2730
2831
Trajectory.att
29-
TrajectoryDataset.trajectory
32+
3033
3134
Filtering
3235
***********
@@ -54,7 +57,7 @@ Attributes
5457
Trajectory.id
5558
Trajectory.interval
5659
Trajectory.timestamps
57-
60+
Trajectory.smos
5861
5962
.. note::
6063

doc/source/reference/pandas/trajectorydataset.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ TrajectoryDataset
77

88
A ``Dataset`` specialized for trajectory data.
99

10+
.. inheritance-diagram:: tasi.TrajectoryDataset
11+
:top-classes: pandas.core.frame.DataFrame, tasi.dataset.Dataset
12+
:parts: 1
13+
1014

1115
Constructor
1216
************

doc/source/reference/pydantic/geopose.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ Factory methods
2020
:toctree: api/
2121

2222
GeoPosePublic.from_orm
23-
GeoPosePublic.from_tasi
2423
GeoPosePublic.from_pose
2524
GeoPosePublic.model_validate
2625

tasi/base.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,15 @@ class IndexMixin:
6363
ID_COLUMN = "id"
6464

6565
@property
66-
def ids(self) -> pd.Index:
66+
def ids(self) -> List[int]:
6767
"""Returns the unique ids in the dataset
6868
6969
Returns:
7070
np.ndarray: A List of ids
7171
"""
7272
idx: pd.MultiIndex = self.index # type: ignore
7373

74-
return idx.get_level_values(self.ID_COLUMN).unique()
74+
return idx.get_level_values(self.ID_COLUMN).unique() # type: ignore
7575

7676

7777
class TASIBase(LocatableEntity):
@@ -231,6 +231,18 @@ def to_csv(
231231
finally:
232232
self.columns = old_columns
233233

234+
@property
235+
def woid(self) -> pd.DataFrame:
236+
"""Without the object id on the index
237+
238+
This property will return a `DataFrame` without the `id` on the index
239+
that can be used to compare trajectories with each other.
240+
241+
Returns:
242+
pd.DataFrame: A DataFrame with the `id` as a column.
243+
"""
244+
return self.reset_index("id")
245+
234246

235247
class PoseCollectionBase(CollectionBase):
236248

tasi/calculus.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
from tasi.utils import has_extra
1+
from .utils import has_extra
22

33
EXTRA = has_extra("performance")
44

5-
65
if EXTRA:
76
from numba import jit, njit # pyright: ignore[reportMissingImports]
87
else:

tasi/dataset/base.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import wraps
2-
from typing import Iterable, List, Tuple, Union
2+
from typing import Iterable, List, Tuple, Union, overload
33

44
import numpy as np
55
import pandas as pd
@@ -54,7 +54,15 @@ def _pose_constructor(self):
5454

5555
return Pose
5656

57-
def trajectory(self, index: Union[int, Iterable[int]], inverse: bool = False):
57+
@overload
58+
def trajectory(self, index: int, inverse: bool = False) -> Trajectory: ...
59+
60+
@overload
61+
def trajectory(self, index: Iterable[int], inverse: bool = False) -> Self: ...
62+
63+
def trajectory(
64+
self, index: Union[int, Iterable[int]], inverse: bool = False
65+
) -> Trajectory | Self:
5866
"""
5967
Select trajectory data for specific indices, or exclude them if inverse is set to True.
6068
@@ -67,7 +75,7 @@ def trajectory(self, index: Union[int, Iterable[int]], inverse: bool = False):
6775
are excluded from the resulting dataset, and all other trajectories are included. Defaults to False.
6876
6977
Returns:
70-
TrajectoryDataset: A trajectory or multiple trajectories of the dataset.
78+
tasi.Trajectory | TrajectoryDataset: A trajectory or multiple trajectories of the dataset.
7179
"""
7280

7381
if isinstance(index, (int, np.int_)):

tasi/dlr/dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pathlib import Path
88
from typing import List, Tuple, Union
99

10+
import numpy as np
1011
import pandas as pd
1112
import requests
1213
from tqdm import tqdm
@@ -458,7 +459,7 @@ def to_tasi(self) -> TrajectoryDataset:
458459
position=self.position,
459460
velocity=self.velocity,
460461
acceleration=self.acceleration,
461-
heading=self.yaw,
462+
heading=np.deg2rad(self.yaw),
462463
classifications=self.classifications,
463464
dimension=self.dimension,
464465
)

tasi/extensions/__init__.py

Whitespace-only changes.

tasi/extensions/base.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from typing import Generic, TypeVar
2+
3+
from tasi import Pose, Trajectory
4+
from tasi.pose.base import CollectionBase
5+
6+
A = TypeVar("A", bound="CollectionBase", covariant=True)
7+
8+
9+
class ExtensionBase(Generic[A]):
10+
11+
def __init__(self, obj: A, *args, **kwargs) -> None:
12+
self.obj: A = obj
13+
14+
super().__init__(*args, **kwargs)
15+
16+
17+
class PoseExtensionBase(ExtensionBase[Pose]): ...
18+
19+
20+
class TrajectoryExtensionBase(ExtensionBase[Trajectory]): ...

0 commit comments

Comments
 (0)