@@ -435,62 +435,6 @@ void P2PSync<Dtype>::run(shared_ptr<Solver<Dtype> > root,
435435 }
436436}
437437
438- template <typename Dtype>
439- void P2PSync<Dtype>::divide_batch_size(NetParameter* net) {
440- int solver_count = Caffe::solver_count ();
441- for (int i = 0 ; i < net->layer_size (); ++i) {
442- string m = " Batch size must be divisible by the number of solvers (GPUs)" ;
443- if (net->layer (i).has_data_param ()) {
444- if (net->layer (i).data_param ().has_batch_size ()) {
445- uint32_t total = net->layer (i).data_param ().batch_size ();
446- uint32_t batch = total / solver_count;
447- CHECK (batch * solver_count == total) << m;
448- net->mutable_layer (i)->mutable_data_param ()->set_batch_size (batch);
449-
450- // Also adjust the prefetch count, as it is shared by all solvers
451- uint32_t prefetch = net->layer (i).data_param ().prefetch ();
452- net->mutable_layer (i)->mutable_data_param ()->set_prefetch (
453- prefetch * solver_count);
454- }
455- }
456- if (net->layer (i).has_hdf5_data_param ()) {
457- if (net->layer (i).hdf5_data_param ().has_batch_size ()) {
458- uint32_t total = net->layer (i).hdf5_data_param ().batch_size ();
459- uint32_t batch = total / solver_count;
460- CHECK (batch * solver_count == total) << m;
461- net->mutable_layer (i)->mutable_hdf5_data_param ()->set_batch_size (batch);
462- }
463- }
464- if (net->layer (i).has_image_data_param ()) {
465- if (net->layer (i).image_data_param ().has_batch_size ()) {
466- uint32_t total = net->layer (i).image_data_param ().batch_size ();
467- uint32_t batch = total / solver_count;
468- CHECK (batch * solver_count == total) << m;
469- net->mutable_layer (i)->mutable_image_data_param ()->set_batch_size (
470- batch);
471- }
472- }
473- if (net->layer (i).has_memory_data_param ()) {
474- if (net->layer (i).memory_data_param ().has_batch_size ()) {
475- uint32_t total = net->layer (i).memory_data_param ().batch_size ();
476- uint32_t batch = total / solver_count;
477- CHECK (batch * solver_count == total) << m;
478- net->mutable_layer (i)->mutable_memory_data_param ()->set_batch_size (
479- batch);
480- }
481- }
482- if (net->layer (i).has_window_data_param ()) {
483- if (net->layer (i).window_data_param ().has_batch_size ()) {
484- uint32_t total = net->layer (i).window_data_param ().batch_size ();
485- uint32_t batch = total / solver_count;
486- CHECK (batch * solver_count == total) << m;
487- net->mutable_layer (i)->mutable_window_data_param ()->set_batch_size (
488- batch);
489- }
490- }
491- }
492- }
493-
494438INSTANTIATE_CLASS (Params);
495439INSTANTIATE_CLASS (GPUParams);
496440INSTANTIATE_CLASS (P2PSync);
0 commit comments