diff --git a/ccflow/examples/tpch/__init__.py b/ccflow/examples/tpch/__init__.py index 534775e..0cf1db7 100644 --- a/ccflow/examples/tpch/__init__.py +++ b/ccflow/examples/tpch/__init__.py @@ -1,3 +1,77 @@ -from .base import * -from .data_generators import * -from .query import * +"""TPC-H example for ccflow. + +This package is a *teaching* example showing how to compose a workflow from +``CallableModel``s wired together through the ``ModelRegistry``. The +canonical usage is:: + + from ccflow import ModelRegistry + from ccflow.examples.tpch import load_config + + load_config() # populate the root ModelRegistry from conf.yaml + registry = ModelRegistry.root() + result = registry["/query/Q1"]() # run TPC-H query 1 + print(result.df.to_native()) + +To run the same example at a different TPC-H scale factor, override the +single shared backend on load (every table / answer / query references it, +so the change flows through everywhere):: + + load_config(overrides=["tpch.backend.scale_factor=1.0"]) +""" + +from pathlib import Path +from typing import List, Optional + +from ccflow import RootModelRegistry, load_config as _load_config_base + +from .data_generators import TPCHAnswerProvider, TPCHDuckDBBackend, TPCHTable, TPCHTableProvider +from .query import TPCHQuery + +__all__ = ( + "TPCHTable", + "TPCHDuckDBBackend", + "TPCHTableProvider", + "TPCHAnswerProvider", + "TPCHQuery", + "load_config", +) + + +def load_config( + config_dir: str = "", + config_name: str = "", + overrides: Optional[List[str]] = None, + *, + overwrite: bool = True, + basepath: str = "", +) -> RootModelRegistry: + """Load the TPC-H example registry into the root ``ModelRegistry``. + + Pass hydra-style ``overrides`` to reconfigure entries on load — most + usefully ``["tpch.backend.scale_factor=1.0"]`` to run the example at a + different TPC-H scale factor. Every table / answer / query references the + single ``/tpch/backend`` entry, so this one override flows through to all + 22+8 providers. + + Args: + config_dir: Optional extra hydra config directory to overlay on top + of the bundled ``config/conf.yaml``. Empty string (the default) + means "use only the bundled config". + config_name: Optional config name within ``config_dir`` to load. + overrides: Hydra override strings, e.g. + ``["tpch.backend.scale_factor=1.0"]``. + overwrite: When True (the default), entries already present in the + registry are replaced. This is what you want in notebooks where + you re-call ``load_config()`` after tweaking overrides; set to + False to require a fresh registry. + basepath: Base path for resolving a relative ``config_dir``. + """ + return _load_config_base( + root_config_dir=str(Path(__file__).resolve().parent / "config"), + root_config_name="conf", + config_dir=config_dir, + config_name=config_name, + overrides=overrides, + overwrite=overwrite, + basepath=basepath, + ) diff --git a/ccflow/examples/tpch/base.py b/ccflow/examples/tpch/base.py deleted file mode 100644 index e018061..0000000 --- a/ccflow/examples/tpch/base.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Literal - -from pydantic import conint - -from ccflow import ContextBase - -__all__ = ( - "TPCHTable", - "TPCHTableContext", - "TPCHQueryContext", -) - - -TPCHTable = Literal["customer", "lineitem", "nation", "orders", "part", "partsupp", "region", "supplier"] - - -class TPCHTableContext(ContextBase): - table: TPCHTable - - -class TPCHQueryContext(ContextBase): - query_id: conint(ge=1, le=22) diff --git a/ccflow/examples/tpch/config/conf.yaml b/ccflow/examples/tpch/config/conf.yaml index e69de29..b888697 100644 --- a/ccflow/examples/tpch/config/conf.yaml +++ b/ccflow/examples/tpch/config/conf.yaml @@ -0,0 +1,262 @@ +# TPC-H example registry. +# +# This file is loaded by ``ccflow.examples.tpch.load_config()`` into the root +# ``ModelRegistry`` and demonstrates several ccflow features: +# +# 1. A *flat* model graph defined entirely in YAML — Python code defines the +# classes (``TPCHDuckDBBackend``, ``TPCHTableProvider``, ``TPCHAnswerProvider``, +# ``TPCHQuery``); this file decides which instances exist and how they are +# wired together. +# 2. Cross-references between registry entries. Strings beginning with ``/`` +# are absolute paths into the registry; ccflow's pydantic validators +# resolve them to the actual configured Python instance at config-load +# time. Resolution is by reference (not copy), so every provider below +# points at the *same* ``/tpch/backend`` instance, and ``dbgen`` runs +# exactly once for the whole registry. Order within this file does not +# matter — references are resolved after the whole file is parsed. +# 3. Explicit dependencies on a generic ``CallableModel``. ``TPCHQuery`` has +# an ``inputs: tuple[CallableModel[NullContext, NarwhalsFrameResult], ...]`` +# field; the registry resolves each ``/table/`` reference into the +# corresponding ``TPCHTableProvider`` instance, so each query's table +# dependencies are first-class fields on that query's model instance. +# 4. ``scale_factor`` lives on a single backend entry, so loading the same +# config with a hydra override +# (``load_config(overrides=["tpch.backend.scale_factor=1.0"])``) reconfigures +# every table, answer and query consistently. + +# --------------------------------------------------------------------------- +# Shared DuckDB backend. Plain ``ccflow.BaseModel`` — not callable itself, +# but registered so all providers share one connection and one ``dbgen`` call. +# --------------------------------------------------------------------------- +tpch: + backend: + _target_: ccflow.examples.tpch.TPCHDuckDBBackend + scale_factor: 0.1 + +# --------------------------------------------------------------------------- +# Per-table providers. One instance per TPC-H table; the output schema of +# each instance is fixed by its ``table`` field. +# --------------------------------------------------------------------------- +table: + customer: + _target_: ccflow.examples.tpch.TPCHTableProvider + backend: /tpch/backend + table: customer + lineitem: + _target_: ccflow.examples.tpch.TPCHTableProvider + backend: /tpch/backend + table: lineitem + nation: + _target_: ccflow.examples.tpch.TPCHTableProvider + backend: /tpch/backend + table: nation + orders: + _target_: ccflow.examples.tpch.TPCHTableProvider + backend: /tpch/backend + table: orders + part: + _target_: ccflow.examples.tpch.TPCHTableProvider + backend: /tpch/backend + table: part + partsupp: + _target_: ccflow.examples.tpch.TPCHTableProvider + backend: /tpch/backend + table: partsupp + region: + _target_: ccflow.examples.tpch.TPCHTableProvider + backend: /tpch/backend + table: region + supplier: + _target_: ccflow.examples.tpch.TPCHTableProvider + backend: /tpch/backend + table: supplier + +# --------------------------------------------------------------------------- +# Reference answers, one per query, served straight from DuckDB's +# ``tpch_answers()`` table at the configured scale factor. +# --------------------------------------------------------------------------- +answer: + Q1: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 1 + Q2: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 2 + Q3: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 3 + Q4: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 4 + Q5: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 5 + Q6: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 6 + Q7: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 7 + Q8: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 8 + Q9: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 9 + Q10: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 10 + Q11: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 11 + Q12: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 12 + Q13: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 13 + Q14: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 14 + Q15: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 15 + Q16: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 16 + Q17: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 17 + Q18: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 18 + Q19: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 19 + Q20: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 20 + Q21: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 21 + Q22: + _target_: ccflow.examples.tpch.TPCHAnswerProvider + backend: /tpch/backend + query_id: 22 + +# --------------------------------------------------------------------------- +# The 22 TPC-H queries. Each ``TPCHQuery`` is the same Python class with a +# different ``query_id`` and a different tuple of table-provider inputs. +# Wiring the inputs in YAML makes each query's table dependencies explicit +# and overridable per-query. +# --------------------------------------------------------------------------- +query: + Q1: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 1 + inputs: [/table/lineitem] + Q2: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 2 + inputs: [/table/region, /table/nation, /table/supplier, /table/part, /table/partsupp] + Q3: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 3 + inputs: [/table/customer, /table/lineitem, /table/orders] + Q4: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 4 + inputs: [/table/lineitem, /table/orders] + Q5: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 5 + inputs: [/table/region, /table/nation, /table/customer, /table/lineitem, /table/orders, /table/supplier] + Q6: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 6 + inputs: [/table/lineitem] + Q7: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 7 + inputs: [/table/nation, /table/customer, /table/lineitem, /table/orders, /table/supplier] + Q8: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 8 + inputs: [/table/part, /table/supplier, /table/lineitem, /table/orders, /table/customer, /table/nation, /table/region] + Q9: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 9 + inputs: [/table/part, /table/partsupp, /table/nation, /table/lineitem, /table/orders, /table/supplier] + Q10: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 10 + inputs: [/table/customer, /table/nation, /table/lineitem, /table/orders] + Q11: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 11 + inputs: [/table/nation, /table/partsupp, /table/supplier] + Q12: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 12 + inputs: [/table/lineitem, /table/orders] + Q13: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 13 + inputs: [/table/customer, /table/orders] + Q14: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 14 + inputs: [/table/lineitem, /table/part] + Q15: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 15 + inputs: [/table/lineitem, /table/supplier] + Q16: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 16 + inputs: [/table/part, /table/partsupp, /table/supplier] + Q17: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 17 + inputs: [/table/lineitem, /table/part] + Q18: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 18 + inputs: [/table/customer, /table/lineitem, /table/orders] + Q19: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 19 + inputs: [/table/lineitem, /table/part] + Q20: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 20 + inputs: [/table/part, /table/partsupp, /table/nation, /table/lineitem, /table/supplier] + Q21: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 21 + inputs: [/table/lineitem, /table/nation, /table/orders, /table/supplier] + Q22: + _target_: ccflow.examples.tpch.TPCHQuery + query_id: 22 + inputs: [/table/customer, /table/orders] diff --git a/ccflow/examples/tpch/data_generators.py b/ccflow/examples/tpch/data_generators.py index bc67084..d6a0f84 100644 --- a/ccflow/examples/tpch/data_generators.py +++ b/ccflow/examples/tpch/data_generators.py @@ -13,55 +13,69 @@ """ import io -from typing import Any +from typing import Any, Literal import duckdb import polars as pl import pyarrow as pa import pyarrow.csv as pc -from pydantic import model_validator +from pydantic import conint, model_validator -from ccflow import CallableModel, Flow +from ccflow import BaseModel, CallableModel, Flow, NullContext from ccflow.result.narwhals import NarwhalsDataFrameResult -from .base import TPCHQueryContext, TPCHTableContext +__all__ = ("TPCHTable", "TPCHDuckDBBackend", "TPCHTableProvider", "TPCHAnswerProvider") -__all__ = ("TPCHAnswerGenerator", "TPCHDataGenerator") +TPCHTable = Literal["customer", "lineitem", "nation", "orders", "part", "partsupp", "region", "supplier"] -class TPCHAnswerGenerator(CallableModel): - """Generates data for the TPC-H benchmark.""" - scale_factor: float - _conn: Any = None +def _convert_schema(schema: pa.Schema) -> pa.Schema: + """Cast decimal columns to float64 and date32 to ns timestamp. - @model_validator(mode="after") - def _validate(self): - if self._conn is None: - self._conn = duckdb.connect(":memory:") - self._conn.execute("INSTALL tpch; LOAD tpch") - return self - - def get_query(self, context) -> str: - return f""" - SELECT answer FROM tpch_answers() - WHERE scale_factor={self.scale_factor} AND query_nr={context.query_id} - """ - - @Flow.call() - def __call__(self, context: TPCHQueryContext) -> NarwhalsDataFrameResult: - """Generates data for the TPC-H benchmark.""" - results = self._conn.query(self.get_query(context)) - row = results.fetchone() - if row: - answer = row[0] - tbl_answer = pc.read_csv(io.BytesIO(answer.encode("utf-8")), parse_options=pc.ParseOptions(delimiter="|")) - return NarwhalsDataFrameResult(df=tbl_answer) + Narwhals' polars/pandas/etc. backends prefer these dtypes over the + DuckDB-native decimal/date32 representations. + """ + new_fields = [] + for field in schema: + if pa.types.is_decimal(field.type): + new_fields.append(pa.field(field.name, pa.float64())) + elif field.type == pa.date32(): + new_fields.append(pa.field(field.name, pa.timestamp("ns"))) else: - raise ValueError(f"No TPCH answers found for the given scale factor ({self.scale_factor}) and query number ({context.query_id}).") + new_fields.append(field) + return pa.schema(new_fields) + + +class TPCHDuckDBBackend(BaseModel): + """Shared DuckDB connection that runs ``dbgen`` once for a given scale factor. + + This is a plain ``ccflow.BaseModel``, not a ``CallableModel``. The + distinction matters in ccflow: + + * ``CallableModel`` subclasses are the only models the framework invokes + as workflow steps (via ``@Flow.call``). They represent *something to + run*. + * ``BaseModel`` subclasses live in the ``ModelRegistry`` as plain + configured Python objects — useful for shared state, connections, + configuration that other models depend on. They are not themselves + callable as workflow steps. + + This backend is shared state: it owns one DuckDB connection and ensures + ``dbgen(sf=...)`` runs exactly once. By registering it under + ``/tpch/backend`` and having every ``TPCHTableProvider`` / + ``TPCHAnswerProvider`` reference that same path, the whole example uses a + single connection regardless of how many providers are instantiated. + + Note on ``_conn`` / ``_generated``: leading-underscore annotated fields + on a ``BaseModel`` become Pydantic ``PrivateAttr``s — they are not part + of the model's public schema, and *they are not preserved by + ``model_copy()``*. The ``model_validator`` below re-initialises the + connection on every fresh instance, so a copied backend would simply + create its own connection (and re-run ``dbgen`` lazily on first use) + rather than share the original's state. + """ - -class TPCHDataGenerator(CallableModel): scale_factor: float _conn: Any = None _generated: bool = False @@ -73,30 +87,51 @@ def _validate(self): self._conn.execute("INSTALL tpch; LOAD tpch") return self - def _generate_if_needed(self): - if self._generated: - return - self._conn.execute(f"CALL dbgen(sf={self.scale_factor})") - self._generated = True - - def convert_schema(self, schema: pa.Schema) -> pa.Schema: - new_schema = [] - for field in schema: - if pa.types.is_decimal(field.type): - new_schema.append(pa.field(field.name, pa.float64())) - elif field.type == pa.date32(): - new_schema.append(pa.field(field.name, pa.timestamp("ns"))) - else: - new_schema.append(field) - return pa.schema(new_schema) + def _ensure_generated(self) -> None: + if not self._generated: + self._conn.execute(f"CALL dbgen(sf={self.scale_factor})") + self._generated = True + + def get_table(self, table: TPCHTable) -> pl.DataFrame: + self._ensure_generated() + tbl_arrow = self._conn.query(f"SELECT * FROM {table}").to_arrow_table() + tbl_arrow = tbl_arrow.cast(_convert_schema(tbl_arrow.schema)) + # Use the polars backend by default; it's the fastest narwhals backend + # for the downstream query bodies. + return pl.from_arrow(tbl_arrow) + + def get_answer(self, query_id: int) -> pa.Table: + row = self._conn.query(f"SELECT answer FROM tpch_answers() WHERE scale_factor={self.scale_factor} AND query_nr={query_id}").fetchone() + if not row: + raise ValueError(f"No TPC-H answer found for scale_factor={self.scale_factor}, query_nr={query_id}") + return pc.read_csv(io.BytesIO(row[0].encode("utf-8")), parse_options=pc.ParseOptions(delimiter="|")) + + +class TPCHTableProvider(CallableModel): + """Provides a single TPC-H table as a Narwhals frame. + + One instance per table; the output schema is fixed by the ``table`` field. + The call takes a ``NullContext`` because the provider has no runtime + parameters — everything it needs is already on the model itself. The + ``= NullContext()`` default lets callers (such as ``TPCHQuery``) invoke + the provider with no arguments; ``@Flow.call`` reads the default from the + signature in that case. + """ + + backend: TPCHDuckDBBackend + table: TPCHTable + + @Flow.call + def __call__(self, context: NullContext = NullContext()) -> NarwhalsDataFrameResult: + return NarwhalsDataFrameResult(df=self.backend.get_table(self.table)) + + +class TPCHAnswerProvider(CallableModel): + """Provides the canonical reference answer for a single TPC-H query.""" + + backend: TPCHDuckDBBackend + query_id: conint(ge=1, le=22) @Flow.call - def __call__(self, context: TPCHTableContext) -> NarwhalsDataFrameResult: - """Generates data for the TPC-H benchmark.""" - self._generate_if_needed() - tbl = self._conn.query(f"SELECT * FROM {context.table}") - tbl_arrow = tbl.to_arrow_table() - new_schema = self.convert_schema(tbl_arrow.schema) - tbl_arrow = tbl_arrow.cast(new_schema) - # Convert to Polars DataFrame to use the polars backend by default for downstream calculations (it's faster) - return NarwhalsDataFrameResult(df=pl.from_arrow(tbl_arrow)) + def __call__(self, context: NullContext = NullContext()) -> NarwhalsDataFrameResult: + return NarwhalsDataFrameResult(df=self.backend.get_answer(self.query_id)) diff --git a/ccflow/examples/tpch/query.py b/ccflow/examples/tpch/query.py index 2982284..184d3cc 100644 --- a/ccflow/examples/tpch/query.py +++ b/ccflow/examples/tpch/query.py @@ -1,51 +1,52 @@ +"""Generic TPC-H query runner. + +A single ``TPCHQuery`` class can express any of the 22 TPC-H queries; each +query gets one configured instance in the registry (``query/Q1`` ... +``query/Q22``). The query's table dependencies are explicit Pydantic fields +on each instance (via ``inputs``). +""" + from importlib import import_module -from typing import Dict, Tuple +from typing import Tuple -from pydantic import Field +from pydantic import conint -from ccflow import CallableModel, CallableModelGenericType, Flow +from ccflow import CallableModel, CallableModelGenericType, Flow, NullContext from ccflow.result.narwhals import NarwhalsFrameResult -from .base import TPCHQueryContext, TPCHTable, TPCHTableContext - -__all__ = ("TPCHQueryRunner",) - - -_QUERY_TABLE_MAP: Dict[int, Tuple[TPCHTable, ...]] = { - 1: ("lineitem",), - 2: ("region", "nation", "supplier", "part", "partsupp"), - 3: ("customer", "lineitem", "orders"), - 4: ("lineitem", "orders"), - 5: ("region", "nation", "customer", "lineitem", "orders", "supplier"), - 6: ("lineitem",), - 7: ("nation", "customer", "lineitem", "orders", "supplier"), - 8: ("part", "supplier", "lineitem", "orders", "customer", "nation", "region"), - 9: ("part", "partsupp", "nation", "lineitem", "orders", "supplier"), - 10: ("customer", "nation", "lineitem", "orders"), - 11: ("nation", "partsupp", "supplier"), - 12: ("lineitem", "orders"), - 13: ("customer", "orders"), - 14: ("lineitem", "part"), - 15: ("lineitem", "supplier"), - 16: ("part", "partsupp", "supplier"), - 17: ("lineitem", "part"), - 18: ("customer", "lineitem", "orders"), - 19: ("lineitem", "part"), - 20: ("part", "partsupp", "nation", "lineitem", "supplier"), - 21: ("lineitem", "nation", "orders", "supplier"), - 22: ("customer", "orders"), -} - - -class TPCHQueryRunner(CallableModel): - """Generically runs TPC-H queries from a pre-packaged repository of queries (courtesy of narwhals).""" - - table_provider: CallableModelGenericType[TPCHTableContext, NarwhalsFrameResult] - query_table_map: Dict[int, Tuple[TPCHTable, ...]] = Field(_QUERY_TABLE_MAP, validate_default=True) +__all__ = ("TPCHQuery",) + + +class TPCHQuery(CallableModel): + """Runs one TPC-H query (``q{query_id}``) against a tuple of table providers. + + The query body itself comes from ``ccflow.examples.tpch.queries.q{N}.query`` + (vendored from narwhals); each input is called to produce a frame, and the + frames are passed positionally to that function in the order given. + + A few ccflow features worth noting on this class: + + * The ``inputs`` field is typed ``CallableModelGenericType[NullContext, + NarwhalsFrameResult]``. Using a ``CallableModelGenericType[C, R]`` as a + Pydantic field type causes ccflow to validate, when the registry is + loaded, that each resolved provider's ``__call__`` actually takes a + ``NullContext`` (or subclass) and returns a ``NarwhalsFrameResult`` (or + subclass). The configured ``TPCHTableProvider`` instances satisfy this + because their return type, ``NarwhalsDataFrameResult``, is a subclass + of ``NarwhalsFrameResult``. + * Each provider is invoked with no arguments (``provider()``). The + ``@Flow.call`` decorator on the provider's ``__call__`` reads the + ``context: NullContext = NullContext()`` default from the signature + when no context is passed in, so a "no-arg" call is well-defined. + """ + + query_id: conint(ge=1, le=22) + inputs: Tuple[CallableModelGenericType[NullContext, NarwhalsFrameResult], ...] @Flow.call - def __call__(self, context: TPCHQueryContext) -> NarwhalsFrameResult: - query_module = import_module(f"ccflow.examples.tpch.queries.q{context.query_id}") - inputs = (self.table_provider(TPCHTableContext(table=table)).df for table in self.query_table_map[context.query_id]) - result = query_module.query(*inputs) - return NarwhalsFrameResult(df=result) + def __call__(self, context: NullContext = NullContext()) -> NarwhalsFrameResult: + query_module = import_module(f"ccflow.examples.tpch.queries.q{self.query_id}") + # Materialise the frames eagerly into a tuple before unpacking, so the + # query body can iterate its inputs more than once if it wants to. + frames = tuple(provider().df for provider in self.inputs) + return NarwhalsFrameResult(df=query_module.query(*frames)) diff --git a/ccflow/tests/examples/test_tpch.py b/ccflow/tests/examples/test_tpch.py index 52aa0a0..dfc3417 100644 --- a/ccflow/tests/examples/test_tpch.py +++ b/ccflow/tests/examples/test_tpch.py @@ -3,46 +3,40 @@ import pytest from polars.testing import assert_frame_equal -from ccflow.examples.tpch import TPCHAnswerGenerator, TPCHDataGenerator, TPCHQueryContext, TPCHQueryRunner, TPCHTable, TPCHTableContext +from ccflow import ModelRegistry +from ccflow.examples.tpch import TPCHTable, load_config @pytest.fixture(scope="module") -def scale_factor(): - return 0.1 +def registry(): + # Load the TPC-H example registry from its YAML. We override the scale + # factor on the single shared backend; that one override flows through + # to all table/answer/query entries because they all reference + # ``/tpch/backend``. + load_config(overrides=["tpch.backend.scale_factor=0.1"], overwrite=True) + return ModelRegistry.root() -@pytest.fixture(scope="module") -def tpch_answer_generator(scale_factor): - return TPCHAnswerGenerator(scale_factor=scale_factor) - - -@pytest.fixture(scope="module") -def tpch_data_generator(scale_factor): - return TPCHDataGenerator(scale_factor=scale_factor) - - -@pytest.mark.parametrize("query_id", range(1, 23)) -def test_tpch_answer_generation(tpch_answer_generator, query_id): - context = TPCHQueryContext(query_id=query_id) - out = tpch_answer_generator(context) +@pytest.mark.parametrize("table", get_args(TPCHTable)) +def test_tpch_table_provider(registry, table): + provider = registry[f"/table/{table}"] + out = provider() assert out is not None assert len(out.df) > 0 -@pytest.mark.parametrize("table", get_args(TPCHTable)) -def test_tpch_data_generation(tpch_data_generator, table): - context = TPCHTableContext(table=table) - out = tpch_data_generator(context) +@pytest.mark.parametrize("query_id", range(1, 23)) +def test_tpch_answer_provider(registry, query_id): + provider = registry[f"/answer/Q{query_id}"] + out = provider() assert out is not None assert len(out.df) > 0 @pytest.mark.parametrize("query_id", range(1, 23)) -def test_tpch_queries(tpch_answer_generator, tpch_data_generator, query_id): - runner = TPCHQueryRunner(table_provider=tpch_data_generator) - context = TPCHQueryContext(query_id=query_id) - answer = tpch_answer_generator(context) - out = runner(context) - assert out is not None - assert answer is not None - assert_frame_equal(out.df.to_polars(), answer.df.to_polars(), check_dtypes=False) +def test_tpch_query(registry, query_id): + query = registry[f"/query/Q{query_id}"] + answer = registry[f"/answer/Q{query_id}"] + out = query() + expected = answer() + assert_frame_equal(out.df.to_polars(), expected.df.to_polars(), check_dtypes=False)