Skip to content

Commit cd723c5

Browse files
committed
include some tests
1 parent f0dbbb3 commit cd723c5

11 files changed

Lines changed: 322 additions & 490 deletions

File tree

pysus/api/ducklake/catalog.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class CatalogTable(Base):
3535
__table_args__ = {"schema": "pysus"}
3636

3737

38-
class Dataset(CatalogTable):
38+
class CatalogDataset(CatalogTable):
3939
__tablename__ = "datasets"
4040

4141
id = Column(Integer, primary_key=True)
@@ -138,7 +138,7 @@ class DatasetGroup(CatalogTable):
138138
)
139139

140140

141-
class File(CatalogTable):
141+
class CatalogFile(CatalogTable):
142142
__tablename__ = "files"
143143

144144
id = Column(Integer, primary_key=True)

pysus/api/ducklake/client.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from sqlalchemy import create_engine
1212
from sqlalchemy.orm import joinedload, sessionmaker
1313

14-
from .catalog import Dataset, DatasetGroup
15-
from .models import CatalogDataset, CatalogFile
14+
from .catalog import CatalogDataset, DatasetGroup
15+
from .models import Dataset, File
1616

1717

1818
class DuckLakeCredentials(BaseModel):
@@ -47,25 +47,25 @@ def _catalog_url(self) -> str:
4747
def _is_authenticated(self) -> bool:
4848
return self.credentials is not None
4949

50-
async def datasets(self, **kwargs) -> List[CatalogDataset]:
50+
async def datasets(self, **kwargs) -> List[Dataset]:
5151
if not self._Session:
5252
await self.connect()
5353

5454
def _fetch():
5555
with self._Session() as session:
5656
return (
57-
session.query(Dataset)
57+
session.query(CatalogDataset)
5858
.options(
59-
joinedload(Dataset.dataset_metadata),
60-
joinedload(Dataset.groups).joinedload(
59+
joinedload(CatalogDataset.dataset_metadata),
60+
joinedload(CatalogDataset.groups).joinedload(
6161
DatasetGroup.files
6262
),
6363
)
6464
.all()
6565
)
6666

6767
records = await anyio.to_thread.run_sync(_fetch)
68-
return [CatalogDataset(record=rec, client=self) for rec in records]
68+
return [Dataset(record=rec, client=self) for rec in records]
6969

7070
async def login(
7171
self,
@@ -102,12 +102,12 @@ def _setup_engine(self):
102102
"s3_use_ssl": "true",
103103
}
104104
if self._is_authenticated:
105-
s3_cfg[
106-
"s3_access_key_id"
107-
] = self.credentials.access_key.get_secret_value()
108-
s3_cfg[
109-
"s3_secret_access_key"
110-
] = self.credentials.secret_key.get_secret_value()
105+
s3_cfg["s3_access_key_id"] = (
106+
self.credentials.access_key.get_secret_value()
107+
)
108+
s3_cfg["s3_secret_access_key"] = (
109+
self.credentials.secret_key.get_secret_value()
110+
)
111111

112112
for key, value in s3_cfg.items():
113113
conn.exec_driver_sql(f"SET {key}='{value}';")
@@ -134,7 +134,7 @@ async def close(self):
134134

135135
async def _download_file(
136136
self,
137-
file: "CatalogFile",
137+
file: "File",
138138
output: Path,
139139
callback: Optional[Callable[[int], None]] = None,
140140
) -> Path:
@@ -185,9 +185,7 @@ async def _load_catalog(self):
185185

186186
async def _upload_catalog(self):
187187
if not self._is_authenticated:
188-
raise PermissionError(
189-
"Admin credentials required to upload catalog."
190-
)
188+
raise PermissionError("Admin credentials required to upload catalog.")
191189

192190
def _upload():
193191
self._s3_client.upload_file(

pysus/api/ducklake/models.py

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
BaseRemoteGroup,
1313
)
1414

15-
from .catalog import Dataset, DatasetGroup, File
15+
from .catalog import CatalogDataset, DatasetGroup, CatalogFile
1616

1717

18-
class CatalogFile(BaseRemoteFile):
19-
record: File = Field(exclude=True)
20-
parent: Union["CatalogDataset", "CatalogGroup"] = Field(exclude=True)
18+
class File(BaseRemoteFile):
19+
record: CatalogFile = Field(exclude=True)
20+
parent: Union["Dataset", "Group"] = Field(exclude=True)
2121

2222
type: str = "remote"
2323

@@ -52,9 +52,7 @@ def sha256(self) -> Optional[str]:
5252
async def _download(
5353
self, output: Path, callback: Optional[Callable[[int], None]] = None
5454
) -> Path:
55-
return await self.client._download_file(
56-
self, output, callback=callback
57-
)
55+
return await self.client._download_file(self, output, callback=callback)
5856

5957
async def verify(self, path: Path) -> bool:
6058
if not self.sha256:
@@ -71,9 +69,9 @@ def _calculate():
7169
return actual_hash == self.sha256
7270

7371

74-
class CatalogGroup(BaseRemoteGroup):
72+
class Group(BaseRemoteGroup):
7573
record: DatasetGroup = Field(exclude=True)
76-
dataset: "CatalogDataset" = Field(exclude=True)
74+
dataset: "Dataset" = Field(exclude=True)
7775

7876
@property
7977
def name(self) -> str:
@@ -90,17 +88,15 @@ def long_name(self) -> str:
9088
@property
9189
def description(self) -> str:
9290
return (
93-
self.record.group_metadata.description
94-
if self.record.group_metadata
95-
else ""
91+
self.record.group_metadata.description if self.record.group_metadata else ""
9692
)
9793

98-
async def files(self, **kwargs) -> List[CatalogFile]:
99-
return [CatalogFile(record=f, parent=self) for f in self.record.files]
94+
async def files(self, **kwargs) -> List[File]:
95+
return [File(record=f, parent=self) for f in self.record.files]
10096

10197

102-
class CatalogDataset(BaseRemoteDataset):
103-
record: Dataset = Field(exclude=True)
98+
class Dataset(BaseRemoteDataset):
99+
record: CatalogDataset = Field(exclude=True)
104100
client: BaseRemoteClient = Field(exclude=True)
105101

106102
@property
@@ -123,22 +119,13 @@ def description(self) -> str:
123119
else ""
124120
)
125121

