Skip to content

Commit b1149e7

Browse files
committed
feat: Added frequency aware one-hot encoding for dynamic frequencieslike business days.
1 parent aa6d1c9 commit b1149e7

2 files changed

Lines changed: 107 additions & 94 deletions

File tree

darts/tests/utils/test_timeseries_generation.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,8 +451,14 @@ def test_datetime_attribute_timeseries_wrong_args(self):
451451
"day",
452452
to_offset("B"),
453453
pd.Timestamp(year=2025, month=1, day=1),
454-
ValueError,
455-
id="business_day_value_error",
454+
np.arange(31),
455+
id="day_business_daily",
456+
),
457+
pytest.param(
458+
"nanosecond",
459+
to_offset("999999ns"),
460+
pd.Timestamp(year=2000, month=1, day=1),
461+
np.arange(1000),
456462
),
457463
],
458464
)

darts/utils/timeseries_generation.py

Lines changed: 99 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import holidays
1111
import numpy as np
1212
import pandas as pd
13+
from pandas.tseries.offsets import Tick
1314

1415
from darts.logging import get_logger, raise_if, raise_if_not, raise_log
1516
from darts.timeseries import (
@@ -57,24 +58,30 @@
5758
"weekofyear": 52 + 1,
5859
"week_of_year": 52 + 1,
5960
}
60-
PERIOD_BY_ATTRIBTUE = {
61-
"month": pd.Timedelta(days=366),
62-
"day": pd.Timedelta(days=31),
63-
"weekday": pd.Timedelta(days=7),
64-
"dayofweek": pd.Timedelta(days=7),
65-
"day_of_week": pd.Timedelta(days=7),
66-
"hour": pd.Timedelta(hours=24),
67-
"minute": pd.Timedelta(minutes=60),
68-
"second": pd.Timedelta(seconds=60),
69-
"microsecond": pd.Timedelta(microseconds=1000000),
70-
"nanosecond": pd.Timedelta(nanoseconds=1000),
71-
"quarter": pd.Timedelta(days=366), # approx
72-
"dayofyear": pd.Timedelta(days=366),
73-
"day_of_year": pd.Timedelta(days=366),
74-
"week": pd.Timedelta(weeks=53),
75-
"weekofyear": pd.Timedelta(weeks=53),
76-
"week_of_year": pd.Timedelta(weeks=53),
61+
FULL_CALENDAR_CYCLE = pd.Timedelta(days=365 * 28 + 7) # ~28 years
62+
"""The solar calendar cycle (https://en.wikipedia.org/wiki/Solar_cycle_(calendar)) of the Julian calendar."""
63+
64+
MAX_GENERATION_STEPS = 100000
65+
"""Threshold to prevent generating too massive arrays when calculating unique datetime attribute values."""
66+
67+
ATTRIBUTE_PERIODS = {
68+
"microsecond": pd.Timedelta("1s"),
69+
"nanosecond": pd.Timedelta("1us"),
70+
"second": pd.Timedelta("1min"),
71+
"minute": pd.Timedelta("1h"),
72+
"hour": pd.Timedelta("1D"),
73+
"weekday": pd.Timedelta("1W"),
74+
"day_of_week": pd.Timedelta("1W"),
75+
"day": FULL_CALENDAR_CYCLE,
76+
"month": FULL_CALENDAR_CYCLE,
77+
"dayofyear": FULL_CALENDAR_CYCLE,
78+
"week": FULL_CALENDAR_CYCLE,
7779
}
80+
"""The time is takes for an attribute to naturally reset/wrap around.
81+
82+
For example, minutes wrap around every hour, hours wrap around every day, etc.
83+
"""
84+
7885
DATETIME_ATT_WITH_VARIABLE_MAX = [
7986
"day",
8087
"dayofyear",
@@ -83,6 +90,7 @@
8390
"weekofyear",
8491
"week_of_year",
8592
]
93+
"""Time index attributes whose maximum value varies (e.g., day of month (´28, 30 or 31), week of year (52 or 53))."""
8694

8795

8896
def constant_timeseries(
@@ -685,8 +693,8 @@ def _timedelta_lcm(td1: pd.Timedelta, td2: pd.Timedelta) -> pd.Timedelta:
685693

686694

687695
def unique_datetime_value_freq_aware(
688-
attribute: str, freq: pd.tseries.offsets.BaseOffset, start: pd.Timestamp
689-
) -> np.ndarray[int]:
696+
attribute: str, freq: Union[str, pd.tseries.offsets.BaseOffset], start: pd.Timestamp
697+
) -> np.ndarray[tuple[int], int]:
690698
"""Returns a sorted array of unqiue values that the given datetime attribute can take, based on `freq` and `start`.
691699
692700
Parameters
@@ -702,18 +710,27 @@ def unique_datetime_value_freq_aware(
702710
703711
Returns
704712
-------
705-
np.ndarray[int]
706713
Sorted array of all the unique values that the given datetime attribute can take.
707714
708715
See Also
709716
--------
710717
unique_datetime_values: When all possible values for the attribute are to be returned.
711718
712-
Warnings
713-
--------
714-
For attributes with a variable number of maximum values (day, dayofyear, day_of_year, week, weekofyear,
715-
week_of_year), this function will return all possible values as fallback, since actually computing the values
716-
would be inefficient.
719+
Notes
720+
-----
721+
This function determines unique values using one of three strategies:
722+
723+
1. **Exact Synchronization:** For fixed frequencies, it simulates the exact period where the frequency and attribute
724+
cycle align (LCM).
725+
* *Example:* ``attribute="hour", freq="2H"`` -> Returns even hours ``[0, 2, ..., 22]``.
726+
727+
2. **Calendar Simulation:** For variable frequencies (e.g., Business Days), it simulates a 28-year cycle to
728+
guarantee capturing leap years and weekday shifts.
729+
* *Example:* ``attribute="day", freq="B"`` -> Returns ``[1..31]`` (ensures Feb 29th is eventually captured).
730+
731+
3. **Heuristic Fallback:** If the simulation requires generating an excessive number of points (e.g., high-frequency
732+
data for low-frequency attributes), it assumes all theoretically possible values occur.
733+
* *Example:* ``attribute="month", freq="1min"`` -> Returns ``[1..12]`` immediately to save memory.
717734
718735
Examples
719736
--------
@@ -725,75 +742,65 @@ def unique_datetime_value_freq_aware(
725742
>>> unique_datetime_values("minute", "15min", pd.Timestamp("2020-01-01"))
726743
array([0, 15, 30, 45])
727744
"""
728-
raise_if_not(
729-
attribute in MAX_DATETIME_VALUES,
730-
f"Can't determine unique values for attribute `{attribute}`, required for cyclic and one-hot encodings. "
731-
f"Supported datetime attribute: {list(MAX_DATETIME_VALUES.keys())}",
732-
logger,
733-
)
734-
# Common frequencies, which are not convertable to pd.Timedelta
735-
fixed_yearly = {
736-
"month",
737-
"day",
738-
"hour",
739-
"minute",
740-
"second",
741-
"microsecond",
742-
"nanosecond",
743-
"quarter",
744-
}
745-
fixed_monthly = {"day", "hour", "minute", "second", "microsecond", "nanosecond"}
746-
fixed_attributes = {
747-
pd.tseries.offsets.YearBegin: fixed_yearly,
748-
pd.tseries.offsets.YearEnd: fixed_yearly,
749-
pd.tseries.offsets.MonthBegin: fixed_monthly,
750-
pd.tseries.offsets.MonthEnd: fixed_monthly,
751-
}
752-
if type(freq) in fixed_attributes:
753-
if attribute in fixed_attributes[type(freq)]:
754-
val = np.array([getattr(start, attribute)])
755-
if attribute in ONE_INDEXED_FREQS:
756-
val -= 1
757-
return val
758-
else:
759-
return unique_datetime_values(attribute)
760-
# Handle other frequencies
761-
freq_delta = None
745+
# 1. Get the Natural Period of the attribute (~28 years as safe default)
746+
natural_period = ATTRIBUTE_PERIODS.get(attribute, FULL_CALENDAR_CYCLE)
747+
748+
# 2. Try to convert frequency to Timedelta
749+
freq_td: Optional[pd.Timedelta] = None
762750
try:
763-
freq_delta = pd.Timedelta(freq.freqstr)
764-
except ValueError as e:
765-
if e.args and "unit abbreviation w/o a number" in e.args[0]:
766-
try:
767-
freq_delta = pd.Timedelta(1, unit=freq.freqstr)
768-
except ValueError:
769-
pass
770-
finally:
771-
if freq_delta is None:
772-
raise_log(
773-
ValueError(
774-
f"Can't convert freq `{freq.freqstr}` to pd.Timedelta, required for computing unique values for "
775-
f"attribute `{attribute}`. Please provide a frequency that can be converted to pd.Timedelta, "
776-
f"e.g. '15min', '1H', '3D', '1W'. Alternatively, use a frequency unaware encoding or omit the "
777-
"attribute."
778-
),
779-
logger,
780-
)
781-
if attribute in DATETIME_ATT_WITH_VARIABLE_MAX:
782-
# For these attributes, periods must be really long to capture all possible values
783-
#
784-
logger.warning(
785-
"Finding unique values for attribute `%s` based on frequency uses all possible values as fallback.",
786-
attribute,
787-
)
788-
return unique_datetime_values(attribute)
789-
lcm = _timedelta_lcm(freq_delta, PERIOD_BY_ATTRIBTUE[attribute])
790-
num_unique = lcm // freq_delta
791-
idx = pd.date_range(start=start, freq=freq_delta, periods=num_unique)
792-
values: pd.Index = _get_datetime_attribute_values(attribute, idx).sort_values()
793-
return values.unique().to_numpy()
751+
offset = pd.tseries.frequencies.to_offset(freq)
752+
if isinstance(offset, Tick):
753+
freq_td = pd.Timedelta(offset)
754+
755+
except (ValueError, TypeError):
756+
# Handle raw strings that to_offset might not like, but to_timedelta might
757+
# e.g., "15min" is fine, but sometimes complex strings fail to_offset
758+
pass
759+
760+
# Fallback: Try direct string-to-timedelta conversion if the above failed
761+
# This handles strings like "10us" if to_offset failed
762+
if freq_td is None:
763+
try:
764+
freq_td = pd.to_timedelta(freq)
765+
except (ValueError, TypeError):
766+
# If this fails, it is truly a variable frequency (e.g. 'M', 'B')
767+
pass
768+
769+
# 3. Dynamic Duration Calculation
770+
if freq_td is not None:
771+
# How long until the Freq and the Attribute Period sync up?
772+
total_duration = _timedelta_lcm(freq_td, natural_period)
773+
# Check how many points this requires
774+
num_points = total_duration // freq_td
775+
776+
# Safety fallback: If the interference pattern requires a large number of points
777+
if num_points > MAX_GENERATION_STEPS:
778+
return unique_datetime_values(attribute)
779+
780+
# Otherwise, simulate exact LCM duration
781+
idx = pd.date_range(start=start, periods=num_points, freq=freq_td)
782+
783+
else:
784+
# Variable frequency (e.g. 'BusinessDay')
785+
# We cannot calculate LCM easily. We fallback to the Safe Horizon (28 years).
786+
# 28 Years covers the synchronization of Weekdays, Leap Years, and Days.
787+
end_date = start + FULL_CALENDAR_CYCLE
788+
789+
# Heuristic check for variable freqs:
790+
# If we are doing 'BusinessHour' over 28 years, that is too huge.
791+
# Estimate points: 28 years / rough estimate of freq.
792+
# If freq is unknown, we just run generation with a cap.
793+
idx = pd.date_range(start=start, end=end_date, freq=freq)
794+
795+
if len(idx) > MAX_GENERATION_STEPS:
796+
return unique_datetime_values(attribute)
797+
798+
# 4. Return unique values
799+
values = _get_datetime_attribute_values(attribute, idx)
800+
return np.unique(values).astype(int)
794801

795802

796-
def unique_datetime_values(attribute: str) -> np.ndarray[int]:
803+
def unique_datetime_values(attribute: str) -> np.ndarray[tuple[int], int]:
797804
"""Returns a sorted array of all the unique values that the given datetime attribute can take.
798805
799806
Parameters
@@ -805,7 +812,7 @@ def unique_datetime_values(attribute: str) -> np.ndarray[int]:
805812
806813
Returns
807814
-------
808-
np.ndarray[int]
815+
np.ndarray[tuple[int], int]
809816
Sorted array of all the unique values that the given datetime attribute can take.
810817
811818
See Also

0 commit comments

Comments
 (0)