Skip to content

Commit cfd0090

Browse files
authored
fix: torchdataset for lance does not support S3 (#4045)
Signed-off-by: jukejian <jukejian@bytedance.com>
1 parent 084d73a commit cfd0090

2 files changed

Lines changed: 15 additions & 10 deletions

File tree

python/python/lance/torch/data.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import math
1212
import warnings
1313
from pathlib import Path
14-
from typing import Callable, Dict, Iterable, List, Literal, Optional, Union
14+
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Union
1515

1616
import pyarrow as pa
1717

@@ -182,6 +182,7 @@ def __init__(
182182
dataset: Union[torch.utils.data.Dataset, str, Path],
183183
batch_size: int,
184184
*args,
185+
dataset_options: Optional[Dict[str, Any]] = None,
185186
columns: Optional[Union[List[str], Dict[str, str]]] = None,
186187
filter: Optional[str] = None,
187188
samples: Optional[int] = 0,
@@ -237,7 +238,8 @@ def __init__(
237238
"""
238239
super().__init__()
239240
if isinstance(dataset, (str, Path)):
240-
dataset = lance.dataset(dataset)
241+
dataset_options = dataset_options or {}
242+
dataset = lance.dataset(dataset, **dataset_options)
241243
self.dataset = dataset
242244
self.columns = columns
243245
self.batch_size = batch_size
@@ -378,16 +380,18 @@ def _blob_columns(self) -> List[str]:
378380

379381

380382
class SafeLanceDataset(torch.utils.data.Dataset):
381-
def __init__(self, uri):
383+
def __init__(self, uri, *, dataset_options=None, **kwargs):
384+
super().__init__(**kwargs)
382385
self.uri = uri
386+
self.dataset_options = dataset_options or {}
383387
self._len = self._safe_preload()
384-
self._ds = None # Deferred initialization
388+
self._ds = None
385389

386390
def _safe_preload(self):
387391
"""Main-process safe metadata loading"""
388-
ds = lance.dataset(self.uri)
392+
ds = lance.dataset(self.uri, **self.dataset_options)
389393
length = ds.count_rows()
390-
del ds # Critical: release before spawning
394+
del ds
391395
return length
392396

393397
def __len__(self):

python/python/tests/torch_tests/test_data.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,10 +292,11 @@ def test_blob_api(tmp_path: Path):
292292
)
293293
tbl = pa.Table.from_arrays([ints, vals], schema=schema)
294294

295-
ds = lance.write_dataset(tbl, tmp_path / "data.lance")
295+
uri = tmp_path / "data.lance"
296+
dataset = lance.write_dataset(tbl, uri)
297+
296298
torch_ds = LanceDataset(
297-
ds,
298-
batch_size=4,
299+
uri, batch_size=4, dataset_options={"version": dataset.version}
299300
)
300301
with pytest.raises(NotImplementedError):
301302
next(iter(torch_ds))
@@ -314,7 +315,7 @@ def to_tensor_fn(batch, *args, **kwargs):
314315
return {"int": ints, "val": vals}
315316

316317
torch_ds = LanceDataset(
317-
ds,
318+
dataset,
318319
batch_size=4,
319320
to_tensor_fn=to_tensor_fn,
320321
)

0 commit comments

Comments
 (0)