@@ -88,6 +88,12 @@ TEST_F(TrainingModuleTest, JointGraphTest) {
8888 ASSERT_EQ (param.find (" linear.weight" )->second .dim (), 2 );
8989 ASSERT_EQ (param.find (" linear.bias" )->second .sizes ()[0 ], 3 );
9090 ASSERT_EQ (param.find (" linear.bias" )->second .dim (), 1 );
91+
92+ // Test attributes for pte only model
93+ auto attributes_res = mod.named_attributes (" forward" );
94+ ASSERT_EQ (attributes_res.error (), Error::Ok);
95+ auto & attributes = attributes_res.get ();
96+ ASSERT_EQ (attributes.size (), 0 );
9197}
9298
9399TEST_F (TrainingModuleTest, NonTrainingModuleTest) {
@@ -153,3 +159,43 @@ TEST_F(TrainingModuleTest, SeperateDataTest) {
153159 ASSERT_EQ (res.error (), Error::Ok);
154160 ASSERT_EQ (res.get ().size (), 1 );
155161}
162+
163+ TEST_F (TrainingModuleTest, DataExternalConstantsTest) {
164+ // Test the external constants are loaded correctly.
165+ const char * ptd_path = std::getenv (" ET_MODULE_ADD_MUL_DATA_PATH" );
166+ Result<FileDataLoader> data_map_loader_res = FileDataLoader::from (ptd_path);
167+ ASSERT_EQ (data_map_loader_res.error (), Error::Ok);
168+
169+ auto data_map_loader =
170+ std::make_unique<torch::executor::util::FileDataLoader>(
171+ std::move (data_map_loader_res.get ()));
172+
173+ const char * pte_path = std::getenv (" ET_MODULE_ADD_MUL_PROGRAM_PATH" );
174+ Result<FileDataLoader> pte_loader_res = FileDataLoader::from (pte_path);
175+ ASSERT_EQ (pte_loader_res.error (), Error::Ok);
176+
177+ auto pte_loader = std::make_unique<torch::executor::util::FileDataLoader>(
178+ std::move (pte_loader_res.get ()));
179+
180+ auto mod = executorch::extension::training::TrainingModule (
181+ std::move (pte_loader),
182+ nullptr ,
183+ nullptr ,
184+ nullptr ,
185+ std::move (data_map_loader));
186+
187+ // Test Attributes for pte + ptd model containing external constants
188+ auto attributes_res = mod.named_attributes (" forward" );
189+ ASSERT_EQ (attributes_res.error (), Error::Ok);
190+ auto & attributes = attributes_res.get ();
191+ ASSERT_EQ (attributes.size (), 2 );
192+ ASSERT_NE (attributes.find (" a" ), attributes.end ());
193+ ASSERT_NE (attributes.find (" b" ), attributes.end ());
194+
195+ ASSERT_EQ (attributes.find (" a" )->second .sizes ()[0 ], 2 );
196+ ASSERT_EQ (attributes.find (" a" )->second .sizes ()[1 ], 2 );
197+ ASSERT_EQ (attributes.find (" a" )->second .dim (), 2 );
198+ ASSERT_EQ (attributes.find (" b" )->second .sizes ()[0 ], 2 );
199+ ASSERT_EQ (attributes.find (" b" )->second .sizes ()[0 ], 2 );
200+ ASSERT_EQ (attributes.find (" b" )->second .dim (), 2 );
201+ }
0 commit comments