Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions qlib/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,23 +226,20 @@ def prepare(
------
NotImplementedError:
"""
logger = get_module_logger("DatasetH")
seg_kwargs = {"col_set": col_set}
seg_kwargs = {"col_set": col_set, "data_key": data_key}
seg_kwargs.update(kwargs)
if "data_key" in getfullargspec(self.handler.fetch).args:
seg_kwargs["data_key"] = data_key
else:
logger.info(f"data_key[{data_key}] is ignored.")

# Conflictions may happen here
# - The fetched data and the segment key may both be string
# To resolve the confliction
# - The segment name will have higher priorities

# 1) Use it as segment name first
# 1.1) directly fetch split like "train" "valid" "test"
if isinstance(segments, str) and segments in self.segments:
return self._prepare_seg(self.segments[segments], **seg_kwargs)

# 1.2) fetch multiple splits like ["train", "valid"] ["train", "valid", "test"]
if isinstance(segments, (list, tuple)) and all(seg in self.segments for seg in segments):
return [self._prepare_seg(self.segments[seg], **seg_kwargs) for seg in segments]

Expand All @@ -262,7 +259,7 @@ def get_max_time(segments):
def _get_extrema(segments, idx: int, cmp: Callable, key_func=pd.Timestamp):
"""it will act like sort and return the max value or None"""
candidate = None
for k, seg in segments.items():
for _, seg in segments.items():
point = seg[idx]
if point is None:
# None indicates unbounded, return directly
Expand Down Expand Up @@ -376,6 +373,8 @@ def __init__(
ffill with previous samples first and fill with later samples second
flt_data : pd.Series
a column of data(True or False) to filter data. Its index order is <"datetime", "instrument">
This feature is essential because:
- We want some sample not included due to label-based filtering, but we can't filter them at the beginning due to the features is still important in the feature.
None:
kepp all data

Expand Down Expand Up @@ -661,8 +660,9 @@ class TSDatasetH(DatasetH):

DEFAULT_STEP_LEN = 30

def __init__(self, step_len=DEFAULT_STEP_LEN, **kwargs):
def __init__(self, step_len=DEFAULT_STEP_LEN, flt_col: Optional[str] = None, **kwargs):
self.step_len = step_len
self.flt_col = flt_col
super().__init__(**kwargs)

def config(self, **kwargs):
Expand Down Expand Up @@ -693,10 +693,10 @@ def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
dtype = kwargs.pop("dtype", None)
if not isinstance(slc, slice):
slc = slice(*slc)
start, end = slc.start, slc.stop
flt_col = kwargs.pop("flt_col", None)
# TSDatasetH will retrieve more data for complete time-series
if (flt_col := kwargs.pop("flt_col", None)) is None:
flt_col = self.flt_col

# TSDatasetH will retrieve more data for complete time-series
ext_slice = self._extend_slice(slc, self.cal, self.step_len)
data = super()._prepare_seg(ext_slice, **kwargs)

Expand All @@ -710,8 +710,8 @@ def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:

tsds = TSDataSampler(
data=data,
start=start,
end=end,
start=slc.start,
end=slc.stop,
step_len=self.step_len,
dtype=dtype,
flt_data=flt_data,
Expand Down
101 changes: 71 additions & 30 deletions qlib/data/dataset/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.

# coding=utf-8
from abc import abstractmethod
import warnings
from typing import Callable, Union, Tuple, List, Iterator, Optional

Expand All @@ -19,9 +20,59 @@
from . import loader as data_loader_module


# TODO: A more general handler interface which does not relies on internal pd.DataFrame is needed.
class DataHandler(Serializable):
DATA_KEY_TYPE = Literal["raw", "infer", "learn"]


class DataHandlerABC(Serializable):
"""
Interface for data handler.

This class does not assume the internal data structure of the data handler.
It only defines the interface for external users (uses DataFrame as the internal data structure).

In the future, the data handler's more detailed implementation should be refactored. Here are some guidelines:

It covers several components:

- [data loader] -> internal representation of the data -> data preprocessing -> interface adaptor for the fetch interface
- The workflow to combine them all:
The workflow may be very complicated. DataHandlerLP is one of the practices, but it can't satisfy all the requirements.
So leaving the flexibility to the user to implement the workflow is a more reasonable choice.
"""

def __init__(self, *args, **kwargs):
"""
We should define how to get ready for the fetching.
"""
super().__init__(*args, **kwargs)

CS_ALL = "__all" # return all columns with single-level index column
CS_RAW = "__raw" # return raw data with multi-level index column

# data key
DK_R: DATA_KEY_TYPE = "raw"
DK_I: DATA_KEY_TYPE = "infer"
DK_L: DATA_KEY_TYPE = "learn"

@abstractmethod
def fetch(
self,
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = CS_ALL,
data_key: DATA_KEY_TYPE = DK_I,
) -> pd.DataFrame:
pass


class DataHandler(DataHandlerABC):
"""
The motivation of DataHandler:

- It provides an implementation of BaseDataHandler that we implement with:
- Handling responses with an internal loaded DataFrame
- The DataFrame is loaded by a data loader.

The steps to using a handler
1. initialized data handler (call by `init`).
2. use the data.
Expand Down Expand Up @@ -144,16 +195,14 @@ def setup_data(self, enable_cache: bool = False):
self._data = lazy_sort_index(self.data_loader.load(self.instruments, self.start_time, self.end_time))
# TODO: cache

CS_ALL = "__all" # return all columns with single-level index column
CS_RAW = "__raw" # return raw data with multi-level index column

def fetch(
self,
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = CS_ALL,
col_set: Union[str, List[str]] = DataHandlerABC.CS_ALL,
data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I,
squeeze: bool = False,
proc_func: Callable = None,
proc_func: Optional[Callable] = None,
) -> pd.DataFrame:
"""
fetch data from underlying data source
Expand Down Expand Up @@ -216,6 +265,8 @@ def fetch(
-------
pd.DataFrame.
"""
# DataHandler is an example with only one dataframe, so data_key is not used.
_ = data_key # avoid linting errors (e.g., unused-argument)
return self._fetch_data(
data_storage=self._data,
selector=selector,
Expand All @@ -230,7 +281,7 @@ def _fetch_data(
data_storage,
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = CS_ALL,
col_set: Union[str, List[str]] = DataHandlerABC.CS_ALL,
squeeze: bool = False,
proc_func: Callable = None,
):
Expand Down Expand Up @@ -261,16 +312,9 @@ def _fetch_data(
data_df = fetch_df_by_col(data_df, col_set)
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
elif isinstance(data_storage, BaseHandlerStorage):
if not data_storage.is_proc_func_supported():
if proc_func is not None:
raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}")
data_df = data_storage.fetch(
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig
)
else:
data_df = data_storage.fetch(
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func
)
if proc_func is not None:
raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}")
data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig)
else:
raise TypeError(f"data_storage should be pd.DataFrame|HashingStockStorage, not {type(data_storage)}")

