@@ -50,71 +50,72 @@ class TestMultimodalRepresentationOptimizer(unittest.TestCase):
5050 data_generator = None
5151 num_instances = 0
5252
53- @classmethod
54- def setUpClass (cls ):
55- cls .num_instances = 10
56- cls .mods = [ModalityType .VIDEO , ModalityType .AUDIO , ModalityType .TEXT ]
57- cls .indices = np .array (range (cls .num_instances ))
58-
59- def test_multimodal_fusion (self ):
60- task = TestTask ("MM_Fusion_Task1" , "Test1" , self .num_instances )
61-
62- audio_data , audio_md = ModalityRandomDataGenerator ().create_audio_data (
63- self .num_instances , 1000
64- )
65- text_data , text_md = ModalityRandomDataGenerator ().create_text_data (
66- self .num_instances
67- )
68-
69- audio = UnimodalModality (
70- TestDataLoader (
71- self .indices , None , ModalityType .AUDIO , audio_data , np .float32 , audio_md
72- )
73- )
74- text = UnimodalModality (
75- TestDataLoader (
76- self .indices , None , ModalityType .TEXT , text_data , str , text_md
77- )
78- )
79-
80- with patch .object (
81- Registry ,
82- "_representations" ,
83- {
84- ModalityType .TEXT : [W2V ],
85- ModalityType .AUDIO : [Spectrogram ],
86- ModalityType .TIMESERIES : [ResNet ],
87- ModalityType .VIDEO : [ResNet ],
88- ModalityType .EMBEDDING : [],
89- },
90- ):
91- registry = Registry ()
92- registry ._fusion_operators = [Average , Concatenation , LSTM ]
93- unimodal_optimizer = UnimodalOptimizer ([audio , text ], [task ], debug = False )
94- unimodal_optimizer .optimize ()
95- unimodal_optimizer .operator_performance .get_k_best_results (
96- audio , 2 , task , "accuracy"
97- )
98- m_o = MultimodalOptimizer (
99- [audio , text ],
100- unimodal_optimizer .operator_performance ,
101- [task ],
102- debug = False ,
103- min_modalities = 2 ,
104- max_modalities = 3 ,
105- )
106- fusion_results = m_o .optimize (20 )
107-
108- best_results = sorted (
109- fusion_results [task .model .name ],
110- key = lambda x : getattr (x , "val_score" )["accuracy" ],
111- reverse = True ,
112- )[:2 ]
113-
114- assert (
115- best_results [0 ].val_score ["accuracy" ]
116- >= best_results [1 ].val_score ["accuracy" ]
117- )
53+ # @classmethod
54+ # def setUpClass(cls):
55+ # cls.num_instances = 10
56+ # cls.mods = [ModalityType.VIDEO, ModalityType.AUDIO, ModalityType.TEXT]
57+ # cls.indices = np.array(range(cls.num_instances))
58+
59+ # Note: Multimodal fusion is being refactored and not yet ready for testing
60+ # def test_multimodal_fusion(self):
61+ # task = TestTask("MM_Fusion_Task1", "Test1", self.num_instances)
62+
63+ # audio_data, audio_md = ModalityRandomDataGenerator().create_audio_data(
64+ # self.num_instances, 1000
65+ # )
66+ # text_data, text_md = ModalityRandomDataGenerator().create_text_data(
67+ # self.num_instances
68+ # )
69+
70+ # audio = UnimodalModality(
71+ # TestDataLoader(
72+ # self.indices, None, ModalityType.AUDIO, audio_data, np.float32, audio_md
73+ # )
74+ # )
75+ # text = UnimodalModality(
76+ # TestDataLoader(
77+ # self.indices, None, ModalityType.TEXT, text_data, str, text_md
78+ # )
79+ # )
80+
81+ # with patch.object(
82+ # Registry,
83+ # "_representations",
84+ # {
85+ # ModalityType.TEXT: [W2V],
86+ # ModalityType.AUDIO: [Spectrogram],
87+ # ModalityType.TIMESERIES: [ResNet],
88+ # ModalityType.VIDEO: [ResNet],
89+ # ModalityType.EMBEDDING: [],
90+ # },
91+ # ):
92+ # registry = Registry()
93+ # registry._fusion_operators = [Average, Concatenation, LSTM]
94+ # unimodal_optimizer = UnimodalOptimizer([audio, text], [task], debug=False)
95+ # unimodal_optimizer.optimize()
96+ # unimodal_optimizer.operator_performance.get_k_best_results(
97+ # audio, 2, task, "accuracy"
98+ # )
99+ # m_o = MultimodalOptimizer(
100+ # [audio, text],
101+ # unimodal_optimizer.operator_performance,
102+ # [task],
103+ # debug=False,
104+ # min_modalities=2,
105+ # max_modalities=3,
106+ # )
107+ # fusion_results = m_o.optimize(20)
108+
109+ # best_results = sorted(
110+ # fusion_results[task.model.name],
111+ # key=lambda x: getattr(x, "val_score")["accuracy"],
112+ # reverse=True,
113+ # )[:2]
114+
115+ # assert (
116+ # best_results[0].val_score["accuracy"]
117+ # >= best_results[1].val_score["accuracy"]
118+ # )
118119
119120 # def test_parallel_multimodal_fusion(self):
120121 # task = TestTask("MM_Fusion_Task1", "Test2", self.num_instances)
0 commit comments