126-
async def content(
127-
self, **kwargs
128-
) -> List[Union[CatalogGroup, CatalogFile]]:
122+
async def content(self, **kwargs) -> List[Union[Group, File]]:
129123
items = []
130124

131125
if self.record.groups:
132-
items.extend(
133-
[
134-
CatalogGroup(record=g, dataset=self)
135-
for g in self.record.groups
136-
]
137-
)
126+
items.extend([Group(record=g, dataset=self) for g in self.record.groups])
138127

139128
if self.record.files:
140-
items.extend(
141-
[CatalogFile(record=f, parent=self) for f in self.record.files]
142-
)
129+
items.extend([File(record=f, parent=self) for f in self.record.files])
143130

144131
return items

pysus/api/ducklake/storage.py

Lines changed: 0 additions & 15 deletions
This file was deleted.

pysus/api/extensions.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,13 @@ def sniff():
124124
async def load(self) -> pd.DataFrame:
125125
encoding = await self._get_encoding()
126126
separator = await self._get_sep()
127-
return await anyio.to_thread.run_sync(
128-
pd.read_csv, self.path, sep=separator, encoding=encoding
129-
)
127+
128+
def _read_sync():
129+
return pd.read_csv(
130+
self.path, sep=separator, encoding=encoding, low_memory=False
131+
)
132+
133+
return await anyio.to_thread.run_sync(_read_sync)
130134

131135
async def stream(
132136
self,
@@ -135,7 +139,7 @@ async def stream(
135139
encoding = await self._get_encoding()
136140
separator = await self._get_sep()
137141

138-
def _get_reader():
142+
def _get_reader_sync():
139143
return pd.read_csv(
140144
self.path,
141145
sep=separator,
@@ -145,7 +149,7 @@ def _get_reader():
145149
low_memory=False,
146150
)
147151

148-
reader = await anyio.to_thread.run_sync(_get_reader)
152+
reader = await anyio.to_thread.run_sync(_get_reader_sync)
149153
for chunk in reader:
150154
yield chunk
151155
await anyio.sleep(0)
@@ -514,6 +518,12 @@ def _read():
514518

515519
return await anyio.to_thread.run_sync(_read)
516520

521+
async def list_members(self) -> List[str]:
522+
return [self.path.stem]
523+
524+
async def open_member(self, member_name: str) -> bytes:
525+
return await self.load()
526+
517527
async def extract(
518528
self, target_dir: Optional[Path] = CACHEPATH
519529
) -> List[BaseLocalFile]:
@@ -549,6 +559,14 @@ def _list():
549559

550560
return await anyio.to_thread.run_sync(_list)
551561

562+
async def open_member(self, member_name: str) -> bytes:
563+
def _read():
564+
with tarfile.open(self.path) as t:
565+
f = t.extractfile(member_name)
566+
return f.read() if f else b""
567+
568+
return await anyio.to_thread.run_sync(_read)
569+
552570
async def extract(
553571
self, target_dir: Optional[Path] = CACHEPATH
554572
) -> List[BaseLocalFile]:

0 commit comments

Comments
 (0)