Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions datatune/core/filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
def _is_ibis_table(obj):
try:
import ibis
return isinstance(obj, ibis.Table)
from .ibis.lazy_pipeline import LazyTable
return isinstance(obj, (ibis.Table, LazyTable))
except ImportError:
return False

Expand All @@ -23,10 +24,18 @@ def apply(llm, data):
)(llm, data)
elif _is_ibis_table(data):
from .ibis.filter_ibis import _filter_ibis
return _filter_ibis(
from .ibis.lazy_pipeline import LazyTable, FilterNode

filter_obj = _filter_ibis(
prompt=prompt,
input_fields=input_fields,
)(llm, data)
input_fields=input_fields
)
filter_obj.llm = llm

if not isinstance(data, LazyTable):
data = LazyTable(data)

return LazyTable(FilterNode(filter_obj, data))
raise TypeError(f"Unsupported data type: {type(data)}")

return apply
Expand Down
14 changes: 7 additions & 7 deletions datatune/core/ibis/filter_ibis.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ def __call__(self, llm: Callable, table: Table) -> Table:
self.llm = llm
if self.input_fields:
missing = [f for f in self.input_fields if f not in table.columns]
if missing:
error_msg = (
f"[datatune] Schema mismatch: The following input_fields were not found: {missing}. "
f"Available columns: {list(table.columns)}"
)
logger.error(error_msg)
if missing:
error_msg = (
f"[datatune] Schema mismatch: The following input_fields were not found: {missing}. "
f"Available columns: {list(table.columns)}"
)
logger.error(error_msg)

raise ValueError(error_msg)
raise ValueError(error_msg)

table = add_serialized_col(
table,
Expand Down
43 changes: 43 additions & 0 deletions datatune/core/ibis/lazy_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
class LazyTable:

def __init__(self, plan):
self.plan = plan

def to_ibis(self):
if isinstance(self.plan, PlanNode):
return self.plan.to_ibis()
return self.plan

def execute(self):
ibis_table = self.to_ibis()
return ibis_table.execute()

def show_plan(self):
return repr(self.plan)

class PlanNode:

def to_ibis(self):
raise NotImplementedError

class MapNode(PlanNode):
def __init__(self, map_obj, source):
self.map_obj = map_obj
self.source = source

def to_ibis(self):
table = self.source.to_ibis()
return self.map_obj(self.map_obj.llm, table)

class FilterNode(PlanNode):
def __init__(self, filter_obj, source):
self.filter_obj = filter_obj
self.source = source

def to_ibis(self):
table = self.source.to_ibis()
return self.filter_obj(self.filter_obj.llm, table)




14 changes: 7 additions & 7 deletions datatune/core/ibis/map_ibis.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,14 @@ def __call__(self, llm: Callable, table: Table) -> Table:
self.llm = llm
if self.input_fields:
missing = [f for f in self.input_fields if f not in table.columns]
if missing:
error_msg = (
f"[datatune] Schema mismatch: The following input_fields were not found: {missing}. "
f"Available columns: {list(table.columns)}"
)
logger.error(error_msg)
if missing:
error_msg = (
f"[datatune] Schema mismatch: The following input_fields were not found: {missing}. "
f"Available columns: {list(table.columns)}"
)
logger.error(error_msg)

raise ValueError(error_msg)
raise ValueError(error_msg)


table = add_serialized_col(
Expand Down
16 changes: 13 additions & 3 deletions datatune/core/map.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@

def _is_ibis_table(obj):
try:
import ibis
return isinstance(obj, ibis.Table)
from .ibis.lazy_pipeline import LazyTable
return isinstance(obj, (ibis.Table,LazyTable))
except ImportError:
return False

Expand All @@ -24,11 +26,19 @@ def apply(llm, data):
)(llm, data)
elif _is_ibis_table(data):
from .ibis.map_ibis import _map_ibis
return _map_ibis(
from .ibis.lazy_pipeline import LazyTable, MapNode

map_obj = _map_ibis(
prompt=prompt,
output_fields=output_fields,
input_fields=input_fields,
)(llm, data)
)
map_obj.llm = llm

if not isinstance(data, LazyTable):
data = LazyTable(data)

return LazyTable(MapNode(map_obj, data))

raise TypeError(f"Unsupported data type: {type(data)}")

Expand Down