Expand All @@ -282,7 +326,7 @@ def _fetch_data(
data_df = data_df.reset_index(level=level, drop=True)
return data_df

def get_cols(self, col_set=CS_ALL) -> list:
def get_cols(self, col_set=DataHandlerABC.CS_ALL) -> list:
"""
get the column names

Expand Down Expand Up @@ -336,11 +380,12 @@ def get_range_iterator(
yield cur_date, self.fetch(selector, **kwargs)


DATA_KEY_TYPE = Literal["raw", "infer", "learn"]


class DataHandlerLP(DataHandler):
"""
Motivation:
- For the case that we hope using different processor workflows for learning and inference;


DataHandler with **(L)earnable (P)rocessor**

This handler will produce three pieces of data in pd.DataFrame format.
Expand Down Expand Up @@ -374,12 +419,8 @@ class DataHandlerLP(DataHandler):
_infer: pd.DataFrame # data for inference
_learn: pd.DataFrame # data for learning models

# data key
DK_R: DATA_KEY_TYPE = "raw"
DK_I: DATA_KEY_TYPE = "infer"
DK_L: DATA_KEY_TYPE = "learn"
# map data_key to attribute name
ATTR_MAP = {DK_R: "_data", DK_I: "_infer", DK_L: "_learn"}
ATTR_MAP = {DataHandler.DK_R: "_data", DataHandler.DK_I: "_infer", DataHandler.DK_L: "_learn"}

# process type
PTYPE_I = "independent"
Expand Down Expand Up @@ -622,7 +663,7 @@ def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs):

# TODO: Be able to cache handler data. Save the memory for data processing

