1515import importlib
1616import os
1717import sys
18+ import unittest
1819
1920
2021# This is required to make the module import works (when the python process is running from the root of the repo)
@@ -87,11 +88,19 @@ def get_test_classes(test_file):
8788 test_module = get_test_module (test_file )
8889 for attr in dir (test_module ):
8990 attr_value = getattr (test_module , attr )
90- # ModelTesterMixin is also an attribute in specific model test module. Let's exclude them by checking
91- # `all_model_classes` is not empty (which also excludes other special classes).
92- model_classes = getattr (attr_value , "all_model_classes" , [])
93- if len (model_classes ) > 0 :
94- test_classes .append (attr_value )
91+
92+ # Look for the test classes (subclass of `unittest.TestCase`) with `all_model_classes` attribute.
93+ # This also excludes `ModelTesterMixin` and `CausalLMModelTest`.
94+ if isinstance (attr_value , type ) and issubclass (attr_value , unittest .TestCase ):
95+ model_classes = getattr (attr_value , "all_model_classes" , [])
96+ # `CausalLMModelTest` (subclass of `ModelTesterMixin`) has `all_model_classes` as a class attribute with
97+ # the value being `None`. For a real test class of `CausalLMModelTest`, the value is only set during `setUp`.
98+ if model_classes is None :
99+ test_instance = attr_value ()
100+ test_instance .setUp ()
101+ model_classes = getattr (test_instance , "all_model_classes" , [])
102+ if len (model_classes ) > 0 :
103+ test_classes .append (attr_value )
95104
96105 # sort with class names
97106 return sorted (test_classes , key = lambda x : x .__name__ )
@@ -102,7 +111,12 @@ def get_model_classes(test_file):
102111 test_classes = get_test_classes (test_file )
103112 model_classes = set ()
104113 for test_class in test_classes :
105- model_classes .update (test_class .all_model_classes )
114+ all_model_classes = test_class .all_model_classes
115+ if all_model_classes is None :
116+ test_instance = test_class ()
117+ test_instance .setUp ()
118+ all_model_classes = test_instance .all_model_classes
119+ model_classes .update (all_model_classes )
106120
107121 # sort with class names
108122 return sorted (model_classes , key = lambda x : x .__name__ )
@@ -128,8 +142,15 @@ def get_test_classes_for_model(test_file, model_class):
128142 test_classes = get_test_classes (test_file )
129143
130144 target_test_classes = []
145+
131146 for test_class in test_classes :
132- if model_class in test_class .all_model_classes :
147+ all_model_classes = test_class .all_model_classes
148+ if all_model_classes is None :
149+ test_instance = test_class ()
150+ test_instance .setUp ()
151+ all_model_classes = test_instance .all_model_classes
152+
153+ if model_class in all_model_classes :
133154 target_test_classes .append (test_class )
134155
135156 # sort with class names
0 commit comments