|
4 | 4 |
|
5 | 5 | import pytest |
6 | 6 | import sqlalchemy as sa |
| 7 | +from graphon.model_runtime.entities.model_entities import ModelType |
7 | 8 | from sqlalchemy import exc as sa_exc |
8 | 9 | from sqlalchemy import insert |
9 | 10 | from sqlalchemy.engine import Connection, Engine |
@@ -58,6 +59,13 @@ class _ColumnTest(_Base): |
58 | 59 | long_value: Mapped[_EnumWithLongValue] = mapped_column(EnumText(enum_class=_EnumWithLongValue), nullable=False) |
59 | 60 |
|
60 | 61 |
|
| 62 | +class _LegacyModelTypeRecord(_Base): |
| 63 | + __tablename__ = "enum_text_legacy_model_type_test" |
| 64 | + |
| 65 | + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) |
| 66 | + model_type: Mapped[ModelType] = mapped_column(EnumText(enum_class=ModelType), nullable=False) |
| 67 | + |
| 68 | + |
61 | 69 | def _first[T](it: Iterable[T]) -> T: |
62 | 70 | ls = list(it) |
63 | 71 | if not ls: |
@@ -201,3 +209,23 @@ def test_select_invalid_values(self, engine_with_containers: Engine): |
201 | 209 | _user = session.query(_User).where(_User.id == 1).first() |
202 | 210 |
|
203 | 211 | assert str(exc.value) == "'invalid' is not a valid _UserType" |
| 212 | + |
| 213 | + def test_select_legacy_model_type_values(self, engine_with_containers: Engine): |
| 214 | + insertion_sql = """ |
| 215 | + INSERT INTO enum_text_legacy_model_type_test (id, model_type) VALUES |
| 216 | + (1, 'text-generation'), |
| 217 | + (2, 'embeddings'), |
| 218 | + (3, 'reranking'); |
| 219 | + """ |
| 220 | + with Session(engine_with_containers) as session: |
| 221 | + session.execute(sa.text(insertion_sql)) |
| 222 | + session.commit() |
| 223 | + |
| 224 | + with Session(engine_with_containers) as session: |
| 225 | + records = session.query(_LegacyModelTypeRecord).order_by(_LegacyModelTypeRecord.id).all() |
| 226 | + |
| 227 | + assert [record.model_type for record in records] == [ |
| 228 | + ModelType.LLM, |
| 229 | + ModelType.TEXT_EMBEDDING, |
| 230 | + ModelType.RERANK, |
| 231 | + ] |
0 commit comments