Skip to content

Commit 9a07770

Browse files
jan-wassenbergcopybara-github
authored andcommitted
Move MatMulEnv out of Gemma to enable concurrent calls
Also update benchmark_helper config print: add profiler, remove free mem PiperOrigin-RevId: 773301086
1 parent 7630ec0 commit 9a07770

17 files changed

Lines changed: 99 additions & 92 deletions

File tree

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,7 @@ cc_library(
551551
"//compression:compress",
552552
"@highway//:hwy",
553553
"@highway//:nanobenchmark",
554+
"@highway//:profiler",
554555
],
555556
)
556557

evals/benchmark.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,9 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
7575
std::vector<int> prompt_slice(prompt.begin() + pos,
7676
prompt.begin() + pos + num_tokens);
7777
KVCache kv_cache(gemma.GetModelConfig(), gemma.Inference());
78-
float entropy = ComputeCrossEntropy(
79-
*env.GetGemma(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
78+
float entropy =
79+
ComputeCrossEntropy(*env.GetGemma(), num_tokens, prompt_slice, kv_cache,
80+
env.MutableEnv(), env.Verbosity());
8081
total_entropy += entropy;
8182
LogSpeedStats(time_start, pos + num_tokens);
8283
std::string text_slice = env.StringFromTokens(prompt_slice);

evals/benchmark_helper.cc

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "util/threading_context.h"
3333
#include "hwy/highway.h"
3434
#include "hwy/per_target.h" // DispatchedTarget
35+
#include "hwy/profiler.h" // PROFILER_ENABLED
3536
#include "hwy/timer.h"
3637

3738
namespace gcpp {
@@ -50,7 +51,7 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
5051
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
5152
const InferenceArgs& inference)
5253
: env_(MakeMatMulEnv(threading, inference)),
53-
gemma_(loader, inference, env_) {
54+
gemma_(loader, inference, env_.ctx.pools) {
5455
const ModelConfig& config = gemma_.GetModelConfig();
5556
// Only allocate one for starters because GenerateBatch might not be called.
5657
kv_caches_.push_back(KVCache(config, inference));
@@ -94,7 +95,7 @@ QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
9495
}
9596
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
9697
runtime_config_.batch_stream_token = batch_stream_token;
97-
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
98+
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
9899
timing_info);
99100
return result;
100101
}
@@ -104,7 +105,7 @@ void GemmaEnv::QueryModel(
104105
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
105106
const StreamFunc previous_stream_token = runtime_config_.stream_token;
106107
runtime_config_.stream_token = stream_token;
107-
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
108+
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
108109
timing_info);
109110
runtime_config_.stream_token = previous_stream_token;
110111
}
@@ -146,7 +147,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
146147

147148
gcpp::AllQueries all_queries(queries_prompt, kv_caches, prefix_end);
148149
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
149-
gemma_.GenerateBatch(runtime_config_, all_queries, timing_info);
150+
gemma_.GenerateBatch(runtime_config_, all_queries, env_, timing_info);
150151
return res;
151152
}
152153

@@ -176,7 +177,7 @@ float GemmaEnv::CrossEntropy(const std::string& input) {
176177
std::vector<int> prompt = Tokenize(input);
177178
prompt.insert(prompt.begin(), BOS_ID);
178179
return ComputeCrossEntropy(*GetGemma(), /*max_generated_tokens=*/3072, prompt,
179-
MutableKVCache(),
180+
MutableKVCache(), env_,
180181
/*verbosity=*/0) /
181182
static_cast<int>(input.size());
182183
}
@@ -247,13 +248,13 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
247248
"CPU : %s, bind %d\n"
248249
"CPU topology : %s, %s, %s\n"
249250
"Instruction set : %s (%zu bits)\n"
250-
"Compiled config : %s\n"
251-
"Memory MiB : %4zu, %4zu free\n",
251+
"Compiled config : %s, profiler %d\n"
252+
"Memory MiB : %4zu\n",
252253
dt, cpu100, static_cast<int>(threading.bind),
253254
ctx.topology.TopologyString(), ctx.pools.PinString(),
254255
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
255-
ctx.allocator.VectorBytes() * 8, CompiledConfig(),
256-
ctx.allocator.TotalMiB(), ctx.allocator.FreeMiB());
256+
ctx.allocator.VectorBytes() * 8, CompiledConfig(), PROFILER_ENABLED,
257+
ctx.allocator.TotalMiB());
257258
}
258259
}
259260

