Skip to content

Commit 799c264

Browse files
jan-wassenbergcopybara-github
authored andcommitted
Pre-tune thread pool before matmul
Also improve profiler annotations - remove near-zero ones and add more for startup PiperOrigin-RevId: 789352414
1 parent 50ee1a3 commit 799c264

11 files changed

Lines changed: 93 additions & 20 deletions

File tree

BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ cc_library(
9393
":threading",
9494
":topology",
9595
"@highway//:hwy",
96+
"@highway//:hwy_test_util",
9697
"@highway//:profiler",
9798
],
9899
)
@@ -205,6 +206,7 @@ cc_library(
205206
"//io:blob_store",
206207
"//io:fields",
207208
"@highway//:hwy",
209+
"@highway//:profiler",
208210
"@highway//:thread_pool",
209211
],
210212
)

gemma/gemma.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,7 @@ Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
609609
weights_(model_.Config()),
610610
chat_template_(model_.Tokenizer(), model_.Config().model),
611611
inference_(inference) {
612+
// Negligible CPU time in the ctor body (except ReadFromBlobs).
612613
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference,
613614
mat_owners_, ctx);
614615
// Read everything into memory, or `weights_.mapped_` keeps the mapping alive.

gemma/model_store.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "util/basics.h"
3737
#include "util/threading_context.h"
3838
#include "hwy/base.h"
39+
#include "hwy/profiler.h"
3940

