Skip to content

Commit 4f5785b

Browse files
jan-wassenbergcopybara-github
authored andcommitted
Update instrumentation for new Highway wall-time profiler
Pass the thread index through and use new zone_id. PiperOrigin-RevId: 773344242
1 parent 1665ecc commit 4f5785b

8 files changed

Lines changed: 97 additions & 57 deletions

File tree

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
2222
set(CMAKE_CXX_STANDARD_REQUIRED ON)
2323
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
2424

25-
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 12d9fa908e0c1d3346c298d472584687a24e4ce6 EXCLUDE_FROM_ALL)
25+
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 01019e979cd098f2ee618f39bb6718f1b4a3d901 EXCLUDE_FROM_ALL)
2626
FetchContent_MakeAvailable(highway)
2727

2828
## Note: absl needs to be installed by sentencepiece. This will only happen if

MODULE.bazel

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5")
1818
# Require a more recent version.
1919
git_override(
2020
module_name = "highway",
21-
commit = "12d9fa908e0c1d3346c298d472584687a24e4ce6",
21+
commit = "01019e979cd098f2ee618f39bb6718f1b4a3d901",
2222
remote = "https://github.com/google/highway",
2323
)
2424

@@ -71,6 +71,7 @@ pip.parse(
7171
requirements_lock = "//compression/python:requirements.txt",
7272
)
7373
use_repo(pip, "compression_deps")
74+
7475
pip.parse(
7576
hub_name = "python_deps",
7677
python_version = "3.11",

examples/hello_world/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
1818
set(CMAKE_CXX_STANDARD_REQUIRED ON)
1919

2020
include(FetchContent)
21-
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 12d9fa908e0c1d3346c298d472584687a24e4ce6)
21+
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 01019e979cd098f2ee618f39bb6718f1b4a3d901)
2222
FetchContent_MakeAvailable(highway)
2323
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
2424
FetchContent_MakeAvailable(sentencepiece)
@@ -32,7 +32,7 @@ if (NOT BUILD_MODE)
3232
endif()
3333
if (BUILD_MODE STREQUAL "local")
3434
# Relative path to gemma.cpp from examples/hello_world/build/
35-
FetchContent_Declare(gemma SOURCE_DIR ../../..)
35+
FetchContent_Declare(gemma SOURCE_DIR ../../..)
3636
else()
3737
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e)
3838
endif()

examples/simplified_gemma/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
1818
set(CMAKE_CXX_STANDARD_REQUIRED ON)
1919

2020
include(FetchContent)
21-
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 12d9fa908e0c1d3346c298d472584687a24e4ce6)
21+
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 01019e979cd098f2ee618f39bb6718f1b4a3d901)
2222
FetchContent_MakeAvailable(highway)
2323
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
2424
FetchContent_MakeAvailable(sentencepiece)
@@ -32,7 +32,7 @@ if (NOT BUILD_MODE)
3232
endif()
3333
if (BUILD_MODE STREQUAL "local")
3434
# Relative path to gemma.cpp from examples/simplified_gemma/build/
35-
FetchContent_Declare(gemma SOURCE_DIR ../../..)
35+
FetchContent_Declare(gemma SOURCE_DIR ../../..)
3636
else()
3737
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e)
3838
endif()