evals/benchmark_helper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ class GemmaEnv {
112112
RuntimeConfig& MutableConfig() { return runtime_config_; }
113113
std::mt19937& MutableGen() { return gen_; }
114114
KVCache& MutableKVCache() { return kv_caches_[0]; }
115+
MatMulEnv& MutableEnv() { return env_; }
115116

116117
private:
117118
MatMulEnv env_;

evals/cross_entropy.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ HWY_EXPORT(CallSoftmax);
9999

100100
float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
101101
const std::vector<int>& prompt, KVCache& kv_cache,
102-
int verbosity) {
102+
MatMulEnv& env, int verbosity) {
103103
const StreamFunc stream_token = [](int, float) { return true; };
104104

105105
const int vocab_size = gemma.GetModelConfig().vocab_size;
@@ -145,7 +145,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
145145
};
146146
TimingInfo timing_info;
147147

148-
gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info);
148+
gemma.Generate(runtime, prompt0, 0, kv_cache, env, timing_info);
149149

150150
const float scale = 1.0f / std::log(2.0f);
151151
return cross_entropy * scale;

evals/cross_entropy.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace gcpp {
2626

2727
float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
2828
const std::vector<int>& prompt, KVCache& kv_cache,
29-
int verbosity);
29+
MatMulEnv& env, int verbosity);
3030

3131
} // namespace gcpp
3232

evals/gemma_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ TEST_F(GemmaTest, Multiturn) {
127127
config.wrapping, abs_pos, mutable_prompt);
128128

129129
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
130-
timing_info);
130+
s_env->MutableEnv(), timing_info);
131131
// Note: we do not rewind any <end_of_turn> tokens here. If the model
132132
// produced one and WrapAndTokenize() inserts another one, it will just be
133133
// duplicated.
@@ -139,7 +139,7 @@ TEST_F(GemmaTest, Multiturn) {
139139
// access to the previous turn by asking to reproduce.
140140
response.clear();
141141
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
142-
timing_info);
142+
s_env->MutableEnv(), timing_info);
143143
fprintf(stderr, "decoded: '%s'\n", response.c_str());
144144
bool remembered_turquoise =
145145
response.find("turquoise") != std::string::npos; // NOLINT

evals/run_mmlu.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ void Run(GemmaEnv& env, JsonArgs& json) {
131131
.stream_token = stream_token,
132132
};
133133
env.GetGemma()->Generate(runtime_config, prompt, /*pos=*/0,
134-
env.MutableKVCache(), timing_info);
134+
env.MutableKVCache(), env.MutableEnv(),
135+
timing_info);
135136

136137
std::string output_string = env.StringFromTokens(predicted_token_ids);
137138
fprintf(stderr, "Correct %s, model '%s'\n", correct_answer.c_str(),

examples/hello_world/run.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ int main(int argc, char** argv) {
5252

5353
// Instantiate model and KV Cache
5454
gcpp::MatMulEnv env(MakeMatMulEnv(threading, inference));
55-
gcpp::Gemma gemma(loader, inference, env);
55+
gcpp::Gemma gemma(loader, inference, env.ctx.pools);
5656
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference);
5757
size_t generated = 0;
5858

@@ -93,5 +93,5 @@ int main(int argc, char** argv) {
9393
return !reject_tokens.contains(token);
9494
},
9595
};
96-
gemma.Generate(runtime_config, tokens, 0, kv_cache, timing_info);
96+
gemma.Generate(runtime_config, tokens, 0, kv_cache, env, timing_info);
9797
}

examples/simplified_gemma/gemma.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class SimplifiedGemma {
3636
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
3737
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
3838
: env_(MakeMatMulEnv(threading, inference)),
39-
gemma_(loader, inference, env_),
39+
gemma_(loader, inference, env_.ctx.pools),
4040
kv_cache_(gemma_.GetModelConfig(), inference) {
4141
// Initialize random number generator
4242
std::random_device rd;
@@ -83,7 +83,7 @@ class SimplifiedGemma {
8383
return !reject_tokens.contains(token);
8484
},
8585
};
86-
gemma_.Generate(runtime_config, tokens, 0, kv_cache_, timing_info);
86+
gemma_.Generate(runtime_config, tokens, 0, kv_cache_, env_, timing_info);
8787
}
8888
~SimplifiedGemma() = default;
8989

0 commit comments

Comments
 (0)