Skip to content

Commit f42e88a

Browse files
Thomas CarrollThomas Carroll
authored andcommitted
refactor: use TypedDict
1 parent a43d4b6 commit f42e88a

2 files changed

Lines changed: 14 additions & 13 deletions

File tree

tests/functional/e2e/testcases/parser.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
import json
22
from pathlib import Path
3-
from typing import cast
3+
from typing import TypedDict, cast
44

55
# Path to the shared testcases directory in athena-protobufs
66
_REPO_ROOT = Path(__file__).parent.parent.parent.parent.parent
77
TESTCASES_DIR = _REPO_ROOT / "athena-protobufs" / "testcases"
88

99

10+
class TestCases(TypedDict):
11+
classification_labels: list[str]
12+
images: list[list[str | list[float]]]
13+
14+
1015
class AthenaTestCase:
1116
def __init__(
1217
self,
@@ -28,13 +33,9 @@ def load_test_cases(dirname: str = "benign_model") -> list[AthenaTestCase]:
2833
with Path.open(
2934
Path(TESTCASES_DIR / dirname / "expected_outputs.json"),
3035
) as f:
31-
test_cases = cast(
32-
"dict[str, list[str] | list[list[str | list[float]]]]", json.load(f)
33-
)
34-
classification_labels = cast(
35-
"list[str]", test_cases["classification_labels"]
36-
)
37-
images = cast("list[list[str | list[float]]]", test_cases["images"])
36+
test_cases: TestCases = cast("TestCases", json.load(f))
37+
classification_labels = test_cases["classification_labels"]
38+
images = test_cases["images"]
3839
return [
3940
AthenaTestCase(
4041
str(

tests/test_classify_single.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,6 @@
88
import grpc.aio
99
import numpy as np
1010
import pytest
11-
12-
from resolver_athena_client.client.athena_client import AthenaClient
13-
from resolver_athena_client.client.athena_options import AthenaOptions
14-
from resolver_athena_client.client.exceptions import AthenaError
15-
from resolver_athena_client.client.models import ImageData
1611
from resolver_athena_client.generated.athena.models_pb2 import (
1712
Classification,
1813
ClassificationError,
@@ -25,6 +20,11 @@
2520
RequestEncoding,
2621
)
2722

23+
from resolver_athena_client.client.athena_client import AthenaClient
24+
from resolver_athena_client.client.athena_options import AthenaOptions
25+
from resolver_athena_client.client.exceptions import AthenaError
26+
from resolver_athena_client.client.models import ImageData
27+
2828

2929
@pytest.fixture
3030
def athena_options() -> AthenaOptions:

0 commit comments

Comments
 (0)