gemma/attention.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,10 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
164164
const LayerWeightsPtrs& layer,
165165
AttentionActivations& activations, QBatch& qbatch,
166166
NestedPools& pools) {
167-
PROFILER_ZONE("Gen.Attention.DotSoftmax");
167+
PROFILER_ZONE("Gen.Attention.DotSoftmax.misc");
168+
static const uint32_t HWY_MAYBE_UNUSED zone_id_par =
169+
PROFILER_ADD_ZONE("Gen.Attention.DotSoftmax.par");
170+
168171
const hwy::Divisor div_qbatch(qbatch.Size());
169172
const LayerConfig& layer_config = layer.layer_config;
170173
const size_t qkv_dim = layer_config.qkv_dim;
@@ -186,9 +189,12 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
186189
ParallelizeOneRange(
187190
tq_ranges, pools.AllPackages(),
188191
[&](const IndexRange& tq_range, const size_t pkg_idx) {
192+
const size_t pkg_base = pkg_idx * pools.MaxWorkersPerPackage();
189193
pools.AllClusters(pkg_idx).Run(
190194
tq_range.begin(), tq_range.end(),
191195
[&](const size_t tq_idx, const size_t cluster_idx) {
196+
const HWY_MAYBE_UNUSED size_t cluster_base =
197+
pkg_base + cluster_idx * pools.MaxWorkersPerCluster();
192198
const size_t qi = div_qbatch.Remainder(tq_idx);
193199
const size_t batch_idx = div_qbatch.Divide(tq_idx);
194200
auto& kv_cache = qbatch.KV(qi).kv_cache;
@@ -209,6 +215,11 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
209215
.Run(
210216
0, layer_config.heads,
211217
[&](const size_t head, size_t thread) HWY_ATTR {
218+
#if PROFILER_ENABLED
219+
const hwy::Zone zone(cluster_base + thread,
220+
zone_id_par);
221+
#endif
222+
212223
const size_t head_offset =
213224
(head / kHeadGroups) * qkv_dim * 2;
214225

gemma/weights.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -385,9 +385,8 @@ static void DecompressToBF16(MatPtr& mat,
385385

386386
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
387387
const BlobReader& reader, hwy::ThreadPool& pool) {
388-
PROFILER_ZONE("Startup.Weights.ReadBF16");
389-
390-
pool.Run(0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
388+
pool.Run(0, tensors.size(), [&](uint64_t task, size_t thread) {
389+
PROFILER_ZONE2(thread, "Startup.Weights.ReadBF16");
391390
const TensorToRead& tensor = tensors[task];
392391
MatPtr& mat = *tensor.mat;
393392

@@ -465,9 +464,9 @@ static std::vector<IOBatch> MakeBatches(
465464
static void ReadBatches(const BlobReader& reader,
466465
const std::vector<IOBatch>& batches,
467466
hwy::ThreadPool& pool) {
468-
PROFILER_ZONE("Startup.Weights.Read");
469467
// >5x speedup from parallel reads when cached.
470-
pool.Run(0, batches.size(), [&](uint64_t i, size_t /*thread*/) {
468+
pool.Run(0, batches.size(), [&](uint64_t i, size_t thread) {
469+
PROFILER_ZONE2(thread, "Startup.Weights.Read");
471470
const IOBatch& batch = batches[i];
472471
const std::string& key = reader.Keys()[batch.KeyIdx()];
473472
const uint64_t bytes_read = batch.Read(reader.file());

ops/matmul-inl.h

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -875,8 +875,9 @@ class MMPerPackage {
875875
inner_tasks_(config.InnerTasks()),
876876
out_(config.Out()),
877877
line_bytes_(args.env->ctx.allocator.LineBytes()) {
878+
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DecompressA");
878879
MMZone zone;
879-
zone.MaybeEnter("MM.DecompressA", args_);
880+
zone.MaybeEnter(pkg_idx, zone_id, args_);
880881
A_ = DecompressA(A);
881882
}
882883

@@ -914,8 +915,7 @@ class MMPerPackage {
914915
// Single M and K ranges, parallel N. Fills all of C directly.
915916
template <typename TB, typename TC>
916917
HWY_INLINE void DoNT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
917-
MMZone zone;
918-
zone.MaybeEnter("MM.NT", args_);
918+
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT");
919919
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
920920
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
921921
const IndexRange& range_M = ranges_mc_.Range(0);
@@ -928,7 +928,10 @@ class MMPerPackage {
928928
// Similar to `loop_nc` below, but here we hoisted `A_view`.
929929
args_.env->parallel.ForNP(
930930
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
931-
[&](const IndexRange& range_nc) HWY_ATTR {
931+
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
932+
MMZone zone;
933+
zone.MaybeEnter(worker, zone_id, args_);
934+
932935
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
933936
const StridedViewBF B_storage_view(B_storage, K, B_stride);
934937

@@ -947,8 +950,7 @@ class MMPerPackage {
947950
// Single M range, parallel N, sequential K. Fills all of partial.
948951
template <typename TB, typename TC>
949952
HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
950-
MMZone zone;
951-
zone.MaybeEnter("MM.NT_K", args_);
953+
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_K");
952954
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
953955
const IndexRange& range_mc = ranges_mc_.Range(0);
954956

@@ -975,7 +977,10 @@ class MMPerPackage {
975977

976978
args_.env->parallel.ForNP(
977979
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
978-
[&](const IndexRange& range_nc) HWY_ATTR {
980+
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
981+
MMZone zone;
982+
zone.MaybeEnter(worker, zone_id, args_);
983+
979984
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
980985

981986
// Peel off the first iteration of the kc loop: avoid
@@ -988,14 +993,17 @@ class MMPerPackage {
988993
});
989994
});
990995

991-
MMZone fill_zone;
992996
if (out_ == MMOut::kCopy) {
993-
fill_zone.MaybeEnter("MM.NT_K.FillC", args_);
997+
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_K.FillC.Copy");
998+
MMZone fill_zone;
999+
fill_zone.MaybeEnter(0, zone_id, args_);
9941000
MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C_rows);
9951001
} else if (out_ == MMOut::kParM) {
996-
fill_zone.MaybeEnter("MM.NT_K.FillC.ParM", args_);
1002+
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_K.FillC.ParM");
9971003
args_.env->parallel.ForRangeMC(
998-
range_mc, pkg_idx_, [&](size_t row_a) HWY_ATTR {
1004+
range_mc, pkg_idx_, [&](size_t row_a, size_t worker) HWY_ATTR {
1005+
MMZone fill_zone;
1006+
fill_zone.MaybeEnter(worker, zone_id, args_);
9991007
MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_,
10001008
args_, C_rows);
10011009
});
@@ -1008,8 +1016,7 @@ class MMPerPackage {
10081016
// Fills `mc x nc` sections of C directly, in parallel.
10091017
template <typename TB, typename TC>
10101018
HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
1011-
MMZone zone;
1012-
zone.MaybeEnter("MM.NT_MT", args_);
1019+
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_MT");
10131020
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
10141021
const IndexRange& range_K = ranges_kc_.Range(0);
10151022
const size_t K = range_K.Num();
@@ -1020,7 +1027,11 @@ class MMPerPackage {
10201027
// except for the profiler strings and `out_tag`.
10211028
args_.env->parallel.ForRangesMC_NC(
10221029
ranges_mc_, ranges_nc_, pkg_idx_,
1023-
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
1030+
[&](const IndexRange& range_mc, const IndexRange& range_nc,
1031+
size_t worker) HWY_ATTR {
1032+
MMZone zone;
1033+
zone.MaybeEnter(worker, zone_id, args_);
1034+
10241035
const StridedViewBF& A_view = A_.View(range_mc.begin(), 0, K);
10251036
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
10261037
const StridedViewBF B_storage_view(B_storage, K, B_stride);
@@ -1041,8 +1052,8 @@ class MMPerPackage {
10411052
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
10421053
template <typename TB, typename TC>
10431054
HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
1044-
MMZone zone;
1045-
zone.MaybeEnter("MM.NT_MT_K", args_);
1055+
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.NT_MT_K");
1056+
static const uint32_t fill_zone_id = PROFILER_ADD_ZONE("MM.NT_MT_K.FillC");
10461057
const size_t kc_max = ranges_kc_.TaskSize();
10471058
HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
10481059
const size_t B_stride =
@@ -1068,7 +1079,11 @@ class MMPerPackage {
10681079
}; // loop_nc
10691080
args_.env->parallel.ForRangesMC_NC(
10701081
ranges_mc_, ranges_nc_, pkg_idx_,
1071-
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
1082+
[&](const IndexRange& range_mc, const IndexRange& range_nc,
1083+
size_t worker) HWY_ATTR {
1084+
MMZone zone;
1085+
zone.MaybeEnter(worker, zone_id, args_);
1086+
10721087
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
10731088
const StridedViewBF B_storage_view(B_storage, kc_max, B_stride);
10741089

@@ -1087,7 +1102,7 @@ class MMPerPackage {
10871102
// `kDirect` is only used with `kNT_MT`.
10881103
HWY_DASSERT(out_ == MMOut::kCopy);
10891104
MMZone fill_zone;
1090-
fill_zone.MaybeEnter("MM.NT_MT_K.FillC", args_);
1105+
fill_zone.MaybeEnter(worker, fill_zone_id, args_);
10911106
MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C_rows);
10921107
});
10931108
}
@@ -1139,13 +1154,16 @@ class MMPerPackage {
11391154

11401155
args_.env->parallel.ForNP(
11411156
all_K, multiple_K, inner_tasks, pkg_idx_,
1142-
[&](const IndexRange& range_K) { do_range(all_M, range_K); });
1157+
[&](const IndexRange& range_K, size_t /*worker*/) {
1158+
do_range(all_M, range_K);
1159+
});
11431160
break;
11441161
}
11451162
case MMParA::kM:
1146-
args_.env->parallel.ForRangeMC(all_M, pkg_idx_, [&](size_t row_a) {
1147-
do_range(IndexRange(row_a, row_a + 1), all_K);
1148-
});
1163+
args_.env->parallel.ForRangeMC(
1164+
all_M, pkg_idx_, [&](size_t row_a, size_t /*worker*/) {
1165+
do_range(IndexRange(row_a, row_a + 1), all_K);
1166+
});
11491167
break;
11501168
}
11511169
}
@@ -1261,12 +1279,13 @@ struct MMImpl {
12611279
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
12621280
RowPtrs<TC> C_rows, const MMArgs& args,
12631281
const MMConfig& config) {
1264-
MMZone matmul_zone;
1265-
matmul_zone.MaybeEnter("MM.DoMatMul", args);
1282+
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DoMatMul");
12661283

12671284
// Outermost loop: static NUMA-aware partition of B rows across packages.
12681285
args.env->parallel.ForPkg(
12691286
args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) {
1287+
MMZone matmul_zone;
1288+
matmul_zone.MaybeEnter(pkg_idx, zone_id, args);
12701289
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
12711290
MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows);
12721291
});

0 commit comments

Comments
 (0)