diff --git a/.gitignore b/.gitignore index b37594314..76a9eb5c8 100644 --- a/.gitignore +++ b/.gitignore @@ -498,3 +498,15 @@ opencl_program.cc opencl_program.cc platforms/mac/tnn.xcodeproj/project.xcworkspace/xcuserdata/darrenyao.xcuserdatad/UserInterfaceState.xcuserstate platforms/mac/tnn.xcodeproj/xcuserdata/darrenyao.xcuserdatad/xcschemes/xcschememanagement.plist + +# build output +platforms/ios/tnn.bundle/ +platforms/ios/tnn.framework/ +scripts/build_aarch64_macos/ +scripts/build_macos_native/ + +# tmp dir +tmp/ + +# finetune_demo +scripts/finetune_demo**/ diff --git a/CMakeLists.txt b/CMakeLists.txt index fad86599f..5130e402f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,7 @@ set(TNN_VERSION "${TNN_MAJOR_VERSION}.${TNN_MINOR_VERSION}.${TNN_PATCH_VERSION}. option(TNN_CPU_ENABLE "Enable Cpu" ON) option(TNN_X86_ENABLE "Enable X86" OFF) -option(TNN_ARM_ENABLE "Enable Arm" OFF) +option(TNN_ARM_ENABLE "Enable Arm" ON) option(TNN_ARM82_ENABLE "Enable Arm82" OFF) option(TNN_METAL_ENABLE "Enable Metal" OFF) option(TNN_OPENCL_ENABLE "Enable OpenCL" OFF) @@ -38,7 +38,7 @@ option(TNN_SYMBOL_HIDE "Enable Hide Symbol Visibility" ON) option(TNN_OPENMP_ENABLE "Enable OpenMP" OFF) option(TNN_BUILD_SHARED "Build Shared Library" ON) option(TNN_OPENVINO_BUILD_SHARED "Build Shared Openvino Library" OFF) -option(TNN_TEST_ENABLE "Enable Test" OFF) +option(TNN_TEST_ENABLE "Enable Test" ON) option(TNN_UNIT_TEST_ENABLE "Enable Test" OFF) option(TNN_PROFILER_ENABLE "Enable Profiler" OFF) option(TNN_QUANTIZATION_ENABLE "Enable Quantization" OFF) @@ -51,7 +51,7 @@ option(TNN_ONNX2TNN_ENABLE "Enable ONNX2TNN Converter" OFF) option(TNN_TNN2MEM_ENABLE "Enable tnn2mem" OFF) option(TNN_BUILD_BENCHMARK_TEST_LIB_ENABLE "Enable Build Benchmark Test Lib" OFF) option(TNN_GLIBCXX_USE_CXX11_ABI_ENABLE "Enable Use CXX11 ABI" ON) -option(TNN_TRAIN_ENABLE "Enable train module" OFF) +option(TNN_TRAIN_ENABLE "Enable train module" ON) option(TNN_METAL_FLOAT32 "Enable Metal Float32" OFF) option(TNN_COREML_FLOAT32 "Enable Float32 CoreML Model" ON) option(TNN_DYNAMIC_RANGE_QUANTIZATION_ENABLE "Enable Dynamic Range Quantization" OFF) diff --git a/include/tnn/core/common.h b/include/tnn/core/common.h index 7ce3ec306..ec83a8633 100644 --- a/include/tnn/core/common.h +++ b/include/tnn/core/common.h @@ -172,11 +172,11 @@ struct PUBLIC TrainConfig { // loss LossFunc loss_func = LOSS_FUNC_DEFAULT; // if loss_func is not default, the following informations are used to create loss layer - std::string target_layer = ""; // the layer whose output is used to calculate loss, default is the last layer + std::vector target_layers; // the layers whose outputs are used to calculate losses, default is the last layer bool auto_add_prob_layer = true; // add softmax or sigmoid layer before loss layer // target used to calculate loss - std::string ground_truth_name = ""; // the ground truth, provide by model inputs - DimsVector ground_truth_shape = {}; // the shape of the ground truth + std::vector ground_truth_names; // the ground truths, provide by model inputs + std::vector ground_truth_shapes; // the shapes of the ground truths // solver SolverType solver_type = SOLVER_TYPE_SGD; @@ -189,8 +189,8 @@ struct PUBLIC TrainConfig { }; struct PUBLIC TrainingFeedback { - std::string loss_name = ""; - float loss_value = 0.0; + std::vector loss_names; + std::vector loss_values; std::string global_step_name = ""; int global_step_value = 0; }; diff --git a/source/tnn/core/default_network.cc b/source/tnn/core/default_network.cc index eaf82be02..8573c7d8b 100644 --- a/source/tnn/core/default_network.cc +++ b/source/tnn/core/default_network.cc @@ -608,9 +608,8 @@ Status DefaultNetwork::Forward() { auto layer = layers_[cnt]; std::vector inputs = layer->GetInputBlobs(); std::vector outputs = layer->GetOutputBlobs(); - { - + #if DUMP_INPUT_BLOB if (runtime_model_ == RUNTIME_MODE_NORMAL || runtime_model_ == RUNTIME_MODE_BACKWARD) { // InputBlob data in dumped into files in NCHW_FLOAT format as default diff --git a/source/tnn/core/instance.cc b/source/tnn/core/instance.cc index 05d40d891..186d0cc81 100644 --- a/source/tnn/core/instance.cc +++ b/source/tnn/core/instance.cc @@ -65,8 +65,10 @@ Status Instance::SaveTrainedModel(const std::string& model_path) { Status Instance::GetTrainingFeedback(TrainingFeedback& feed_back) { RETURN_ON_NEQ(network_->GetTrainingFeedback(feed_back), TNN_OK); std::shared_ptr mat; - GetOutputMat(mat, MatConvertParam(), feed_back.loss_name); - feed_back.loss_value = *(reinterpret_cast(mat->GetData())); + for (int i = 0; i < feed_back.loss_names.size(); ++i) { + GetOutputMat(mat, MatConvertParam(), feed_back.loss_names[i]); + feed_back.loss_values.push_back(*(reinterpret_cast(mat->GetData()))); + } GetOutputMat(mat, MatConvertParam(), feed_back.global_step_name); feed_back.global_step_value = *(reinterpret_cast(mat->GetData())); return TNN_OK; diff --git a/source/tnn/interpreter/net_structure.h b/source/tnn/interpreter/net_structure.h index c7da301b5..54748c94f 100644 --- a/source/tnn/interpreter/net_structure.h +++ b/source/tnn/interpreter/net_structure.h @@ -67,6 +67,11 @@ struct NetStructure { std::set blobs; ModelType source_model_type = MODEL_TYPE_TNN; +#ifdef TNN_TRAIN + std::vector loss_names; + std::vector loss_grad_names; +#endif + public: std::shared_ptr Copy() { std::shared_ptr net_structure(new NetStructure()); diff --git a/source/tnn/train/default_train_network.cc b/source/tnn/train/default_train_network.cc index ed586ba68..256b4135e 100644 --- a/source/tnn/train/default_train_network.cc +++ b/source/tnn/train/default_train_network.cc @@ -11,7 +11,6 @@ // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. - #include "tnn/train/default_train_network.h" #include "tnn/train/gradient/gradient_layer.h" @@ -37,6 +36,8 @@ Status DefaultTrainNetwork::Init(NetworkConfig &net_config, ModelConfig &model_c enable_const_folder); RETURN_ON_NEQ(ret, TNN_OK); + RETURN_ON_NEQ(CopyLossAndLossGradNames(interpreter), TNN_OK); + RETURN_ON_NEQ(InitTrainingStatus(), TNN_OK); RETURN_ON_NEQ(InitRuntimeInfo(), TNN_OK); @@ -44,10 +45,36 @@ Status DefaultTrainNetwork::Init(NetworkConfig &net_config, ModelConfig &model_c return TNN_OK; } +Status DefaultTrainNetwork::CopyLossAndLossGradNames(AbstractModelInterpreter *interpreter) { + auto default_interpreter = dynamic_cast(interpreter); + CHECK_PARAM_NULL(default_interpreter); + + const NetStructure *net_structure = default_interpreter->GetNetStructure(); + if (net_structure == NULL) { + LOGE("ERROR: network_ is nil, network_type may not support\n"); + return Status(TNNERR_NULL_PARAM, "network_ is nil, network_type may not support"); + } + + loss_names_ = net_structure->loss_names; + loss_grad_names_ = net_structure->loss_grad_names; + + if (loss_names_.empty()) { + LOGE("DefaultTrainNetwork::CopyLossAndLossGradNames ERROR, cannot get loss names\n"); + return Status(TNNERR_TRAIN_ERROR, "cannot get loss names"); + } + if (loss_grad_names_.empty()) { + LOGE("DefaultTrainNetwork::CopyLossAndLossGradNames ERROR, cannot get loss grad names\n"); + return Status(TNNERR_TRAIN_ERROR, "cannot get loss grad names"); + } + return TNN_OK; +} + Status DefaultTrainNetwork::GetAllInputBlobs(BlobMap &blobs) { blob_manager_->GetAllInputBlobs(blobs); // loss grad is assumed to be one - blobs.erase(loss_grad_name_); + for (auto loss_grad_name : loss_grad_names_) { + blobs.erase(loss_grad_name); + } // global step init value is assumed to be zero blobs.erase(global_step_init_name_); return TNN_OK; @@ -81,34 +108,24 @@ Status DefaultTrainNetwork::TrainStep() { } Status DefaultTrainNetwork::GetTrainingFeedback(TrainingFeedback &feed_back) { - feed_back.loss_name = loss_name_; + for (const auto & loss_name : loss_names_) { + feed_back.loss_names.push_back(loss_name); + } feed_back.global_step_name = global_step_name_; return TNN_OK; } Status DefaultTrainNetwork::InitTrainingStatus() { - LayerInfo *loss_layer = nullptr; - LayerInfo *loss_grad_layer = nullptr; - int cnt = 0; + std::vector loss_layers; + std::vector loss_grad_layers; + int cnt = 0; for (auto layer : net_structure_->layers) { if (layer->type == LAYER_GRADIENT) { - loss_grad_layer = layer.get(); break; } - loss_layer = layer.get(); cnt++; } forward_layer_count_ = cnt; - if (!loss_layer) { - LOGE("DefaultTrainNetwork::InitTrainingStatus ERROR, cannot get loss layer\n"); - return Status(TNNERR_TRAIN_ERROR, "cannot get loss layer"); - } - if (!loss_grad_layer) { - LOGE("DefaultTrainNetwork::InitTrainingStatus ERROR, cannot get loss grad layer\n"); - return Status(TNNERR_TRAIN_ERROR, "cannot get loss grad layer"); - } - loss_name_ = loss_layer->outputs[0]; - loss_grad_name_ = loss_grad_layer->inputs.back(); LayerInfo *solver_layer_info = net_structure_->layers.back().get(); if (!solver_layer_info) { @@ -131,50 +148,55 @@ Status DefaultTrainNetwork::InitRuntimeInfo() { } Status DefaultTrainNetwork::SetLossGrad() { - Blob *loss_blob = blob_manager_->GetBlob(loss_name_); - if (!loss_blob) { - LOGE("DefaultTrainNetwork::SetLossGrad get loss_blob failed\n"); - return Status(TNNERR_TRAIN_ERROR, "get loss_blob failed!"); - } - auto loss_data_count = DimsVectorUtils::Count(loss_blob->GetBlobDesc().dims); - if (loss_data_count != 1) { - LOGE( - "DefaultTrainNetwork::SetLossGrad only support loss data count = 1 now, got %d. Try to change loss " - "function type or loss target layer!\n", - loss_data_count); - return Status(TNNERR_TRAIN_ERROR, - "loss data count not supported, try to change loss function type or loss target layer!"); - } + for (int loss_idx = 0; loss_idx < loss_names_.size(); ++loss_idx) { + const auto loss_name = loss_names_[loss_idx]; + const auto loss_grad_name = loss_grad_names_[loss_idx]; + + Blob *loss_blob = blob_manager_->GetBlob(loss_name); + if (!loss_blob) { + LOGE("DefaultTrainNetwork::SetLossGrad get loss_blob failed\n"); + return Status(TNNERR_TRAIN_ERROR, "get loss_blob failed!"); + } + auto loss_data_count = DimsVectorUtils::Count(loss_blob->GetBlobDesc().dims); + if (loss_data_count != 1) { + LOGE( + "DefaultTrainNetwork::SetLossGrad only support loss data count = 1 now, got %d. Try to change loss " + "function type or loss target layer!\n", + loss_data_count); + return Status(TNNERR_TRAIN_ERROR, + "loss data count not supported, try to change loss function type or loss target layer!"); + } - std::shared_ptr mat(new Mat(DEVICE_ARM, NCHW_FLOAT, {loss_data_count})); - if (!mat || !mat->GetData()) { - LOGE("DefaultTrainNetwork::SetLossGrad create mat failed\n"); - return Status(TNNERR_TRAIN_ERROR, "create mat failed"); - } + std::shared_ptr mat(new Mat(DEVICE_ARM, NCHW_FLOAT, {loss_data_count})); + if (!mat || !mat->GetData()) { + LOGE("DefaultTrainNetwork::SetLossGrad create mat failed\n"); + return Status(TNNERR_TRAIN_ERROR, "create mat failed"); + } - // init loss grad as one - auto ptr = reinterpret_cast(mat->GetData()); - for (int i = 0; i < loss_data_count; ++i) { - ptr[i] = 1.0; - } + // init loss grad as one + auto ptr = reinterpret_cast(mat->GetData()); + for (int i = 0; i < loss_data_count; ++i) { + ptr[i] = 1.0; + } - Blob *loss_grad = blob_manager_->GetBlob(loss_grad_name_); - if (!loss_grad) { - LOGE("DefaultTrainNetwork::SetLossGrad get loss_grad failed\n"); - return Status(TNNERR_TRAIN_ERROR, "get loss_grad failed!"); - } + Blob *loss_grad = blob_manager_->GetBlob(loss_grad_name); + if (!loss_grad) { + LOGE("DefaultTrainNetwork::SetLossGrad get loss_grad failed\n"); + return Status(TNNERR_TRAIN_ERROR, "get loss_grad failed!"); + } - // create blob convert - std::shared_ptr blob_converter = std::make_shared(loss_grad); + // create blob convert + std::shared_ptr blob_converter = std::make_shared(loss_grad); - // get command queue - void *command_queue = nullptr; - RETURN_ON_NEQ(GetCommandQueue(&command_queue), TNN_OK); + // get command queue + void *command_queue = nullptr; + RETURN_ON_NEQ(GetCommandQueue(&command_queue), TNN_OK); - Status status = blob_converter->ConvertFromMatAsync(*(mat.get()), MatConvertParam(), command_queue); - if (status != TNN_OK) { - LOGE("DefaultTrainNetwork::SetLossGrad, ConvertFromMatAsync Error: %s\n", status.description().c_str()); - return status; + Status status = blob_converter->ConvertFromMatAsync(*(mat.get()), MatConvertParam(), command_queue); + if (status != TNN_OK) { + LOGE("DefaultTrainNetwork::SetLossGrad, ConvertFromMatAsync Error: %s\n", status.description().c_str()); + return status; + } } return TNN_OK; diff --git a/source/tnn/train/default_train_network.h b/source/tnn/train/default_train_network.h index b55c3391d..011fb8381 100644 --- a/source/tnn/train/default_train_network.h +++ b/source/tnn/train/default_train_network.h @@ -51,13 +51,15 @@ class DefaultTrainNetwork : public DefaultNetwork { virtual Status SetGradientLayerRuntimeInfo(); virtual Status SetSolverLayerRuntimeInfo(); + virtual Status CopyLossAndLossGradNames(AbstractModelInterpreter *interpreter); + std::map input_to_grad_map_; std::map grad_to_resource_map_; std::vector need_refresh_layers_; - std::string loss_name_; - std::string loss_grad_name_; + std::vector loss_names_; + std::vector loss_grad_names_; std::string global_step_name_; std::string global_step_init_name_; diff --git a/source/tnn/train/optimizer/net_optimizer_insert_loss_and_gradient.cc b/source/tnn/train/optimizer/net_optimizer_insert_loss_and_gradient.cc index 08c74f572..b298c398d 100644 --- a/source/tnn/train/optimizer/net_optimizer_insert_loss_and_gradient.cc +++ b/source/tnn/train/optimizer/net_optimizer_insert_loss_and_gradient.cc @@ -50,8 +50,8 @@ namespace optimizer { bool NetOptimizerInsertLossAndGradient::IsSupported(const NetworkConfig &net_config) { bool is_support = false; - train_config = net_config.train_config; - if (train_config.run_mode == TRAIN_MODE_TRAIN) { + train_config_ = net_config.train_config; + if (train_config_.run_mode == TRAIN_MODE_TRAIN) { auto device = net_config.device_type; if (device == DEVICE_ARM || device == DEVICE_NAIVE) { is_support = true; @@ -79,7 +79,7 @@ namespace optimizer { return TNN_OK; } - if (train_config.trainable_layers.empty() && !train_config.train_the_whole_model) { + if (train_config_.trainable_layers.empty() && !train_config_.train_the_whole_model) { return Status(TNNERR_TRAIN_ERROR, "train mode but trainable_layers is empty"); } @@ -96,73 +96,92 @@ namespace optimizer { } Status NetOptimizerInsertLossAndGradient::InsertLossLayer(NetStructure *net_structure) { - if (train_config.loss_func == LOSS_FUNC_DEFAULT) { + if (train_config_.loss_func == LOSS_FUNC_DEFAULT) { // the last layer should output loss auto loss_layer = net_structure->layers.back(); - loss_blob_ = loss_layer->outputs[0]; + loss_blobs_.push_back(loss_layer->outputs[0]); + net_structure->loss_names.push_back(loss_layer->outputs[0]); return TNN_OK; } // target blob - if (train_config.ground_truth_name.empty() || train_config.ground_truth_shape.empty()) { + if (train_config_.ground_truth_names.empty() || train_config_.ground_truth_shapes.empty()) { LOGE( "NetOptimizerInsertLossAndGradient::InsertLossLayer, loss func is %d, please set target name and shape " "to calculate loss\n", - train_config.loss_func); + train_config_.loss_func); return Status(TNNERR_TRAIN_ERROR, "loss layer will be added, but target(ground truth) name and shape is empty!"); } - net_structure->inputs_shape_map[train_config.ground_truth_name] = train_config.ground_truth_shape; - net_structure->blobs.insert(train_config.ground_truth_name); - // target layer - std::shared_ptr target_layer = GetTargetLayer(net_structure); - if (target_layer == nullptr || target_layer->outputs.size() <= 0) { + // get target layers + auto target_layers = GetTargetLayers(net_structure); + if (target_layers.empty()) { return Status(TNNERR_TRAIN_ERROR, "get target layer error"); } - - // probability layer - std::shared_ptr prob_layer = GetOrCreateProbability(target_layer); - if (prob_layer == nullptr) { - return Status(TNNERR_TRAIN_ERROR, "get or create prob layer error"); + for (auto layer : target_layers) { + if (layer->outputs.size() <= 0) { + return Status(TNNERR_TRAIN_ERROR, "get target layer error"); + } } - if (prob_layer != target_layer) { - auto prob_input = target_layer->outputs[0]; - prob_layer->inputs.push_back(prob_input); - auto prob_output = prob_input + kProbSuffix.at(train_config.loss_func); - prob_layer->outputs.push_back(prob_output); - net_structure->layers.push_back(prob_layer); - net_structure->blobs.insert(prob_output); + if (target_layers.size() != train_config_.ground_truth_names.size()) { + return Status(TNNERR_TRAIN_ERROR, "size of target layers should eq size of target names"); } - // cross entropy - std::shared_ptr entropy_layer = - CreateCrossEntropy(prob_layer->name + kEntropySuffix.at(train_config.loss_func)); - if (entropy_layer == nullptr) { - return Status(TNNERR_TRAIN_ERROR, "create entropy layer error"); - } else { - auto entropy_input = prob_layer->outputs[0]; - entropy_layer->inputs.push_back(entropy_input); - entropy_layer->inputs.push_back(train_config.ground_truth_name); - auto entropy_output = entropy_input + kEntropySuffix.at(train_config.loss_func); - entropy_layer->outputs.push_back(entropy_output); - net_structure->layers.push_back(entropy_layer); - net_structure->blobs.insert(entropy_output); - } + // insert prob, loss and reduce mean layers + for (int i = 0; i < train_config_.ground_truth_names.size(); ++i) { + auto target_layer = target_layers[i]; + auto ground_truth_name = train_config_.ground_truth_names[i]; + auto ground_truth_shape = train_config_.ground_truth_shapes[i]; - // reduce mean - std::shared_ptr reduce_layer = CreateReduceMean(entropy_layer->name + loss_suffix); - if (reduce_layer == nullptr) { - return Status(TNNERR_TRAIN_ERROR, "create reduce mean layer error"); - } else { - auto reduce_input = entropy_layer->outputs[0]; - reduce_layer->inputs.push_back(reduce_input); - auto reduce_output = reduce_input + loss_suffix; - reduce_layer->outputs.push_back(reduce_output); - net_structure->layers.push_back(reduce_layer); - net_structure->blobs.insert(reduce_output); - net_structure->outputs.insert(reduce_output); - loss_blob_ = reduce_output; + net_structure->inputs_shape_map[ground_truth_name] = ground_truth_shape; + net_structure->blobs.insert(ground_truth_name); + + // insert probability layer + std::shared_ptr prob_layer = GetOrCreateProbability(target_layer); + if (prob_layer == nullptr) { + return Status(TNNERR_TRAIN_ERROR, "get or create prob layer error"); + } + if (prob_layer != target_layer) { + auto prob_input = target_layer->outputs[0]; + prob_layer->inputs.push_back(prob_input); + auto prob_output = prob_input + kProbSuffix.at(train_config_.loss_func); + prob_layer->outputs.push_back(prob_output); + net_structure->layers.push_back(prob_layer); + net_structure->blobs.insert(prob_output); + } + + // cross entropy + std::shared_ptr entropy_layer = + CreateCrossEntropy(prob_layer->name + kEntropySuffix.at(train_config_.loss_func)); + if (entropy_layer == nullptr) { + return Status(TNNERR_TRAIN_ERROR, "create entropy layer error"); + } else { + auto entropy_input = prob_layer->outputs[0]; + entropy_layer->inputs.push_back(entropy_input); + entropy_layer->inputs.push_back(ground_truth_name); + auto entropy_output = entropy_input + kEntropySuffix.at(train_config_.loss_func); + entropy_layer->outputs.push_back(entropy_output); + net_structure->layers.push_back(entropy_layer); + net_structure->blobs.insert(entropy_output); + } + + // reduce mean + std::shared_ptr reduce_layer = CreateReduceMean(entropy_layer->name + loss_suffix, + ground_truth_shape.size()); + if (reduce_layer == nullptr) { + return Status(TNNERR_TRAIN_ERROR, "create reduce mean layer error"); + } else { + auto reduce_input = entropy_layer->outputs[0]; + reduce_layer->inputs.push_back(reduce_input); + auto reduce_output = reduce_input + loss_suffix; + reduce_layer->outputs.push_back(reduce_output); + net_structure->layers.push_back(reduce_layer); + net_structure->blobs.insert(reduce_output); + net_structure->outputs.insert(reduce_output); + loss_blobs_.push_back(reduce_output); + net_structure->loss_names.push_back(reduce_output); + } } return TNN_OK; @@ -193,7 +212,7 @@ namespace optimizer { std::vector output_grads; for (auto forward_output : forward_layer->outputs) { grad_layer->inputs.push_back(forward_output); - if (forward_output != loss_blob_) { + if (std::find(loss_blobs_.begin(), loss_blobs_.end(), forward_output) == loss_blobs_.end()) { if (blob_to_grad_map.find(forward_output) != blob_to_grad_map.end()) { output_grads.push_back(blob_to_grad_map[forward_output]); } else { @@ -208,13 +227,15 @@ namespace optimizer { output_grads.push_back(loss_grad); net_structure->blobs.insert(loss_grad); net_structure->inputs_shape_map.insert({loss_grad, {1}}); + net_structure->loss_grad_names.push_back(loss_grad); } } grad_layer->inputs.insert(grad_layer->inputs.end(), output_grads.begin(), output_grads.end()); // resource buffer gradients - if (train_config.train_the_whole_model || - (train_config.trainable_layers.find(forward_layer->name) != train_config.trainable_layers.end())) { + if (train_config_.train_the_whole_model || + (train_config_.trainable_layers.find(forward_layer->name) != + train_config_.trainable_layers.end())) { const auto &resource_map = net_resource->resource_map; if (resource_map.find(forward_layer->name) != resource_map.end()) { auto grad_param = dynamic_cast(grad_layer->param.get()); @@ -263,29 +284,41 @@ namespace optimizer { return TNN_OK; } - std::shared_ptr NetOptimizerInsertLossAndGradient::GetTargetLayer(NetStructure *net_structure) { - std::shared_ptr last_layer; - if (train_config.target_layer.empty()) { + std::vector> NetOptimizerInsertLossAndGradient::GetTargetLayers(NetStructure *net_structure) { + if (train_config_.target_layers.empty()) { LOGD("NetOptimizerInsertLossAndGradient::InsertLossLayer, target layer is empty, use the last layer\n"); - last_layer = net_structure->layers.back(); - } else { - for (auto layer : net_structure->layers) { - if (layer->name == train_config.target_layer) { - last_layer = layer; - } + return std::vector> {net_structure->layers.back()}; + } + + std::vector> mid_result; + for (auto layer : net_structure->layers) { + if (std::find(train_config_.target_layers.begin(), train_config_.target_layers.end(), + layer->name) != train_config_.target_layers.end()) { + mid_result.push_back(layer); } } - return last_layer; + + // make the order of layers in result the same as train_config_.target_layers + std::vector> result; + for (auto name : train_config_.target_layers) { + for (auto layer : mid_result) { + if (layer->name == name) { + result.push_back(layer); + } + } + } + + return result; } std::shared_ptr NetOptimizerInsertLossAndGradient::GetOrCreateProbability( std::shared_ptr target_layer) { - if (train_config.loss_func == LOSS_FUNC_BINARY_CROSS_ENTROPY) { - if (train_config.auto_add_prob_layer && target_layer->type != LAYER_SIGMOID) { + if (train_config_.loss_func == LOSS_FUNC_BINARY_CROSS_ENTROPY) { + if (train_config_.auto_add_prob_layer && target_layer->type != LAYER_SIGMOID) { std::shared_ptr new_layer = std::shared_ptr(new LayerInfo()); new_layer->type = LAYER_SIGMOID; new_layer->type_str = "Sigmoid"; - new_layer->name = target_layer->name + kProbSuffix.at(train_config.loss_func); + new_layer->name = target_layer->name + kProbSuffix.at(train_config_.loss_func); LayerParam *param = new LayerParam(); new_layer->param = std::shared_ptr(param); new_layer->param->type = new_layer->type_str; @@ -298,12 +331,12 @@ namespace optimizer { target_layer->name.c_str()); return target_layer; } - } else if (train_config.loss_func == LOSS_FUNC_CATEGORICAL_CROSS_ENTROPY) { - if (train_config.auto_add_prob_layer && target_layer->type != LAYER_SOFTMAX) { + } else if (train_config_.loss_func == LOSS_FUNC_CATEGORICAL_CROSS_ENTROPY) { + if (train_config_.auto_add_prob_layer && target_layer->type != LAYER_SOFTMAX) { std::shared_ptr new_layer = std::shared_ptr(new LayerInfo()); new_layer->type = LAYER_SOFTMAX; new_layer->type_str = "Softmax"; - new_layer->name = target_layer->name + kProbSuffix.at(train_config.loss_func); + new_layer->name = target_layer->name + kProbSuffix.at(train_config_.loss_func); SoftmaxLayerParam *param = new SoftmaxLayerParam(); new_layer->param = std::shared_ptr(param); new_layer->param->type = new_layer->type_str; @@ -326,10 +359,10 @@ namespace optimizer { std::shared_ptr NetOptimizerInsertLossAndGradient::CreateCrossEntropy(const std::string &name) { std::shared_ptr new_layer = std::shared_ptr(new LayerInfo()); - if (train_config.loss_func == LOSS_FUNC_BINARY_CROSS_ENTROPY) { + if (train_config_.loss_func == LOSS_FUNC_BINARY_CROSS_ENTROPY) { new_layer->type = LAYER_BINARY_CROSSENTROPY; new_layer->type_str = "BinaryCrossEntropy"; - } else if (train_config.loss_func == LOSS_FUNC_CATEGORICAL_CROSS_ENTROPY) { + } else if (train_config_.loss_func == LOSS_FUNC_CATEGORICAL_CROSS_ENTROPY) { new_layer->type = LAYER_CATEGORICAL_CROSSENTROPY; new_layer->type_str = "CategoricalCrossEntropy"; } else { @@ -344,7 +377,7 @@ namespace optimizer { return new_layer; } - std::shared_ptr NetOptimizerInsertLossAndGradient::CreateReduceMean(const std::string &name) { + std::shared_ptr NetOptimizerInsertLossAndGradient::CreateReduceMean(const std::string &name, int ndims) { std::shared_ptr new_layer = std::shared_ptr(new LayerInfo()); new_layer->type = LAYER_REDUCE_MEAN; new_layer->type_str = "ReduceMean"; @@ -353,7 +386,7 @@ namespace optimizer { new_layer->param = std::shared_ptr(param); new_layer->param->type = new_layer->type_str; new_layer->param->name = new_layer->name; - for (int i = 0; i < train_config.ground_truth_shape.size(); ++i) { + for (int i = 0; i < ndims; ++i) { param->axis.push_back(i); } return new_layer; @@ -383,8 +416,8 @@ namespace optimizer { new_layer->param = std::shared_ptr(param); new_layer->param->type = new_layer->type_str; new_layer->param->name = new_layer->name; - param->type = train_config.solver_type; - param->learning_rate = train_config.solver_params.learning_rate; + param->type = train_config_.solver_type; + param->learning_rate = train_config_.solver_params.learning_rate; return new_layer; } @@ -396,7 +429,7 @@ namespace optimizer { layer_names_set.insert(ly->name); } - for (auto name : train_config.trainable_layers) { + for (auto name : train_config_.trainable_layers) { if (layer_names_set.find(name) == layer_names_set.end()) { LOGE("NetOptimizerInsertLossAndGradient::GetNeedGradLayers, specified trainable layer: %s not found.\n", name.c_str()); return Status(TNNERR_TRAIN_ERROR, "specified tranable layer not found."); @@ -411,8 +444,8 @@ namespace optimizer { } for (auto &layer : structure->layers) { - if (train_config.train_the_whole_model || - (train_config.trainable_layers.find(layer->name) != train_config.trainable_layers.end())) { + if (train_config_.train_the_whole_model || + (train_config_.trainable_layers.find(layer->name) != train_config_.trainable_layers.end())) { need_grad_layers.insert(layer->name); LOGD("Layer need to calculate grad: %s\n", layer->name.c_str()); continue; @@ -481,7 +514,7 @@ namespace optimizer { } for (auto &layer : structure->layers) { - DeepVisit(layer.get(), train_config.trainable_layers, blob_to_layer, need_grad_layers, + DeepVisit(layer.get(), train_config_.trainable_layers, blob_to_layer, need_grad_layers, structure->inputs_shape_map); } diff --git a/source/tnn/train/optimizer/net_optimizer_insert_loss_and_gradient.h b/source/tnn/train/optimizer/net_optimizer_insert_loss_and_gradient.h index bf1ae529b..0c33d3a75 100644 --- a/source/tnn/train/optimizer/net_optimizer_insert_loss_and_gradient.h +++ b/source/tnn/train/optimizer/net_optimizer_insert_loss_and_gradient.h @@ -37,10 +37,10 @@ namespace optimizer { private: Status InsertLossLayer(NetStructure* net_structure); - std::shared_ptr GetTargetLayer(NetStructure* net_structure); + std::vector> GetTargetLayers(NetStructure *net_structure); std::shared_ptr GetOrCreateProbability(std::shared_ptr last_layer); std::shared_ptr CreateCrossEntropy(const std::string& name); - std::shared_ptr CreateReduceMean(const std::string& name); + std::shared_ptr CreateReduceMean(const std::string& name, int ndims); Status InsertGradientLayers(NetStructure* net_structure, NetResource* net_resource); Status GetNeedGradLayers(NetStructure* net_structure, std::set& need_grad_layers); @@ -49,9 +49,9 @@ namespace optimizer { Status InsertGradientUpdateLayer(NetStructure* net_structure); std::shared_ptr CreateSolver(const std::string& name); - TrainConfig train_config; + TrainConfig train_config_; - std::string loss_blob_; + std::vector loss_blobs_; std::vector resource_grads_; }; diff --git a/test/flags.h b/test/flags.h index 7f44a0e01..be9c66fba 100644 --- a/test/flags.h +++ b/test/flags.h @@ -117,11 +117,11 @@ static const char train_mode_message[] = "train mode: PREDICT, TRAIN, default PR DECLARE_string(lf); static const char loss_function_message[] = "loss function: BCE, CCE, default treat last layer as loss layer"; DECLARE_string(tl); -static const char target_layer_message[] = "target layer to calculate loss: default use the last layer"; +static const char target_layer_message[] = "list of target layers to calculate loss: default use the last layer"; DECLARE_bool(ap); static const char auto_add_probability_layer[] = "add sigmoid or softmax before calculating loss: default true"; DECLARE_string(gts); -static const char ground_truth_shape_message[] = "ground truth and its shape: name[n,c,h,w]"; +static const char ground_truth_shape_message[] = "list of ground truths and their shapes: name1[n1,c1,h1,w1] name2[n2,c2,h2,w2] ..."; DECLARE_string(st); static const char solver_type_message[] = "solver type: SGD, default SGD"; DECLARE_double(lr); diff --git a/test/test.cc b/test/test.cc index 8e2d6be88..fad6dc8f4 100644 --- a/test/test.cc +++ b/test/test.cc @@ -50,6 +50,19 @@ namespace TNN_NS { namespace test { + void LogFeedbackStep(const TrainingFeedback &feed_back) { + static char buf[1024]; + std::string output; + sprintf(buf, "Training step: %d, loss:", feed_back.global_step_value); + output += buf; + for (auto loss_value: feed_back.loss_values) { + sprintf(buf, " %f", loss_value); + output += buf; + } + output += "\n"; + LOGI("%s", output.c_str()); + } + int Run(int argc, char* argv[]) { // parse command line params if (!ParseAndCheckCommandLine(argc, argv)) @@ -114,7 +127,7 @@ namespace test { } TrainingFeedback feed_back; ret = instance->GetTrainingFeedback(feed_back); - LOGI("Training step: %d, loss: %f\n", feed_back.global_step_value, feed_back.loss_value); + LogFeedbackStep(feed_back); } #endif // TNN_TRAIN for(auto element : output_converters_map) { @@ -158,7 +171,7 @@ namespace test { } TrainingFeedback feed_back; ret = instance->GetTrainingFeedback(feed_back); - LOGI("Training step: %d, loss: %f\n", feed_back.global_step_value, feed_back.loss_value); + LogFeedbackStep(feed_back); } #endif // TNN_TRAIN @@ -281,34 +294,55 @@ namespace test { } } - static std::pair> GetInputShape(const std::string message) { - std::string name; - std::vector dims; - if(!message.empty()) { + static std::vector split_message(std::string message, std::string delimiter) { + std::vector result; + size_t pos; + while ((pos = message.find(delimiter)) != std::string::npos) { + result.push_back(message.substr(0, pos)); + message.erase(0, pos + delimiter.length()); + } + if (message.length() != 0) { + result.push_back(message); + } + return result; + } + + static std::vector>> GetInputShapes(const std::string message) { + std::vector>> result; + + auto tokens = split_message(message, " "); + for (const auto &token : tokens) { + std::string name; + std::vector dims; std::string delimiter = "["; - std::ptrdiff_t p1 = 0, p2; - p2 = message.find(delimiter, p1); - name = message.substr(p1, p2 -p1); + size_t p1 = 0, p2; + p2 = token.find(delimiter, p1); + name = token.substr(p1, p2 -p1); p1 = p2 + 1; delimiter = ","; while (true) { - p2 = message.find(delimiter, p1); + p2 = token.find(delimiter, p1); if (p2 != std::string::npos) { - dims.push_back(atoi(message.substr(p1, p2 - p1).c_str())); + dims.push_back(atoi(token.substr(p1, p2 - p1).c_str())); p1 = p2 + 1; } else { - dims.push_back(atoi(message.substr(p1, message.length() - 1 - p1).c_str())); + dims.push_back(atoi(token.substr(p1, token.length() - 1 - p1).c_str())); break; } } + result.emplace_back(name, dims); } - return {name, dims}; + + return result; } InputShapesMap GetInputShapesMap() { InputShapesMap input_shape; if(!FLAGS_is.empty()) { - input_shape.insert(GetInputShape(FLAGS_is)); + auto input_shapes = GetInputShapes(FLAGS_is); + for (auto item : input_shapes) { + input_shape.insert(item); + } } return input_shape; } @@ -401,14 +435,19 @@ namespace test { } if (train_config.loss_func != LOSS_FUNC_DEFAULT) { + // target layers if (!FLAGS_tl.empty()) { - train_config.target_layer = FLAGS_tl; + train_config.target_layers = split_message(FLAGS_tl, " "); } + // auto add prob layer train_config.auto_add_prob_layer = FLAGS_ap; + // target names and target shapes if (!FLAGS_gts.empty()) { - auto target_and_shape = GetInputShape(FLAGS_gts); - train_config.ground_truth_name = target_and_shape.first; - train_config.ground_truth_shape = target_and_shape.second; + auto target_and_shapes = GetInputShapes(FLAGS_gts); + for (const auto &ts : target_and_shapes) { + train_config.ground_truth_names.push_back(ts.first); + train_config.ground_truth_shapes.push_back(ts.second); + } } else { LOGE("Loss layer will be created, please provide target for calculating loss!\n"); }