Skip to content

Commit efe54bc

Browse files
IthacaDreamHanqingZ
authored andcommitted
fix: legacy model_type deserialization regression (langgenius#34717)
1 parent 8aa43ea commit efe54bc

2 files changed

Lines changed: 37 additions & 3 deletions

File tree

api/models/types.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import enum
22
import uuid
3-
from typing import Any
3+
from typing import Any, cast
44

55
import sqlalchemy as sa
66
from sqlalchemy import CHAR, TEXT, VARCHAR, LargeBinary, TypeDecorator
@@ -143,8 +143,14 @@ def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
143143
def process_result_value(self, value: str | None, dialect: Dialect) -> T | None:
144144
if value is None or value == "":
145145
return None
146-
# Type annotation guarantees value is str at this point
147-
return self._enum_class(value)
146+
try:
147+
# Type annotation guarantees value is str at this point
148+
return self._enum_class(value)
149+
except ValueError:
150+
value_of = getattr(self._enum_class, "value_of", None)
151+
if callable(value_of):
152+
return cast(T, value_of(value))
153+
raise
148154

149155
def compare_values(self, x: T | None, y: T | None) -> bool:
150156
if x is None or y is None:

api/tests/test_containers_integration_tests/models/test_types_enum_text.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pytest
66
import sqlalchemy as sa
7+
from graphon.model_runtime.entities.model_entities import ModelType
78
from sqlalchemy import exc as sa_exc
89
from sqlalchemy import insert
910
from sqlalchemy.engine import Connection, Engine
@@ -58,6 +59,13 @@ class _ColumnTest(_Base):
5859
long_value: Mapped[_EnumWithLongValue] = mapped_column(EnumText(enum_class=_EnumWithLongValue), nullable=False)
5960

6061

62+
class _LegacyModelTypeRecord(_Base):
63+
__tablename__ = "enum_text_legacy_model_type_test"
64+
65+
id: Mapped[int] = mapped_column(sa.Integer, primary_key=True)
66+
model_type: Mapped[ModelType] = mapped_column(EnumText(enum_class=ModelType), nullable=False)
67+
68+
6169
def _first[T](it: Iterable[T]) -> T:
6270
ls = list(it)
6371
if not ls:
@@ -201,3 +209,23 @@ def test_select_invalid_values(self, engine_with_containers: Engine):
201209
_user = session.query(_User).where(_User.id == 1).first()
202210

203211
assert str(exc.value) == "'invalid' is not a valid _UserType"
212+
213+
def test_select_legacy_model_type_values(self, engine_with_containers: Engine):
214+
insertion_sql = """
215+
INSERT INTO enum_text_legacy_model_type_test (id, model_type) VALUES
216+
(1, 'text-generation'),
217+
(2, 'embeddings'),
218+
(3, 'reranking');
219+
"""
220+
with Session(engine_with_containers) as session:
221+
session.execute(sa.text(insertion_sql))
222+
session.commit()
223+
224+
with Session(engine_with_containers) as session:
225+
records = session.query(_LegacyModelTypeRecord).order_by(_LegacyModelTypeRecord.id).all()
226+
227+
assert [record.model_type for record in records] == [
228+
ModelType.LLM,
229+
ModelType.TEXT_EMBEDDING,
230+
ModelType.RERANK,
231+
]

0 commit comments

Comments
 (0)