@@ -28,22 +28,25 @@ BaseModel::BaseModel(const std::string &modelSource,
2828 memorySizeLowerBound = std::filesystem::file_size (modelPath);
2929}
3030
31- std::vector<std::vector<int32_t >> BaseModel::getAllInputShapes () {
32- if (!module ) {
31+ std::vector<std::vector<int32_t >>
32+ BaseModel::getAllInputShapes (std::string methodName) {
33+ if (!module ) {
3334 throw std::runtime_error (" getInputShape called on unloaded model" );
3435 }
35- auto method_meta = module ->method_meta (" forward " );
36+ auto method_meta = module ->method_meta (methodName );
3637
3738 if (!method_meta.ok ()) {
38- throw std::runtime_error (" Failed to load forward " );
39+ throw std::runtime_error (" Failed to load method: " + methodName );
3940 }
4041 std::vector<std::vector<int32_t >> output;
4142 std::size_t numInputs = method_meta->num_inputs ();
4243 output.reserve (numInputs);
4344 for (std::size_t input = 0 ; input < numInputs; ++input) {
4445 auto input_meta = method_meta->input_tensor_meta (input);
4546 if (!input_meta.ok ()) {
46- throw std::runtime_error (" Failed to load forward input" );
47+ throw std::runtime_error (
48+ " Failed to load input no: " + std::to_string (input) + " for method " +
49+ methodName);
4750 }
4851 auto shape = input_meta->sizes ();
4952 output.emplace_back (std::vector<int32_t >(shape.begin (), shape.end ()));
@@ -55,11 +58,4 @@ std::size_t BaseModel::getMemoryLowerBound() { return memorySizeLowerBound; }
5558
5659void BaseModel::unload () { module .reset (nullptr ); }
5760
58- Result<std::vector<EValue>> BaseModel::forwardET (const EValue &input_value) {
59- if (!module ) {
60- throw std::runtime_error (" Forward called on unloaded model" );
61- }
62- return module ->forward (input_value);
63- }
64-
6561} // namespace rnexecutorch
0 commit comments