diff --git a/kernels/portable/cpu/op_var_mean.cpp b/kernels/portable/cpu/op_var_mean.cpp index 6e72b268610..dbb88f2f32d 100644 --- a/kernels/portable/cpu/op_var_mean.cpp +++ b/kernels/portable/cpu/op_var_mean.cpp @@ -35,30 +35,65 @@ void compute_var_mean( mean_data[out_ix] = NAN; } } else if (in.numel() > 0) { - MapReduceOverDimListPlan plan(in, dim_list); - const bool success = parallel_for_each_reduce_over_dim_list_output_index( - in, dim_list, var_out, [&](const auto begin, const auto end) { - for (const auto out_ix : c10::irange(begin, end)) { - // Pass 1: compute sum -> mean - CTYPE_OUT sum = plan.execute( - [](CTYPE_IN v) { return static_cast(v); }, - [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; }, - out_ix); - CTYPE_OUT mean = sum / static_cast(num); - mean_data[out_ix] = mean; - // Pass 2: compute sum of squared deviations - CTYPE_OUT sum2 = plan.execute( - [mean](CTYPE_IN v) { - return ( - (static_cast(v) - mean) * - (static_cast(v) - mean)); - }, - [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; }, - out_ix); - var_data[out_ix] = sum2 / denominator; + // Fast path: contiguous tensor, single innermost dim reduction, same dtype. + bool used_fast_path = false; + if (dim_list.has_value() && dim_list.value().size() == 1 && + in.scalar_type() == var_out.scalar_type()) { + const int64_t d = dim_list.value()[0] < 0 ? dim_list.value()[0] + in.dim() + : dim_list.value()[0]; + if (d >= 0 && d < in.dim() && d == in.dim() - 1 && + tensor_is_contiguous(in)) { + used_fast_path = true; + const int64_t reduce_size = in.size(d); + const int64_t outer_size = in.numel() / reduce_size; + const CTYPE_OUT cnum = static_cast(num); + const CTYPE_OUT cdenom = static_cast(denominator); + const CTYPE_IN* in_data = in.const_data_ptr(); + for (int64_t i = 0; i < outer_size; i++) { + const CTYPE_IN* row = in_data + i * reduce_size; + // Pass 1: compute mean + CTYPE_OUT sum = 0; + for (int64_t j = 0; j < reduce_size; j++) { + sum += static_cast(row[j]); } - }); - ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed"); + CTYPE_OUT mean = sum / cnum; + mean_data[i] = mean; + // Pass 2: compute variance + CTYPE_OUT sum2 = 0; + for (int64_t j = 0; j < reduce_size; j++) { + CTYPE_OUT diff = static_cast(row[j]) - mean; + sum2 += diff * diff; + } + var_data[i] = sum2 / cdenom; + } + } + } + if (!used_fast_path) { + MapReduceOverDimListPlan plan(in, dim_list); + const bool success = parallel_for_each_reduce_over_dim_list_output_index( + in, dim_list, var_out, [&](const auto begin, const auto end) { + for (const auto out_ix : c10::irange(begin, end)) { + // Pass 1: compute sum -> mean + CTYPE_OUT sum = plan.execute( + [](CTYPE_IN v) { return static_cast(v); }, + [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; }, + out_ix); + CTYPE_OUT mean = sum / static_cast(num); + mean_data[out_ix] = mean; + // Pass 2: compute sum of squared deviations + CTYPE_OUT sum2 = plan.execute( + [mean](CTYPE_IN v) { + return ( + (static_cast(v) - mean) * + (static_cast(v) - mean)); + }, + [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; }, + out_ix); + var_data[out_ix] = sum2 / denominator; + } + }); + ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed"); + } // !used_fast_path } }