Skip to content

Commit 4b8fdf4

Browse files
authored
Update get_test_info.py (related to tiny model creation) (#45238)
update Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
1 parent e1b80de commit 4b8fdf4

1 file changed

Lines changed: 28 additions & 7 deletions

File tree

utils/get_test_info.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import importlib
1616
import os
1717
import sys
18+
import unittest
1819

1920

2021
# This is required to make the module import works (when the python process is running from the root of the repo)
@@ -87,11 +88,19 @@ def get_test_classes(test_file):
8788
test_module = get_test_module(test_file)
8889
for attr in dir(test_module):
8990
attr_value = getattr(test_module, attr)
90-
# ModelTesterMixin is also an attribute in specific model test module. Let's exclude them by checking
91-
# `all_model_classes` is not empty (which also excludes other special classes).
92-
model_classes = getattr(attr_value, "all_model_classes", [])
93-
if len(model_classes) > 0:
94-
test_classes.append(attr_value)
91+
92+
# Look for the test classes (subclass of `unittest.TestCase`) with `all_model_classes` attribute.
93+
# This also excludes `ModelTesterMixin` and `CausalLMModelTest`.
94+
if isinstance(attr_value, type) and issubclass(attr_value, unittest.TestCase):
95+
model_classes = getattr(attr_value, "all_model_classes", [])
96+
# `CausalLMModelTest` (subclass of `ModelTesterMixin`) has `all_model_classes` as a class attribute with
97+
# the value being `None`. For a real test class of `CausalLMModelTest`, the value is only set during `setUp`.
98+
if model_classes is None:
99+
test_instance = attr_value()
100+
test_instance.setUp()
101+
model_classes = getattr(test_instance, "all_model_classes", [])
102+
if len(model_classes) > 0:
103+
test_classes.append(attr_value)
95104

96105
# sort with class names
97106
return sorted(test_classes, key=lambda x: x.__name__)
@@ -102,7 +111,12 @@ def get_model_classes(test_file):
102111
test_classes = get_test_classes(test_file)
103112
model_classes = set()
104113
for test_class in test_classes:
105-
model_classes.update(test_class.all_model_classes)
114+
all_model_classes = test_class.all_model_classes
115+
if all_model_classes is None:
116+
test_instance = test_class()
117+
test_instance.setUp()
118+
all_model_classes = test_instance.all_model_classes
119+
model_classes.update(all_model_classes)
106120

107121
# sort with class names
108122
return sorted(model_classes, key=lambda x: x.__name__)
@@ -128,8 +142,15 @@ def get_test_classes_for_model(test_file, model_class):
128142
test_classes = get_test_classes(test_file)
129143

130144
target_test_classes = []
145+
131146
for test_class in test_classes:
132-
if model_class in test_class.all_model_classes:
147+
all_model_classes = test_class.all_model_classes
148+
if all_model_classes is None:
149+
test_instance = test_class()
150+
test_instance.setUp()
151+
all_model_classes = test_instance.all_model_classes
152+
153+
if model_class in all_model_classes:
133154
target_test_classes.append(test_class)
134155

135156
# sort with class names

0 commit comments

Comments
 (0)