3232#include " util/threading_context.h"
3333#include " hwy/highway.h"
3434#include " hwy/per_target.h" // DispatchedTarget
35+ #include " hwy/profiler.h" // PROFILER_ENABLED
3536#include " hwy/timer.h"
3637
3738namespace gcpp {
@@ -50,7 +51,7 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
5051GemmaEnv::GemmaEnv (const LoaderArgs& loader, const ThreadingArgs& threading,
5152 const InferenceArgs& inference)
5253 : env_(MakeMatMulEnv(threading, inference)),
53- gemma_ (loader, inference, env_) {
54+ gemma_ (loader, inference, env_.ctx.pools ) {
5455 const ModelConfig& config = gemma_.GetModelConfig ();
5556 // Only allocate one for starters because GenerateBatch might not be called.
5657 kv_caches_.push_back (KVCache (config, inference));
@@ -94,7 +95,7 @@ QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
9495 }
9596 gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
9697 runtime_config_.batch_stream_token = batch_stream_token;
97- gemma_.Generate (runtime_config_, tokens, /* start_pos=*/ 0 , kv_caches_[0 ],
98+ gemma_.Generate (runtime_config_, tokens, /* start_pos=*/ 0 , kv_caches_[0 ], env_,
9899 timing_info);
99100 return result;
100101}
@@ -104,7 +105,7 @@ void GemmaEnv::QueryModel(
104105 gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
105106 const StreamFunc previous_stream_token = runtime_config_.stream_token ;
106107 runtime_config_.stream_token = stream_token;
107- gemma_.Generate (runtime_config_, tokens, /* start_pos=*/ 0 , kv_caches_[0 ],
108+ gemma_.Generate (runtime_config_, tokens, /* start_pos=*/ 0 , kv_caches_[0 ], env_,
108109 timing_info);
109110 runtime_config_.stream_token = previous_stream_token;
110111}
@@ -146,7 +147,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
146147
147148 gcpp::AllQueries all_queries (queries_prompt, kv_caches, prefix_end);
148149 gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity };
149- gemma_.GenerateBatch (runtime_config_, all_queries, timing_info);
150+ gemma_.GenerateBatch (runtime_config_, all_queries, env_, timing_info);
150151 return res;
151152}
152153
@@ -176,7 +177,7 @@ float GemmaEnv::CrossEntropy(const std::string& input) {
176177 std::vector<int > prompt = Tokenize (input);
177178 prompt.insert (prompt.begin (), BOS_ID );
178179 return ComputeCrossEntropy (*GetGemma (), /* max_generated_tokens=*/ 3072 , prompt,
179- MutableKVCache (),
180+ MutableKVCache (), env_,
180181 /* verbosity=*/ 0 ) /
181182 static_cast <int >(input.size ());
182183}
@@ -247,13 +248,13 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
247248 " CPU : %s, bind %d\n "
248249 " CPU topology : %s, %s, %s\n "
249250 " Instruction set : %s (%zu bits)\n "
250- " Compiled config : %s\n "
251- " Memory MiB : %4zu, %4zu free \n " ,
251+ " Compiled config : %s, profiler %d \n "
252+ " Memory MiB : %4zu\n " ,
252253 dt, cpu100, static_cast <int >(threading.bind ),
253254 ctx.topology .TopologyString (), ctx.pools .PinString (),
254255 CacheString ().c_str (), hwy::TargetName (hwy::DispatchedTarget ()),
255- ctx.allocator .VectorBytes () * 8 , CompiledConfig (),
256- ctx.allocator .TotalMiB (), ctx. allocator . FreeMiB () );
256+ ctx.allocator .VectorBytes () * 8 , CompiledConfig (), PROFILER_ENABLED ,
257+ ctx.allocator .TotalMiB ());
257258 }
258259}
259260
0 commit comments