|
| 1 | +import enum |
| 2 | +from collections.abc import Callable |
| 3 | +from datetime import datetime |
| 4 | +from pathlib import Path |
| 5 | + |
| 6 | +from pysus import CACHEPATH |
| 7 | +from sqlalchemy import Column, Integer, DateTime, Enum, String, create_engine |
| 8 | +from sqlalchemy.orm import declarative_base, sessionmaker |
| 9 | + |
| 10 | +from .dadosgov import DadosGovClient |
| 11 | +from .ducklake import DuckLakeClient |
| 12 | +from .ftp import FTPClient |
| 13 | +from .models import BaseLocalFile, BaseRemoteFile |
| 14 | + |
| 15 | +Base = declarative_base() |
| 16 | + |
| 17 | + |
| 18 | +class DownloadStatus(enum.Enum): |
| 19 | + PENDING = "pending" |
| 20 | + DOWNLOADING = "downloading" |
| 21 | + COMPLETED = "completed" |
| 22 | + FAILED = "failed" |
| 23 | + MISSING = "missing" |
| 24 | + |
| 25 | + |
| 26 | +class LocalFileState(Base): |
| 27 | + __tablename__ = "local_file_state" |
| 28 | + path = Column(String, primary_key=True) |
| 29 | + remote_path = Column(String, nullable=False) |
| 30 | + client_name = Column(String, nullable=False) |
| 31 | + year = Column(Integer, nullable=True) |
| 32 | + month = Column(Integer, nullable=True) |
| 33 | + state = Column(String, nullable=True) |
| 34 | + status = Column(Enum(DownloadStatus), default=DownloadStatus.PENDING) |
| 35 | + sha256 = Column(String, nullable=True) |
| 36 | + last_synced = Column(DateTime, default=datetime.utcnow) |
| 37 | + |
| 38 | + |
| 39 | +class PySUS: |
| 40 | + def __init__(self, db_path: str = CACHEPATH / "config.db"): |
| 41 | + db_path = Path(db_path) |
| 42 | + db_path.parent.mkdir(parents=True, exist_ok=True) |
| 43 | + |
| 44 | + self.engine = create_engine(f"duckdb:///{db_path}") |
| 45 | + Base.metadata.create_all(self.engine) |
| 46 | + self.Session = sessionmaker(bind=self.engine) |
| 47 | + |
| 48 | + self._ducklake: DuckLakeClient | None = None |
| 49 | + self._ftp: FTPClient | None = None |
| 50 | + self._dadosgov: DadosGovClient | None = None |
| 51 | + |
| 52 | + async def __aenter__(self): |
| 53 | + self._ducklake = DuckLakeClient(engine=self.engine) |
| 54 | + await self._ducklake._load_catalog() |
| 55 | + self._attach_client_catalog("ducklake", self._ducklake.catalog_path) |
| 56 | + return self |
| 57 | + |
| 58 | + async def get_dadosgov(self, access_token: str) -> DadosGovClient: |
| 59 | + if self._dadosgov is None: |
| 60 | + self._dadosgov = DadosGovClient() |
| 61 | + await self._dadosgov.connect(token=access_token) |
| 62 | + return self._dadosgov |
| 63 | + |
| 64 | + async def get_ftp(self) -> FTPClient: |
| 65 | + if self._ftp is None: |
| 66 | + self._ftp = FTPClient() |
| 67 | + await self._ftp.connect() |
| 68 | + return self._ftp |
| 69 | + |
| 70 | + async def get_local_file( |
| 71 | + self, |
| 72 | + file: BaseRemoteFile, |
| 73 | + ) -> BaseLocalFile | None: |
| 74 | + from pysus.api.extensions import ExtensionFactory |
| 75 | + |
| 76 | + client_name = file.client.name.lower() |
| 77 | + remote_path = file.path |
| 78 | + |
| 79 | + with self.Session() as session: |
| 80 | + records = ( |
| 81 | + session.query(LocalFileState) |
| 82 | + .filter_by( |
| 83 | + remote_path=remote_path, |
| 84 | + client_name=client_name, |
| 85 | + status=DownloadStatus.COMPLETED, |
| 86 | + ) |
| 87 | + .all() |
| 88 | + ) |
| 89 | + |
| 90 | + if not records: |
| 91 | + return None |
| 92 | + |
| 93 | + parquet_version = next( |
| 94 | + (r for r in records if r.path.endswith(".parquet")), None |
| 95 | + ) |
| 96 | + file = parquet_version or records[0] |
| 97 | + |
| 98 | + return await ExtensionFactory.instantiate(file.path) |
| 99 | + |
| 100 | + def _attach_client_catalog(self, name: str, path: str): |
| 101 | + abs_path = str(Path(path).absolute()) |
| 102 | + with self.engine.connect() as conn: |
| 103 | + q = "SELECT database_name FROM duckdb_databases() WHERE path = ?" |
| 104 | + existing = conn.exec_driver_sql(q, (abs_path,)).fetchone() |
| 105 | + |
| 106 | + if not existing: |
| 107 | + conn.exec_driver_sql(f"ATTACH '{abs_path}' AS { |
| 108 | + name} (READ_ONLY)") |
| 109 | + |
| 110 | + async def __aexit__(self, exc_type, exc_val, exc_tb): |
| 111 | + if self._ducklake: |
| 112 | + await self._ducklake.close() |
| 113 | + if self._ftp: |
| 114 | + await self._ftp.close() |
| 115 | + if self._dadosgov: |
| 116 | + await self._dadosgov.close() |
| 117 | + self.engine.dispose() |
| 118 | + |
| 119 | + def _get_dest_path(self, client_name: str, remote_path: str) -> Path: |
| 120 | + return CACHEPATH / "downloads" / client_name / remote_path.lstrip("/") |
| 121 | + |
| 122 | + async def _update_state( |
| 123 | + self, |
| 124 | + local_path: Path, |
| 125 | + remote_path: str, |
| 126 | + client_name: str, |
| 127 | + status: DownloadStatus, |
| 128 | + year: int = None, |
| 129 | + month: int = None, |
| 130 | + state: str = None, |
| 131 | + ): |
| 132 | + with self.Session() as session: |
| 133 | + record = ( |
| 134 | + session.query(LocalFileState).filter_by( |
| 135 | + path=str(local_path)).first() |
| 136 | + ) |
| 137 | + if not record: |
| 138 | + record = LocalFileState( |
| 139 | + path=str(local_path), |
| 140 | + remote_path=remote_path, |
| 141 | + client_name=client_name, |
| 142 | + year=year, |
| 143 | + month=month, |
| 144 | + state=state, |
| 145 | + ) |
| 146 | + session.add(record) |
| 147 | + |
| 148 | + record.status = status |
| 149 | + record.last_synced = datetime.utcnow() |
| 150 | + session.commit() |
| 151 | + |
| 152 | + async def download( |
| 153 | + self, |
| 154 | + file: BaseRemoteFile, |
| 155 | + token: str = None, |
| 156 | + callback: Callable = None, |
| 157 | + ): |
| 158 | + from pysus.api.extensions import ExtensionFactory |
| 159 | + |
| 160 | + existing_local = await self.get_local_file(file) |
| 161 | + if existing_local and existing_local.path.exists(): |
| 162 | + return existing_local |
| 163 | + |
| 164 | + client_name = file.client.name.lower() |
| 165 | + remote_path = file.path |
| 166 | + local_path = self._get_dest_path(client_name, remote_path) |
| 167 | + |
| 168 | + local_path.parent.mkdir(parents=True, exist_ok=True) |
| 169 | + |
| 170 | + await self._update_state( |
| 171 | + local_path, remote_path, client_name, DownloadStatus.DOWNLOADING |
| 172 | + ) |
| 173 | + |
| 174 | + try: |
| 175 | + if client_name == "ducklake": |
| 176 | + await self._ducklake._download_file(file, local_path, callback) |
| 177 | + elif client_name == "ftp": |
| 178 | + client = await self.get_ftp() |
| 179 | + await client._download_file(file, local_path, callback) |
| 180 | + elif client_name == "dadosgov": |
| 181 | + client = await self.get_dadosgov(token) |
| 182 | + await client._download_file(file, local_path, callback) |
| 183 | + else: |
| 184 | + raise ValueError(f"No download logic for client: {client_name}") |
| 185 | + |
| 186 | + await self._update_state( |
| 187 | + local_path=local_path, |
| 188 | + remote_path=remote_path, |
| 189 | + client_name=client_name, |
| 190 | + status=DownloadStatus.DOWNLOADING, |
| 191 | + year=file.year, |
| 192 | + month=file.month, |
| 193 | + state=file.state, |
| 194 | + ) |
| 195 | + return await ExtensionFactory.instantiate(local_path) |
| 196 | + |
| 197 | + except Exception: |
| 198 | + await self._update_state( |
| 199 | + local_path, remote_path, client_name, DownloadStatus.FAILED |
| 200 | + ) |
| 201 | + raise |
| 202 | + |
| 203 | + async def download_to_parquet( |
| 204 | + self, |
| 205 | + file: BaseRemoteFile, |
| 206 | + token: str = None, |
| 207 | + callback: Callable = None, |
| 208 | + ): |
| 209 | + local_file = await self.download( |
| 210 | + file=file, |
| 211 | + token=token, |
| 212 | + callback=callback, |
| 213 | + ) |
| 214 | + |
| 215 | + if hasattr(local_file, "to_parquet"): |
| 216 | + parquet_file = await local_file.to_parquet() |
| 217 | + |
| 218 | + await self._update_state( |
| 219 | + local_path=parquet_file.path, |
| 220 | + remote_path=file.path, |
| 221 | + client_name=file.client.name.lower(), |
| 222 | + status=DownloadStatus.COMPLETED, |
| 223 | + ) |
| 224 | + return parquet_file |
| 225 | + return local_file |
0 commit comments