Ball trajectory prediction module. Given a history of ball positions and racket poses, the model predicts future ball trajectory coordinates.
| Model | Registry name | Description |
|---|---|---|
| LSTM | TPLSTM |
Baseline LSTM (ball_only or ball_racket concatenation) |
| CrossLSTM-Attn | TPLSTMAttn |
Low-dim cross-attention + gated residual + LSTM (used for released weights) |
Released checkpoints use CrossLSTM-Attn (TPLSTMAttn).
TrajPred/
├── configs/
│ ├── crosslstm_long_badminton.py # CrossLSTM-Attn config (badminton)
│ ├── crosslstm_long_tabletennis.py
│ ├── crosslstm_short_tennis.py
│ └── lstm_toy.py # Toy config for pipeline testing
├── checkpoints/
│ ├── crosslstm_long_badminton.pth
│ ├── crosslstm_long_tabletennis.pth
│ └── crosslstm_short_tennis.pth
├── model/
│ ├── lstm.py # TPLSTM
│ ├── cross_lstm.py # TPLSTMAttn
│ └── loss.py # Weighted MSE loss
├── dataset/
│ └── trajectory.py # BallTraj dataset (loads PKL files)
├── metrics/
│ └── traj_metric.py # ADE / FDE metrics (pixel-space)
├── hooks/
│ └── traj_visualize.py # Trajectory visualisation hook
├── train.py # Training entry point
├── test.py # Evaluation entry point
├── build_dataset.py # Build PKL datasets from predictions
├── linear_interpolate_ball_traj.py # Interpolate ball trajectory gaps
└── merge_gt_with_predictions.py # Merge racket GT with predictions
TrajPred requires pre-processed ball and racket predictions. The full pipeline:
pred_ball/ ──→ linear_interpolate ──→ interp_ball/
pred_racket/ ──→ merge_gt ──→ merged_racket/
interp_ball/ + merged_racket/ ──→ build_dataset ──→ data_traj/*.pkl
Fill short gaps (< 5 frames) in ball prediction CSVs with linear interpolation:
python linear_interpolate_ball_traj.py --data_root ../data --sport badmintonReplace racket predictions with ground truth annotations where available, creating "soft labels":
python merge_gt_with_predictions.py --data_root ../data --sport badmintonCreate sliding-window trajectory samples:
python build_dataset.py \
--data_root ../data \
--sport badminton \
--history 80 --future 20Output: ../data/data_traj/ball_racket_badminton_h80_f20.pkl
Pre-built PKL files for all sports are provided in ../data/data_traj/.
conda activate uball
python train.py --cfg configs/crosslstm_long_badminton.pyTraining runs both train() and test(). Checkpoints are saved based on best ADE (average displacement error).
Evaluate a trained checkpoint on the test set:
python test.py \
--cfg configs/crosslstm_long_badminton.py \
--ckpt checkpoints/crosslstm_long_badminton.pth| Metric | Description |
|---|---|
| ADE | Average Displacement Error — mean L2 distance (pixels) across all predicted time steps |
| FDE | Final Displacement Error — L2 distance (pixels) at the last predicted time step |
Coordinates are de-normalised to 1920×1080 before computing metrics.
- Embedding: Ball (2D → 64D) and racket (10D → 64D) are projected to a shared low-dimensional space.
- Cross-Attention: Ball features attend to racket features via multi-head attention, with a learnable gating coefficient
alpha. - Projection: Fused features are projected up to the LSTM hidden dimension (512D).
- LSTM: 2-layer LSTM processes the fused sequence.
- Decoder: The last hidden state is mapped to
pred_len × 2coordinates.
Loss: Weighted MSE with linearly decaying weights from 1.0 to 0.5 across the prediction horizon.