diff --git a/src/webapp/database.py b/src/webapp/database.py index 28dcc139..7fe974b0 100644 --- a/src/webapp/database.py +++ b/src/webapp/database.py @@ -2,10 +2,10 @@ import uuid import datetime -from typing import Set, List +from typing import Set, List, Any from contextvars import ContextVar +import enum import sqlalchemy -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.mutable import MutableDict, MutableList from sqlalchemy import ( Column, @@ -16,17 +16,33 @@ String, UniqueConstraint, Text, + Enum, + Boolean, JSON, Integer, BigInteger, + Index, + event, +) +from sqlalchemy.orm import ( + DeclarativeBase, + sessionmaker, + Session, + relationship, + mapped_column, + Mapped, + Mapper, ) -from sqlalchemy.orm import sessionmaker, Session, relationship, mapped_column, Mapped from sqlalchemy.sql import func from sqlalchemy.pool import StaticPool from .config import engine_vars, ssl_env_vars, setup_database_vars from .authn import get_password_hash, get_api_key_hash -Base = declarative_base() + +class Base(DeclarativeBase): + pass + + LocalSession = None local_session: ContextVar[Session] = ContextVar("local_session") db_engine = None @@ -44,7 +60,21 @@ DATETIME_TESTING = datetime.datetime(2024, 12, 26, 19, 37, 59, 753357) -def init_db(env: str): +@event.listens_for(Mapper, "before_insert") +@event.listens_for(Mapper, "before_update") +def validate_string_lengths(mapper, connection, target): + for column in mapper.columns: + col_type = column.type + if isinstance(col_type, String) and col_type.length: + val = getattr(target, column.name, None) + if isinstance(val, str) and len(val) > col_type.length: + raise ValueError( + f"Value for '{column.name}' exceeds max length " + f"{col_type.length}: {len(val)} characters provided" + ) + + +def init_db(env: str) -> None: """Initialize the database for LOCAL and DEV environemtns for ease of use.""" # add some sample users to the database for development utility. if env not in ("LOCAL", "DEV"): @@ -101,7 +131,9 @@ class InstTable(Base): all other tables except for AccountHistory and JobTable.""" __tablename__ = "inst" - id = Column(Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4) + id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) # Linked children tables. accounts: Mapped[Set["AccountTable"]] = relationship(back_populates="inst") @@ -112,22 +144,43 @@ class InstTable(Base): back_populates="inst" ) models: Mapped[Set["ModelTable"]] = relationship(back_populates="inst") + schemas_registry: Mapped[List["SchemaRegistryTable"]] = relationship( + "SchemaRegistryTable", back_populates="inst", cascade="all, delete-orphan" + ) - name = Column(String(VAR_CHAR_STANDARD_LENGTH), nullable=False, unique=True) + name: Mapped[str] = mapped_column( + String(VAR_CHAR_STANDARD_LENGTH), nullable=False, unique=True + ) # If retention unset, the Datakind default is used. File-level retentions overrides # this value. retention_days: Mapped[int] = mapped_column(nullable=True) # The emails for which self sign up will be allowed for this institution and will automatically be assigned to this institution. # The dict structure is {email: AccessType string} - allowed_emails = Column(MutableDict.as_mutable(JSON)) + allowed_emails: Mapped[dict[str, str]] = mapped_column( + MutableDict.as_mutable(JSON()), + nullable=False, + default=dict, + ) # Schemas that are allowed for validation. - schemas = Column(MutableList.as_mutable(JSON)) - state = Column(String(VAR_CHAR_LENGTH), nullable=True) + schemas: Mapped[list[str]] = mapped_column( + MutableList.as_mutable(JSON()), nullable=False, default=list + ) + state: Mapped[str | None] = mapped_column(String(VAR_CHAR_LENGTH), nullable=True) # Only populated for PDP schools. - pdp_id = Column(String(VAR_CHAR_LENGTH), nullable=True) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), onupdate=func.now()) - created_by = Column(Uuid(as_uuid=True), nullable=True) + pdp_id: Mapped[str | None] = mapped_column(String(VAR_CHAR_LENGTH), nullable=True) + created_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + nullable=False, + default=func.now(), + ) + updated_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), + onupdate=func.now(), + nullable=False, + default=func.now(), + ) + created_by: Mapped[uuid.UUID] = mapped_column(Uuid(as_uuid=True), nullable=True) # Within the institutions, the set of name + state should be unique __table_args__ = (UniqueConstraint("name", "state", name="inst_name_state_uc"),) @@ -137,27 +190,41 @@ class ApiKeyTable(Base): """API KEYS should match the format generated by `openssl rand -hex 32`""" __tablename__ = "apikey" - id = Column(Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4) + id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) # A hash of the key_value, so the user must store the generated key_value secretly. - hashed_key_value = Column( + hashed_key_value: Mapped[str | None] = mapped_column( String(VAR_CHAR_STANDARD_LENGTH), nullable=False, unique=True ) # Set the foreign key to link to the institution table. - inst_id = Column( + inst_id: Mapped[uuid.UUID] = mapped_column( Uuid(as_uuid=True), ForeignKey("inst.id", ondelete="CASCADE"), nullable=True, ) inst: Mapped["InstTable"] = relationship(back_populates="apikeys") - created_by = Column(Uuid(as_uuid=True), nullable=False) - notes = Column(String(VAR_CHAR_STANDARD_LENGTH), nullable=True) + created_by: Mapped[uuid.UUID] = mapped_column(Uuid(as_uuid=True), nullable=False) + notes: Mapped[str | None] = mapped_column( + String(VAR_CHAR_STANDARD_LENGTH), nullable=True + ) # Whether this key allows changing the enduser. ONLY SET FOR THE FRONTEND KEY. Can only be set when the API key has DATAKINDER access type as this allows Datakinder level endusers. allows_enduser: Mapped[bool] = mapped_column(nullable=True) - access_type = Column(String(VAR_CHAR_LENGTH), nullable=False) - created_at = mapped_column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + access_type: Mapped[str] = mapped_column(String(VAR_CHAR_LENGTH), nullable=False) + created_at = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + nullable=False, + default=func.now(), + ) + updated_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), + onupdate=func.now(), + nullable=False, + default=func.now(), + ) # API key must be valid and not deleted. deleted: Mapped[bool] = mapped_column(nullable=True) valid: Mapped[bool] = mapped_column(nullable=True) @@ -174,7 +241,9 @@ class AccountTable(Base): The user accounts table""" __tablename__ = "users" # Name to be compliant with Laravel. - id = Column(Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4) + id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) # Set account histories to be children account_histories: Mapped[List["AccountHistoryTable"]] = relationship( @@ -182,44 +251,74 @@ class AccountTable(Base): ) # Set the foreign key to link to the institution table. - inst_id = Column( + inst_id: Mapped[uuid.UUID] = mapped_column( Uuid(as_uuid=True), ForeignKey("inst.id", ondelete="CASCADE"), nullable=True, ) inst: Mapped["InstTable"] = relationship(back_populates="accounts") - name = Column(String(VAR_CHAR_STANDARD_LENGTH), nullable=False) - email = Column(String(VAR_CHAR_STANDARD_LENGTH), nullable=False, unique=True) - google_id = Column(String(VAR_CHAR_STANDARD_LENGTH), nullable=True) - azure_id = Column(String(VAR_CHAR_STANDARD_LENGTH), nullable=True) + name: Mapped[str] = mapped_column(String(VAR_CHAR_STANDARD_LENGTH), nullable=False) + email: Mapped[str] = mapped_column( + String(VAR_CHAR_STANDARD_LENGTH), nullable=False, unique=True + ) + google_id: Mapped[str | None] = mapped_column( + String(VAR_CHAR_STANDARD_LENGTH), nullable=True + ) + azure_id: Mapped[str | None] = mapped_column( + String(VAR_CHAR_STANDARD_LENGTH), nullable=True + ) - email_verified_at = Column(DateTime(timezone=True), nullable=True) - password = Column(String(VAR_CHAR_STANDARD_LENGTH), nullable=False) - two_factor_secret = Column(Text, nullable=True) - two_factor_recovery_codes = Column(Text, nullable=True) - two_factor_confirmed_at = Column(DateTime(timezone=True), nullable=True) + email_verified_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), nullable=True + ) + password: Mapped[str] = mapped_column( + String(VAR_CHAR_STANDARD_LENGTH), nullable=False + ) + two_factor_secret: Mapped[str | None] = mapped_column(Text, nullable=True) + two_factor_recovery_codes: Mapped[str | None] = mapped_column(Text, nullable=True) + two_factor_confirmed_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), nullable=True + ) - remember_token = Column(String(VAR_CHAR_STANDARD_LENGTH), nullable=True) + remember_token: Mapped[str | None] = mapped_column( + String(VAR_CHAR_STANDARD_LENGTH), nullable=True + ) # Required for team integration with laravel - current_team_id = Column(Uuid(as_uuid=True), nullable=True) - access_type = Column(String(VAR_CHAR_LENGTH), nullable=True) - # profile_photo_path = Column(String(VAR_CHAR_LENGTH), nullable=True) - created_at = mapped_column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + current_team_id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), nullable=True + ) + access_type: Mapped[str | None] = mapped_column( + String(VAR_CHAR_LENGTH), nullable=True + ) + # profile_photo_path : Mapped[dict[str, str]] = mapped_column(String(VAR_CHAR_LENGTH), nullable=True) + created_at = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + nullable=False, + default=func.now(), + ) + updated_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), + onupdate=func.now(), + nullable=False, + default=func.now(), + ) class AccountHistoryTable(Base): """The user history table""" __tablename__ = "account_history" - id = Column(Integer, primary_key=True) # Auto-increment should be default - timestamp = Column( + id: Mapped[int] = mapped_column( + Integer, primary_key=True + ) # Auto-increment should be default + timestamp: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False ) # Set the parent foreign key to link to the users table. - account_id = Column( + account_id: Mapped[uuid.UUID] = mapped_column( Uuid(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, @@ -227,15 +326,17 @@ class AccountHistoryTable(Base): account: Mapped["AccountTable"] = relationship(back_populates="account_histories") # This field is nullable if the action was taken by a Datakinder. - inst_id = Column( + inst_id: Mapped[uuid.UUID] = mapped_column( Uuid(as_uuid=True), ForeignKey("inst.id", ondelete="CASCADE"), nullable=True, ) inst: Mapped["InstTable"] = relationship(back_populates="account_histories") - action = Column(String(VAR_CHAR_STANDARD_LENGTH), nullable=False) - resource_id = Column(Uuid(as_uuid=True), nullable=False) + action: Mapped[str] = mapped_column( + String(VAR_CHAR_STANDARD_LENGTH), nullable=False + ) + resource_id: Mapped[uuid.UUID] = mapped_column(Uuid(as_uuid=True), nullable=False) # An intermediary association table allows bi-directional many-to-many between files and batches. @@ -261,13 +362,15 @@ class FileTable(Base): """The file table""" __tablename__ = "file" - name = Column(String(VAR_CHAR_STANDARD_LENGTH), nullable=False) - id = Column(Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4) + name: Mapped[str] = mapped_column(String(VAR_CHAR_STANDARD_LENGTH), nullable=False) + id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) batches: Mapped[Set["BatchTable"]] = relationship( secondary=association_table, back_populates="files" ) # Set the parent foreign key to link to the institution table. - inst_id = Column( + inst_id: Mapped[uuid.UUID] = mapped_column( Uuid(as_uuid=True), ForeignKey("inst.id", ondelete="CASCADE"), nullable=False, @@ -276,25 +379,41 @@ class FileTable(Base): # The size to the nearest mb. # size_mb: Mapped[int] = mapped_column(nullable=False) # Who uploaded the file. For SST generated files, this field would be null. - uploader = Column(Uuid(as_uuid=True), nullable=True) + uploader: Mapped[uuid.UUID] = mapped_column(Uuid(as_uuid=True), nullable=True) # Can be PDP_SFTP, MANUAL_UPLOAD etc. May be empty for generated files. - source = Column(String(VAR_CHAR_LENGTH), nullable=True) + source: Mapped[str | None] = mapped_column(String(VAR_CHAR_LENGTH), nullable=True) # The schema type(s) of this file. - schemas = Column(MutableList.as_mutable(JSON), nullable=False) + schemas: Mapped[list[str]] = mapped_column( + MutableList.as_mutable((JSON())), + nullable=False, + default=list, + ) # If null, the following is non-deleted. # The deleted field indicates whether there is a pending deletion request on the data. # The data may stil be available to Datakind debug role in a soft-delete state but for all # intents and purposes is no longer accessible by the app. deleted: Mapped[bool] = mapped_column(nullable=True) # When the deletion request was made - deleted_at = Column(DateTime(timezone=True), nullable=True) + deleted_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), nullable=True + ) retention_days: Mapped[int] = mapped_column(nullable=True) # Whether the file was generated by SST. (e.g. was it input or output) - sst_generated: Mapped[bool] = mapped_column(nullable=False) + sst_generated: Mapped[bool] = mapped_column(nullable=False, default=False) # Whether the file was approved (in the case of output) or valid for input. - valid: Mapped[bool] = mapped_column(nullable=False) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + valid: Mapped[bool] = mapped_column(nullable=False, default=False) + created_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + nullable=False, + default=func.now(), + ) + updated_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), + onupdate=func.now(), + nullable=False, + default=func.now(), + ) # Within a given institution, there should be no duplicated file names. __table_args__ = (UniqueConstraint("name", "inst_id", name="file_name_inst_uc"),) @@ -304,10 +423,12 @@ class BatchTable(Base): """The batch table""" __tablename__ = "batch" - id = Column(Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4) + id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) # Set the parent foreign key to link to the institution table. - inst_id = Column( + inst_id: Mapped[uuid.UUID] = mapped_column( Uuid(as_uuid=True), ForeignKey("inst.id", ondelete="CASCADE"), nullable=False, @@ -318,18 +439,30 @@ class BatchTable(Base): secondary=association_table, back_populates="batches" ) - name = Column(String(VAR_CHAR_STANDARD_LENGTH), nullable=False) - created_by = Column(Uuid(as_uuid=True)) + name: Mapped[str] = mapped_column(String(VAR_CHAR_STANDARD_LENGTH), nullable=False) + created_by: Mapped[uuid.UUID] = mapped_column(Uuid(as_uuid=True)) # If null, the following is non-deleted. deleted: Mapped[bool] = mapped_column(nullable=True) # If true, the batch is ready for use. completed: Mapped[bool] = mapped_column(nullable=True) # The time the deletion request was set. - deleted_at = Column(DateTime(timezone=True), nullable=True) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + deleted_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), nullable=True + ) + created_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + nullable=False, + default=func.now(), + ) + updated_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), + onupdate=func.now(), + nullable=False, + default=func.now(), + ) # If a batch is deleted, the uuid of the user in the updated_by section is the deleter. - updated_by = Column(Uuid(as_uuid=True), nullable=True) + updated_by: Mapped[uuid.UUID] = mapped_column(Uuid(as_uuid=True), nullable=True) # Within a given institution, there should be no duplicated batch names. __table_args__ = (UniqueConstraint("name", "inst_id", name="batch_name_inst_uc"),) @@ -338,10 +471,12 @@ class ModelTable(Base): """The model table""" __tablename__ = "model" - id = Column(Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4) + id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) # Set the parent foreign key to link to the institution table. - inst_id = Column( + inst_id: Mapped[uuid.UUID] = mapped_column( Uuid(as_uuid=True), ForeignKey("inst.id", ondelete="CASCADE"), nullable=False, @@ -350,20 +485,32 @@ class ModelTable(Base): jobs: Mapped[Set["JobTable"]] = relationship(back_populates="model") - name = Column(String(VAR_CHAR_STANDARD_LENGTH), nullable=False) + name: Mapped[str] = mapped_column(String(VAR_CHAR_STANDARD_LENGTH), nullable=False) # What configuration of schemas are allowed (list of maps e.g. [PDP Course : 1 + PDP Cohort : 1, X_schema :1 + Y_schema: 2]) - schema_configs = Column(JSON, nullable=True) - created_by = Column(Uuid(as_uuid=True), nullable=True) + schema_configs: Mapped[dict[str, str]] = mapped_column(JSON(), nullable=True) + created_by: Mapped[uuid.UUID] = mapped_column(Uuid(as_uuid=True), nullable=True) # If null, the following is non-deleted. deleted: Mapped[bool] = mapped_column(nullable=True) # If true, the model has been approved and is ready for use. valid: Mapped[bool] = mapped_column(nullable=True) # The time the deletion request was set. - deleted_at = Column(DateTime(timezone=True), nullable=True) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + deleted_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), nullable=True + ) + created_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + nullable=False, + default=func.now(), + ) + updated_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), + onupdate=func.now(), + nullable=False, + default=func.now(), + ) # version is unused. version is not currently supported. The webapp only knows about the name of the model and any usages of a model will only use the live version. - version = Column(Integer, default=0) + version: Mapped[int] = mapped_column(Integer, default=0) # Within a given institution, there should be no duplicated model names. __table_args__ = (UniqueConstraint("name", "inst_id", name="model_name_inst_uc"),) @@ -373,28 +520,123 @@ class JobTable(Base): """The job table""" __tablename__ = "job" - id = Column(BigInteger, primary_key=True) + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) # Set the parent foreign key to link to the institution table. - model_id = Column( + model_id: Mapped[uuid.UUID] = mapped_column( Uuid(as_uuid=True), ForeignKey("model.id", ondelete="CASCADE"), nullable=False, ) model: Mapped["ModelTable"] = relationship(back_populates="jobs") - created_by = Column(Uuid(as_uuid=True), nullable=False) + created_by: Mapped[uuid.UUID] = mapped_column(Uuid(as_uuid=True), nullable=False) # The time the deletion request was set. - triggered_at = Column(DateTime(timezone=True), nullable=False) - batch_name = Column(String(VAR_CHAR_STANDARD_LENGTH), nullable=False) + triggered_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + batch_name: Mapped[str] = mapped_column( + String(VAR_CHAR_STANDARD_LENGTH), nullable=False + ) # The following will be empty if not completed or if job errored out. Getting additional details will require a call to the Databricks table. - output_filename = Column(String(VAR_CHAR_STANDARD_LENGTH), nullable=True) + output_filename: Mapped[str | None] = mapped_column( + String(VAR_CHAR_STANDARD_LENGTH), nullable=True + ) # Whether the file was approved. output_valid: Mapped[bool] = mapped_column(nullable=True, default=False) - err_msg = Column(String(VAR_CHAR_STANDARD_LENGTH), nullable=True) + err_msg: Mapped[str | None] = mapped_column( + String(VAR_CHAR_STANDARD_LENGTH), nullable=True + ) completed: Mapped[bool] = mapped_column(nullable=True) +class DocType(enum.Enum): + base = "base" + extension = "extension" + + +class SchemaRegistryTable(Base): + """ + Stores versioned schema documents: + - Base schema (doc_type=base, is_pdp=False, inst_id NULL) + - PDP shared extension (doc_type=extension, is_pdp=True, inst_id NULL) + - Custom institution extension (doc_type=extension, is_pdp=False, inst_id=) + Layers can reference a parent (extends_schema_id) that they extend. + """ + + __tablename__ = "schema_registry" + schema_id: Mapped[int] = mapped_column( + Integer, primary_key=True, autoincrement=True + ) + doc_type: Mapped[DocType] = mapped_column( + Enum(DocType, native_enum=False), nullable=False + ) + # Nullable: NULL for base and PDP shared extension + inst_id: Mapped[uuid.UUID | None] = mapped_column( + ForeignKey("inst.id", ondelete="RESTRICT", onupdate="CASCADE"), nullable=True + ) + is_pdp: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + version_label: Mapped[str] = mapped_column( + String(VAR_CHAR_STANDARD_LENGTH), nullable=False + ) + extends_schema_id: Mapped[int | None] = mapped_column( + Integer, + ForeignKey( + "schema_registry.schema_id", ondelete="SET NULL", onupdate="CASCADE" + ), + nullable=True, + ) + json_doc: Mapped[dict] = mapped_column( + MutableDict.as_mutable(JSON()), nullable=False + ) + is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + created_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + nullable=False, + default=func.now(), + ) + + # ---------------- Relationships ---------------- + inst: Mapped["InstTable | None"] = relationship( + "InstTable", + back_populates="schemas_registry", # we'll add this new relationship on InstTable (see below) + ) + + parent_schema: Mapped["SchemaRegistryTable | None"] = relationship( + "SchemaRegistryTable", + remote_side="SchemaRegistryTable.schema_id", + foreign_keys=[extends_schema_id], + back_populates="child_schemas", + ) + + child_schemas: Mapped[List["SchemaRegistryTable"]] = relationship( + "SchemaRegistryTable", + back_populates="parent_schema", + cascade="all, delete-orphan", + ) + + __table_args__ = ( + UniqueConstraint("doc_type", "version_label", name="uq_base_version"), + UniqueConstraint("is_pdp", "version_label", name="uq_pdp_version"), + UniqueConstraint("inst_id", "version_label", name="uq_inst_version"), + Index("idx_schema_active_base", "doc_type", "is_active"), + Index("idx_schema_active_pdp", "is_pdp", "is_active"), + Index("idx_schema_active_inst", "inst_id", "is_active"), + ) + + # Convenience: identify logical namespace + @property + def namespace(self) -> str: + if self.doc_type == DocType.base: + return "base" + if self.is_pdp: + return "pdp" + if self.inst_id: + return f"inst:{self.inst_id}" + return "unknown" + + def get_session(): """Get the session.""" sess: Session = LocalSession() @@ -435,7 +677,7 @@ def connect_tcp_socket( username=engine_args["DB_USER"], password=engine_args["DB_PASS"], host=engine_args["INSTANCE_HOST"], - port=engine_args["DB_PORT"], + port=int(engine_args["DB_PORT"]), database=engine_args["DB_NAME"], ), connect_args=connect_args, @@ -471,7 +713,7 @@ def init_connection_pool() -> sqlalchemy.engine.Engine: return connect_tcp_socket(engine_vars, ssl_args) -def setup_db(env: str): +def setup_db(env: str) -> Any: """Setup Database. Called by all environments.""" # initialize connection pool global db_engine diff --git a/src/webapp/gcsutil.py b/src/webapp/gcsutil.py index f871e272..91a988db 100644 --- a/src/webapp/gcsutil.py +++ b/src/webapp/gcsutil.py @@ -8,7 +8,7 @@ from .config import gcs_vars, databricks_vars from .validation import validate_file_reader -from typing import Any, List +from typing import Any, List, Optional, Dict import logging # Set the logging @@ -241,7 +241,7 @@ def download_file( raise ValueError(file_name + ": File not found.") blob.download_to_filename(destination_file_name) - def move_file(self, bucket_name: str, prev_name: str, new_name: str): + def move_file(self, bucket_name: str, prev_name: str, new_name: str) -> None: """Rename a file.""" storage_client = storage.Client() bucket = storage_client.bucket(bucket_name) @@ -256,7 +256,7 @@ def move_file(self, bucket_name: str, prev_name: str, new_name: str): bucket.copy_blob(blob, bucket, new_name) blob.delete() - def delete_file(self, bucket_name: str, file_name: str): + def delete_file(self, bucket_name: str, file_name: str) -> None: """Delete a file.""" storage_client = storage.Client() bucket = storage_client.bucket(bucket_name) @@ -268,7 +268,12 @@ def delete_file(self, bucket_name: str, file_name: str): blob.delete() def validate_file( - self, bucket_name: str, file_name: str, allowed_schemas: list[str] + self, + bucket_name: str, + file_name: str, + allowed_schemas: list[str], + base_schema: dict, + inst_schema: Optional[Dict[Any, Any]] = None, ) -> List[str]: """Validate that a file is one of the allowed schemas.""" client = storage.Client() @@ -278,7 +283,9 @@ def validate_file( schems: List[str] = [] try: with blob.open("r") as file: - schemas = validate_file_reader(file, allowed_schemas) + schemas = validate_file_reader( + file, allowed_schemas, base_schema, inst_schema + ) schems = [str(s) for s in schemas.get("schemas", [])] logging.debug( f"If you see this file validation was successful {schems}" @@ -294,7 +301,7 @@ def validate_file( logging.debug("If you see this file validation was complete") return schems - def get_file_contents(self, bucket_name: str, file_name: str): + def get_file_contents(self, bucket_name: str, file_name: str) -> Any: """Returns a file as a bytes object.""" storage_client = storage.Client() bucket = storage_client.get_bucket(bucket_name) diff --git a/src/webapp/routers/data.py b/src/webapp/routers/data.py index bc2b762d..932e75d2 100644 --- a/src/webapp/routers/data.py +++ b/src/webapp/routers/data.py @@ -37,6 +37,8 @@ BatchTable, FileTable, InstTable, + SchemaRegistryTable, + DocType, ) from ..databricks import DatabricksControl @@ -901,12 +903,66 @@ def validation_helper( allowed_schemas = infer_models_from_filename(file_name, "pdp") inferred_schemas: list[str] = [] + # ----------------------- Fetch base schema from DB ------------------------------- + base_schema = ( + local_session.get() + .execute( + select(SchemaRegistryTable.json_doc) + .where( + SchemaRegistryTable.doc_type == DocType.base, + SchemaRegistryTable.is_active.is_(True), + ) + .limit(1) + ) + .scalar_one_or_none() + ) + if base_schema is None: + raise RuntimeError("No active base schema found") + # ----------------------- Fetch inst specific extension schema from DB --------------------- + inst = ( + local_session.get() + .execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id))) + .scalar_one_or_none() + ) + if inst is None: + raise ValueError(f"Institution {inst_id} not found") + + if inst.pdp_id: # institution is PDP + inst_schema = ( + local_session.get() + .execute( + select(SchemaRegistryTable.json_doc) + .where( + SchemaRegistryTable.is_pdp.is_(True), + SchemaRegistryTable.is_active.is_(True), + ) + .limit(1) + ) + .scalar_one_or_none() + ) + else: # custom (or none) + inst_schema = ( + local_session.get() + .execute( + select(SchemaRegistryTable.json_doc) + .where( + SchemaRegistryTable.inst_id == inst.id, + SchemaRegistryTable.is_active.is_(True), + ) + .limit(1) + ) + .scalar_one_or_none() + ) + + # ----------------------- File validation logic logic -------------------------------------- try: inferred_schemas = storage_control.validate_file( get_external_bucket_name(inst_id), file_name, allowed_schemas, + base_schema, + inst_schema, ) logging.debug( f"!!!!!!!!!!Inferred Schemas was successful {list(inferred_schemas)}" @@ -1330,9 +1386,8 @@ def get_training_support_overview( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) -@router.get("/{inst_id}/training/model-cards/{run_id}/{model_name}") +@router.get("/{inst_id}/training/model-cards/{model_name}") def get_model_cards( - run_id: str, model_name: str, inst_id: str, current_user: Annotated[BaseUser, Depends(get_current_active_user)], diff --git a/src/webapp/routers/data_test.py b/src/webapp/routers/data_test.py index 90bd6030..cc529934 100644 --- a/src/webapp/routers/data_test.py +++ b/src/webapp/routers/data_test.py @@ -4,6 +4,7 @@ from unittest import mock from collections import Counter from fastapi.testclient import TestClient +from typing import Any import pytest import sqlalchemy from sqlalchemy.pool import StaticPool @@ -20,6 +21,8 @@ FileTable, BatchTable, InstTable, + SchemaRegistryTable, + DocType, Base, get_session, ) @@ -42,30 +45,30 @@ def counter_repr(x): return {frozenset(Counter(item).items()) for item in x} -def same_file_orderless(a_elem: DataInfo, b_elem: DataInfo): +def same_file_orderless(a_elem: DataInfo, b_elem: DataInfo): # type: ignore """Compares two DataInfo objects.""" if ( - a_elem["inst_id"] != b_elem["inst_id"] - or counter_repr(a_elem["batch_ids"]) != counter_repr(b_elem["batch_ids"]) - or a_elem["name"] != b_elem["name"] - or a_elem["uploader"] != b_elem["uploader"] - or a_elem["deleted"] != b_elem["deleted"] - or a_elem["source"] != b_elem["source"] - or a_elem["deletion_request_time"] != b_elem["deletion_request_time"] - or a_elem["retention_days"] != b_elem["retention_days"] - or a_elem["sst_generated"] != b_elem["sst_generated"] - or a_elem["valid"] != b_elem["valid"] - or a_elem["uploaded_date"] != b_elem["uploaded_date"] + a_elem["inst_id"] != b_elem["inst_id"] # type: ignore + or counter_repr(a_elem["batch_ids"]) != counter_repr(b_elem["batch_ids"]) # type: ignore + or a_elem["name"] != b_elem["name"] # type: ignore + or a_elem["uploader"] != b_elem["uploader"] # type: ignore + or a_elem["deleted"] != b_elem["deleted"] # type: ignore + or a_elem["source"] != b_elem["source"] # type: ignore + or a_elem["deletion_request_time"] != b_elem["deletion_request_time"] # type: ignore + or a_elem["retention_days"] != b_elem["retention_days"] # type: ignore + or a_elem["sst_generated"] != b_elem["sst_generated"] # type: ignore + or a_elem["valid"] != b_elem["valid"] # type: ignore + or a_elem["uploaded_date"] != b_elem["uploaded_date"] # type: ignore ): return False return True -def same_orderless(a: DataOverview, b: DataOverview): +def same_orderless(a: DataOverview, b: DataOverview) -> bool: """Compares two DataOverview objects.""" - for a_elem in a["batches"]: + for a_elem in a["batches"]: # type: ignore found = False - for b_elem in b["batches"]: + for b_elem in b["batches"]: # type: ignore if a_elem["batch_id"] != b_elem["batch_id"]: continue found = True @@ -82,9 +85,9 @@ def same_orderless(a: DataOverview, b: DataOverview): return False if not found: return False - for a_elem in a["files"]: + for a_elem in a["files"]: # type: ignore found = False - for b_elem in b["files"]: + for b_elem in b["files"]: # type: ignore if a_elem["data_id"] != b_elem["data_id"]: continue found = True @@ -157,6 +160,14 @@ def session_fixture(): created_at=DATETIME_TESTING, updated_at=DATETIME_TESTING, ), + SchemaRegistryTable( + doc_type=DocType.base, # ✅ fix this + is_pdp=False, + version_label="1.0.0", + json_doc={"version": "1.0.0", "base": {"data_models": {}}}, + is_active=True, + created_at=DATETIME_TESTING, + ), batch_1, file_1, FileTable( @@ -181,7 +192,7 @@ def session_fixture(): @pytest.fixture(name="client") -def client_fixture(session: sqlalchemy.orm.Session): +def client_fixture(session: sqlalchemy.orm.Session) -> Any: """Unit test mocks setup.""" def get_session_override(): @@ -203,7 +214,7 @@ def storage_control_override(): app.dependency_overrides.clear() -def test_read_inst_all_input_files(client: TestClient): +def test_read_inst_all_input_files(client: TestClient) -> Any: """Test GET /institutions//input.""" response = client.get("/institutions/" + uuid_to_str(UUID_INVALID) + "/input") @@ -217,9 +228,9 @@ def test_read_inst_all_input_files(client: TestClient): "/institutions/" + uuid_to_str(USER_VALID_INST_UUID) + "/input" ) assert response.status_code == 200 - assert same_orderless( + assert same_orderless( # type: ignore response.json(), - { + { # type: ignore "batches": [ { "batch_id": "5b2420f3103546ab90eb74d5df97de43", @@ -270,7 +281,7 @@ def test_read_inst_all_input_files(client: TestClient): ) -def test_read_inst_all_output_files(client: TestClient): +def test_read_inst_all_output_files(client: TestClient) -> Any: """Test GET /institutions//output.""" MOCK_STORAGE.list_blobs_in_folder.return_value = [] response = client.get("/institutions/" + uuid_to_str(UUID_INVALID) + "/output") @@ -285,9 +296,9 @@ def test_read_inst_all_output_files(client: TestClient): "/institutions/" + uuid_to_str(USER_VALID_INST_UUID) + "/output" ) assert response.status_code == 200 - assert same_orderless( + assert same_orderless( # type: ignore response.json(), - { + { # type: ignore "batches": [ { "batch_id": "5b2420f3103546ab90eb74d5df97de43", @@ -338,7 +349,7 @@ def test_read_inst_all_output_files(client: TestClient): ) -def test_read_batch_info(client: TestClient): +def test_read_batch_info(client: TestClient) -> Any: """Test GET /institutions//batch/.""" response = client.get( "/institutions/" @@ -360,9 +371,9 @@ def test_read_batch_info(client: TestClient): + uuid_to_str(BATCH_UUID) ) assert response.status_code == 200 - assert same_orderless( + assert same_orderless( # type: ignore response.json(), - { + { # type: ignore "batches": [ { "batch_id": "5b2420f3103546ab90eb74d5df97de43", @@ -413,7 +424,7 @@ def test_read_batch_info(client: TestClient): ) -def test_read_file_id_info(client: TestClient): +def test_read_file_id_info(client: TestClient) -> Any: """Test GET /institutions//file-id/.""" response = client.get( "/institutions/" @@ -435,9 +446,9 @@ def test_read_file_id_info(client: TestClient): + uuid_to_str(FILE_UUID_1) ) assert response.status_code == 200 - assert same_file_orderless( + assert same_file_orderless( # type: ignore response.json(), - { + { # type: ignore "name": "file_input_one", "data_id": "f0bb3a206d924254afed6a72f43c562a", "batch_ids": ["5b2420f3103546ab90eb74d5df97de43"], @@ -454,7 +465,7 @@ def test_read_file_id_info(client: TestClient): ) -def test_retrieve_file_as_bytes(client: TestClient): +def test_retrieve_file_as_bytes(client: TestClient) -> Any: """Test GET /institutions//output-file-contents/.""" response = client.get( "/institutions/" @@ -479,7 +490,7 @@ def test_retrieve_file_as_bytes(client: TestClient): assert response.text == '{"detail":"No such output file exists."}' -def test_create_batch(client: TestClient): +def test_create_batch(client: TestClient) -> None: """Test POST /institutions//batch.""" response = client.post( "/institutions/" + uuid_to_str(UUID_INVALID) + "/batch", @@ -517,7 +528,7 @@ def test_create_batch(client: TestClient): assert len(response.json()["file_names_to_ids"]) == 1 -def test_update_batch(client: TestClient): +def test_update_batch(client: TestClient) -> None: """Test PATCH /institutions//batch.""" response = client.patch( "/institutions/" @@ -555,7 +566,7 @@ def test_update_batch(client: TestClient): } -def test_validate_success_batch(client: TestClient): +def test_validate_success_batch(client: TestClient) -> None: """Test PATCH /institutions//batch.""" MOCK_STORAGE.validate_file.return_value = ["UNKNOWN"] @@ -606,7 +617,7 @@ def test_validate_success_batch(client: TestClient): assert response_sftp.json()["source"] == "PDP_SFTP" -def test_validate_failure_batch(client: TestClient): +def test_validate_failure_batch(client: TestClient) -> None: """Test PATCH /institutions//batch.""" MOCK_STORAGE.validate_file.return_value = ["COURSE"] # Authorized. diff --git a/src/webapp/validation.py b/src/webapp/validation.py index 2c13e047..90583fb6 100644 --- a/src/webapp/validation.py +++ b/src/webapp/validation.py @@ -5,7 +5,6 @@ from typing import Any import json -import os import re from typing import Union, List, Dict, Optional import logging @@ -16,9 +15,14 @@ from fuzzywuzzy import fuzz -def validate_file_reader(filename: str, allowed_schema: list[str]) -> dict[str, Any]: +def validate_file_reader( + filename: str, + allowed_schema: list[str], + base_schema: dict, + inst_schema: Optional[Dict[Any, Any]] = None, +) -> dict[str, Any]: """Validates given a filename.""" - return validate_dataset(filename, allowed_schema) + return validate_dataset(filename, base_schema, inst_schema, allowed_schema) class HardValidationError(Exception): @@ -153,6 +157,8 @@ def build_schema(specs: Dict[str, dict]) -> DataFrameSchema: def validate_dataset( filename: str, + base_schema: dict, + ext_schema: Optional[Dict[Any, Any]] = None, models: Union[str, List[str], None] = None, institution_id: str = "pdp", ) -> Dict[str, Any]: @@ -160,18 +166,6 @@ def validate_dataset( df = df.rename(columns={c: normalize_col(c) for c in df.columns}) incoming = set(df.columns) - # 1) load schemas - BASE_DIR = os.path.dirname(os.path.abspath(__file__)) - base_schema_path = os.path.join(BASE_DIR, "validation_schemas/base_schema.json") - base_schema = load_json(base_schema_path) - ext_schema = None - - extension_schema_path = os.path.join( - BASE_DIR, f"validation_schemas/{institution_id}_schema_extension.json" - ) - if extension_schema_path and os.path.exists(extension_schema_path): - ext_schema = load_json(extension_schema_path) - # 2) merge requested models if models is None: model_list = [] diff --git a/src/webapp/validation_test.py b/src/webapp/validation_test.py index aa69bb65..92bc1f48 100644 --- a/src/webapp/validation_test.py +++ b/src/webapp/validation_test.py @@ -28,11 +28,11 @@ } } -MOCK_EXT_SCHEMA = {"institutions": {"pdp": {"data_models": {}}}} +MOCK_EXT_SCHEMA: dict = {"institutions": {"pdp": {"data_models": {}}}} @pytest.fixture -def tmp_csv_file(tmp_path: Path): +def tmp_csv_file(tmp_path: Path) -> str: df = pd.DataFrame({"foo_col": [1, 2], "bar_col": ["a", "b"]}) file_path = tmp_path / "test.csv" df.to_csv(file_path, index=False) @@ -47,7 +47,12 @@ def test_validate_file_reader_passes(tmp_csv_file): mock_load.side_effect = lambda path: ( MOCK_BASE_SCHEMA if "base" in path else MOCK_EXT_SCHEMA ) - result = validate_file_reader(tmp_csv_file, ["test_model"]) + result = validate_file_reader( + tmp_csv_file, + ["test_model"], + base_schema=MOCK_BASE_SCHEMA, + inst_schema=MOCK_EXT_SCHEMA, + ) assert result["validation_status"] == "passed" assert result["schemas"] == ["test_model"] @@ -65,5 +70,10 @@ def test_validate_file_reader_fails_missing_required(tmp_path): MOCK_BASE_SCHEMA if "base" in path else MOCK_EXT_SCHEMA ) with pytest.raises(HardValidationError) as exc_info: - validate_file_reader(str(file_path), ["test_model"]) + validate_file_reader( + str(file_path), + ["test_model"], + base_schema=MOCK_BASE_SCHEMA, + inst_schema=MOCK_EXT_SCHEMA, + ) assert "Missing required columns" in str(exc_info.value)