Skip to content

Commit 6ea7f7d

Browse files
committed
Removed batch size division by number of devices
1 parent 9b10508 commit 6ea7f7d

4 files changed

Lines changed: 2 additions & 63 deletions

File tree

include/caffe/parallel.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,6 @@ class P2PSync : public GPUParams<Dtype>, public Solver<Dtype>::Callback,
9595

9696
static void run(shared_ptr<Solver<Dtype> > root, const vector<int>& gpus);
9797

98-
// Divide the batch size by the number of solvers
99-
static void divide_batch_size(NetParameter* net);
100-
10198
protected:
10299
void on_start(Timer* timer, ostringstream* timing);
103100
void on_gradients_ready(Timer* timer, ostringstream* timing);

src/caffe/net.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
4040
// the current NetState.
4141
NetParameter filtered_param;
4242
FilterNet(in_param, &filtered_param);
43-
if (phase_ == TRAIN) {
44-
caffe::P2PSync<Dtype>::divide_batch_size(&filtered_param);
45-
}
4643
if (Caffe::root_solver()) {
4744
LOG(INFO) << "Initializing net from parameters: " << std::endl
4845
<< filtered_param.DebugString();

src/caffe/parallel.cpp

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -435,62 +435,6 @@ void P2PSync<Dtype>::run(shared_ptr<Solver<Dtype> > root,
435435
}
436436
}
437437

438-
template<typename Dtype>
439-
void P2PSync<Dtype>::divide_batch_size(NetParameter* net) {
440-
int solver_count = Caffe::solver_count();
441-
for (int i = 0; i < net->layer_size(); ++i) {
442-
string m = "Batch size must be divisible by the number of solvers (GPUs)";
443-
if (net->layer(i).has_data_param()) {
444-
if (net->layer(i).data_param().has_batch_size()) {
445-
uint32_t total = net->layer(i).data_param().batch_size();
446-
uint32_t batch = total / solver_count;
447-
CHECK(batch * solver_count == total) << m;
448-
net->mutable_layer(i)->mutable_data_param()->set_batch_size(batch);
449-
450-
// Also adjust the prefetch count, as it is shared by all solvers
451-
uint32_t prefetch = net->layer(i).data_param().prefetch();
452-
net->mutable_layer(i)->mutable_data_param()->set_prefetch(
453-
prefetch * solver_count);
454-
}
455-
}
456-
if (net->layer(i).has_hdf5_data_param()) {
457-
if (net->layer(i).hdf5_data_param().has_batch_size()) {
458-
uint32_t total = net->layer(i).hdf5_data_param().batch_size();
459-
uint32_t batch = total / solver_count;
460-
CHECK(batch * solver_count == total) << m;
461-
net->mutable_layer(i)->mutable_hdf5_data_param()->set_batch_size(batch);
462-
}
463-
}
464-
if (net->layer(i).has_image_data_param()) {
465-
if (net->layer(i).image_data_param().has_batch_size()) {
466-
uint32_t total = net->layer(i).image_data_param().batch_size();
467-
uint32_t batch = total / solver_count;
468-
CHECK(batch * solver_count == total) << m;
469-
net->mutable_layer(i)->mutable_image_data_param()->set_batch_size(
470-
batch);
471-
}
472-
}
473-
if (net->layer(i).has_memory_data_param()) {
474-
if (net->layer(i).memory_data_param().has_batch_size()) {
475-
uint32_t total = net->layer(i).memory_data_param().batch_size();
476-
uint32_t batch = total / solver_count;
477-
CHECK(batch * solver_count == total) << m;
478-
net->mutable_layer(i)->mutable_memory_data_param()->set_batch_size(
479-
batch);
480-
}
481-
}
482-
if (net->layer(i).has_window_data_param()) {
483-
if (net->layer(i).window_data_param().has_batch_size()) {
484-
uint32_t total = net->layer(i).window_data_param().batch_size();
485-
uint32_t batch = total / solver_count;
486-
CHECK(batch * solver_count == total) << m;
487-
net->mutable_layer(i)->mutable_window_data_param()->set_batch_size(
488-
batch);
489-
}
490-
}
491-
}
492-
}
493-
494438
INSTANTIATE_CLASS(Params);
495439
INSTANTIATE_CLASS(GPUParams);
496440
INSTANTIATE_CLASS(P2PSync);

tools/caffe.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ using std::ostringstream;
2121

2222
DEFINE_string(gpu, "",
2323
"Optional; run in GPU mode on given device IDs separated by ','."
24-
"Use '-gpu all' to run on all available GPUs.");
24+
"Use '-gpu all' to run on all available GPUs. The effective training "
25+
"batch size is multiplied by the number of devices.");
2526
DEFINE_string(solver, "",
2627
"The solver definition protocol buffer text file.");
2728
DEFINE_string(model, "",

0 commit comments

Comments
 (0)