4041
namespace gcpp {
4142

@@ -60,6 +61,8 @@ static void WarnIfExtra(const IFields::ReadResult& result, const char* name) {
6061
// Reads it from a blob or from a separate file if pre-2025.
6162
static std::string ReadTokenizer(BlobReader& reader,
6263
const Path& tokenizer_path) {
64+
PROFILER_ZONE("Startup.ReadTokenizer");
65+
6366
std::string tokenizer;
6467
// Check prevents `CallWithSpan` from printing a warning.
6568
if (reader.Find(kTokenizerName)) {
@@ -306,6 +309,8 @@ bool ModelStore::ReadMatPtrs(BlobReader& reader) {
306309
// Check first to prevent `CallWithSpan` from printing a warning.
307310
if (!reader.Find(kMatPtrsName)) return false;
308311

312+
PROFILER_ZONE("Startup.ReadMatPtrs");
313+
309314
// For verifying `config_.weight`.
310315
size_t min_bits = ~size_t{0};
311316
Type weight_type = Type::kUnknown;

gemma/run.cc

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,14 @@ std::string GetPromptFromStream(std::istream& input, int verbosity,
7878

7979
// Get prompt either from interactive input or command line
8080
std::string GetPrompt(const InferenceArgs& inference) {
81-
PROFILER_ZONE("Gen.input");
8281
// If prompt is provided via command line, use that
83-
if (!inference.prompt.empty()) {
84-
return inference.prompt;
85-
}
82+
if (!inference.prompt.empty()) return inference.prompt;
8683
if (!inference.prompt_file.Empty()) {
84+
PROFILER_ZONE("Gen.ReadPrompt");
8785
return ReadFileToString(inference.prompt_file);
8886
}
8987

88+
PROFILER_ZONE("Gen.input");
9089
return GetPromptFromStream(std::cin, inference.verbosity, inference.eot_line);
9190
}
9291

@@ -299,8 +298,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
299298
int main(int argc, char** argv) {
300299
gcpp::InternalInit();
301300
{
302-
PROFILER_ZONE("Startup.misc");
303-
301+
// Negligible CPU time.
304302
gcpp::LoaderArgs loader(argc, argv);
305303
gcpp::ThreadingArgs threading(argc, argv);
306304
gcpp::InferenceArgs inference(argc, argv);

gemma/tokenizer.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ bool GemmaTokenizer::Decode(const std::vector<int>& ids,
104104
return impl_->Decode(ids, detokenized);
105105
}
106106

107+
// Negligible CPU time in the ctor body.
107108
GemmaChatTemplate::GemmaChatTemplate(const GemmaTokenizer& tokenizer,
108109
Model model) {
109110
sot_user_.reserve(3);

gemma/weights.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,10 +344,9 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
344344
});
345345
}
346346

347-
// Mode == kMap
347+
// Mode == kMap. CPU time is negligible.
348348
static void MapAll(const std::vector<TensorToRead>& tensors,
349349
const MapPtr& mapped, uint64_t file_bytes) {
350-
PROFILER_ZONE("Startup.Weights.Map");
351350
for (size_t i = 0; i < tensors.size(); ++i) {
352351
// SetPtr does not change the stride, but it is expected to be packed
353352
// because that is what Compress() writes to the file.
@@ -521,6 +520,8 @@ WeightsPtrs::Mode WeightsPtrs::ReadFromBlobs(const ModelStore& model,
521520
const InferenceArgs& inference,
522521
std::vector<MatOwner>& mat_owners,
523522
ThreadingContext& ctx) {
523+
PROFILER_ZONE("Startup.ReadFromBlobs");
524+
524525
// List of tensors to read/map, and where from.
525526
std::vector<TensorToRead> tensors;
526527

io/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ cc_library(
7474
"//:basics",
7575
"//:threading_context",
7676
"@highway//:hwy",
77+
"@highway//:profiler",
7778
"@highway//:thread_pool",
7879
],
7980
)

io/blob_store.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "hwy/base.h"
3131
#include "hwy/contrib/thread_pool/thread_pool.h"
3232
#include "hwy/detect_compiler_arch.h"
33+
#include "hwy/profiler.h"
3334

3435
namespace gcpp {
3536

@@ -413,10 +414,11 @@ class BlobStore {
413414
std::vector<hwy::uint128_t> directory_; // two per blob, see `SetRange`.
414415
}; // BlobStore
415416

416-
BlobReader::BlobReader(const Path& blob_path)
417-
: blob_path_(blob_path),
418-
file_(OpenFileOrAbort(blob_path, "r")),
419-
file_bytes_(file_->FileSize()) {
417+
BlobReader::BlobReader(const Path& blob_path) : blob_path_(blob_path) {
418+
PROFILER_ZONE("Startup.BlobReader");
419+
420+
file_ = OpenFileOrAbort(blob_path, "r");
421+
file_bytes_ = file_->FileSize();
420422
if (file_bytes_ == 0) HWY_ABORT("Zero-sized file %s", blob_path.path.c_str());
421423

422424
BlobStore bs(*file_);

io/blob_store.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class BlobReader {
104104
private:
105105
Path blob_path_;
106106
std::unique_ptr<File> file_;
107-
const uint64_t file_bytes_;
107+
uint64_t file_bytes_; // const after ctor
108108

109109
std::vector<std::string> keys_;
110110
std::vector<BlobRange> ranges_;

paligemma/BUILD.bazel

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,13 @@ cc_library(
3434
srcs = ["paligemma_helper.cc"],
3535
hdrs = ["paligemma_helper.h"],
3636
deps = [
37-
":image",
3837
"//:allocator",
3938
"//:benchmark_helper",
4039
"//:configs",
4140
"//:gemma_args",
4241
"//:gemma_lib",
4342
"//compression:types",
44-
"//io",
4543
"@highway//:hwy",
46-
"@highway//:profiler",
4744
],
4845
)
4946

@@ -59,15 +56,12 @@ cc_test(
5956
],
6057
deps = [
6158
":paligemma_helper",
62-
"//devtools/build/runtime:get_runfiles_dir",
6359
"@googletest//:gtest_main", # buildcleaner: keep
6460
"//:allocator",
6561
"//:benchmark_helper",
6662
"//:configs",
6763
"//:gemma_lib",
68-
"//compression:types",
6964
"//io",
70-
"@highway//:hwy",
7165
"@highway//:hwy_test_util",
7266
],
7367
)

0 commit comments

Comments
 (0)