def _get_df_by_key(self, data_key: DATA_KEY_TYPE = DK_I) -> pd.DataFrame:
def _get_df_by_key(self, data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I) -> pd.DataFrame:
if data_key == self.DK_R and self.drop_raw:
raise AttributeError(
"DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data"
Expand All @@ -635,7 +676,7 @@ def fetch(
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
level: Union[str, int] = "datetime",
col_set=DataHandler.CS_ALL,
data_key: DATA_KEY_TYPE = DK_I,
data_key: DATA_KEY_TYPE = DataHandler.DK_I,
squeeze: bool = False,
proc_func: Callable = None,
) -> pd.DataFrame:
Expand Down Expand Up @@ -669,7 +710,7 @@ def fetch(
proc_func=proc_func,
)

def get_cols(self, col_set=DataHandler.CS_ALL, data_key: DATA_KEY_TYPE = DK_I) -> list:
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I) -> list:
"""
get the column names

Expand Down
56 changes: 39 additions & 17 deletions qlib/data/dataset/storage.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from abc import abstractmethod
import pandas as pd
import numpy as np

from .handler import DataHandler
from typing import Union, List, Callable
from typing import Union, List
from qlib.log import get_module_logger

from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col

Expand All @@ -14,14 +16,13 @@ class BaseHandlerStorage:
- If users want to use custom data storage, they should define subclass inherited BaseHandlerStorage, and implement the following method
"""

@abstractmethod
def fetch(
self,
selector: Union[pd.Timestamp, slice, str, list] = slice(None, None),
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
fetch_orig: bool = True,
proc_func: Callable = None,
**kwargs,
) -> pd.DataFrame:
"""fetch data from the data storage

Expand All @@ -41,8 +42,6 @@ def fetch(
select several sets of meaningful columns, the returned data has multiple level
fetch_orig : bool
Return the original data instead of copy if possible.
proc_func: Callable
please refer to the doc of DataHandler.fetch

Returns
-------
Expand All @@ -51,13 +50,40 @@ def fetch(
"""
raise NotImplementedError("fetch is method not implemented!")

@staticmethod
def from_df(df: pd.DataFrame):
raise NotImplementedError("from_df method is not implemented!")

def is_proc_func_supported(self):
"""whether the arg `proc_func` in `fetch` method is supported."""
raise NotImplementedError("is_proc_func_supported method is not implemented!")
class NaiveDFStorage(BaseHandlerStorage):
"""Naive data storage for datahandler
- NaiveDFStorage is a naive data storage for datahandler
- NaiveDFStorage will input a pandas.DataFrame as and provide interface support for fetching data
"""

def __init__(self, df: pd.DataFrame):
self.df = df

def fetch(
self,
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
fetch_orig: bool = True,
) -> pd.DataFrame:

# Following conflicts may occur
# - Does [20200101", "20210101"] mean selecting this slice or these two days?
# To solve this issue
# - slice have higher priorities (except when level is none)
if isinstance(selector, (tuple, list)) and level is not None:
# when level is None, the argument will be passed in directly
# we don't have to convert it into slice
try:
selector = slice(*selector)
except ValueError:
get_module_logger("DataHandlerLP").info(f"Fail to converting to query to slice. It will used directly")

data_df = self.df
data_df = fetch_df_by_col(data_df, col_set)
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=fetch_orig)
return data_df


class HashingStockStorage(BaseHandlerStorage):
Expand Down Expand Up @@ -142,7 +168,7 @@ def _fetch_hash_df_by_stock(self, selector, level):

def fetch(
self,
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
fetch_orig: bool = True,
Expand All @@ -164,7 +190,3 @@ def fetch(
return fetch_stock_df_list[0]
else:
return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig)

def is_proc_func_supported(self):
"""the arg `proc_func` in `fetch` method is not supported in HashingStockStorage"""
return False
4 changes: 3 additions & 1 deletion qlib/model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,9 @@ def __init__(
self.train_func = train_func
self._call_in_subproc = call_in_subproc

def train(self, tasks: list, train_func: Callable = None, experiment_name: str = None, **kwargs) -> List[Recorder]:
def train(
self, tasks: list, train_func: Optional[Callable] = None, experiment_name: Optional[str] = None, **kwargs
) -> List[Recorder]:
"""
Given a list of `tasks` and return a list of trained Recorder. The order can be guaranteed.

Expand Down
Loading