Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ cc_library(
"//compression:compress",
"@highway//:hwy",
"@highway//:nanobenchmark",
"@highway//:profiler",
],
)

Expand Down
5 changes: 3 additions & 2 deletions evals/benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
std::vector<int> prompt_slice(prompt.begin() + pos,
prompt.begin() + pos + num_tokens);
KVCache kv_cache(gemma.GetModelConfig(), gemma.Inference());
float entropy = ComputeCrossEntropy(
*env.GetGemma(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
float entropy =
ComputeCrossEntropy(*env.GetGemma(), num_tokens, prompt_slice, kv_cache,
env.MutableEnv(), env.Verbosity());
total_entropy += entropy;
LogSpeedStats(time_start, pos + num_tokens);
std::string text_slice = env.StringFromTokens(prompt_slice);
Expand Down
19 changes: 10 additions & 9 deletions evals/benchmark_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "util/threading_context.h"
#include "hwy/highway.h"
#include "hwy/per_target.h" // DispatchedTarget
#include "hwy/profiler.h" // PROFILER_ENABLED
#include "hwy/timer.h"

namespace gcpp {
Expand All @@ -50,7 +51,7 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference)
: env_(MakeMatMulEnv(threading, inference)),
gemma_(loader, inference, env_) {
gemma_(loader, inference, env_.ctx.pools) {
const ModelConfig& config = gemma_.GetModelConfig();
// Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.push_back(KVCache(config, inference));
Expand Down Expand Up @@ -94,7 +95,7 @@ QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
}
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
runtime_config_.batch_stream_token = batch_stream_token;
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
timing_info);
return result;
}
Expand All @@ -104,7 +105,7 @@ void GemmaEnv::QueryModel(
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
const StreamFunc previous_stream_token = runtime_config_.stream_token;
runtime_config_.stream_token = stream_token;
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
timing_info);
runtime_config_.stream_token = previous_stream_token;
}
Expand Down Expand Up @@ -146,7 +147,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(

gcpp::AllQueries all_queries(queries_prompt, kv_caches, prefix_end);
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
gemma_.GenerateBatch(runtime_config_, all_queries, timing_info);
gemma_.GenerateBatch(runtime_config_, all_queries, env_, timing_info);
return res;
}

Expand Down Expand Up @@ -176,7 +177,7 @@ float GemmaEnv::CrossEntropy(const std::string& input) {
std::vector<int> prompt = Tokenize(input);
prompt.insert(prompt.begin(), BOS_ID);
return ComputeCrossEntropy(*GetGemma(), /*max_generated_tokens=*/3072, prompt,
MutableKVCache(),
MutableKVCache(), env_,
/*verbosity=*/0) /
static_cast<int>(input.size());
}
Expand Down Expand Up @@ -247,13 +248,13 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
"CPU : %s, bind %d\n"
"CPU topology : %s, %s, %s\n"
"Instruction set : %s (%zu bits)\n"
"Compiled config : %s\n"
"Memory MiB : %4zu, %4zu free\n",
"Compiled config : %s, profiler %d\n"
"Memory MiB : %4zu\n",
dt, cpu100, static_cast<int>(threading.bind),
ctx.topology.TopologyString(), ctx.pools.PinString(),
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
ctx.allocator.VectorBytes() * 8, CompiledConfig(),
ctx.allocator.TotalMiB(), ctx.allocator.FreeMiB());
ctx.allocator.VectorBytes() * 8, CompiledConfig(), PROFILER_ENABLED,
ctx.allocator.TotalMiB());
}
}

Expand Down
1 change: 1 addition & 0 deletions evals/benchmark_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class GemmaEnv {
RuntimeConfig& MutableConfig() { return runtime_config_; }
std::mt19937& MutableGen() { return gen_; }
KVCache& MutableKVCache() { return kv_caches_[0]; }
MatMulEnv& MutableEnv() { return env_; }

private:
MatMulEnv env_;
Expand Down
4 changes: 2 additions & 2 deletions evals/cross_entropy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ HWY_EXPORT(CallSoftmax);

float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
int verbosity) {
MatMulEnv& env, int verbosity) {
const StreamFunc stream_token = [](int, float) { return true; };

const int vocab_size = gemma.GetModelConfig().vocab_size;
Expand Down Expand Up @@ -145,7 +145,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
};
TimingInfo timing_info;

gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info);
gemma.Generate(runtime, prompt0, 0, kv_cache, env, timing_info);

const float scale = 1.0f / std::log(2.0f);
return cross_entropy * scale;
Expand Down
2 changes: 1 addition & 1 deletion evals/cross_entropy.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace gcpp {

float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
int verbosity);
MatMulEnv& env, int verbosity);

} // namespace gcpp

Expand Down
4 changes: 2 additions & 2 deletions evals/gemma_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ TEST_F(GemmaTest, Multiturn) {
config.wrapping, abs_pos, mutable_prompt);

model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
timing_info);
s_env->MutableEnv(), timing_info);
// Note: we do not rewind any <end_of_turn> tokens here. If the model
// produced one and WrapAndTokenize() inserts another one, it will just be
// duplicated.
Expand All @@ -139,7 +139,7 @@ TEST_F(GemmaTest, Multiturn) {
// access to the previous turn by asking to reproduce.
response.clear();
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
timing_info);
s_env->MutableEnv(), timing_info);
fprintf(stderr, "decoded: '%s'\n", response.c_str());
bool remembered_turquoise =
response.find("turquoise") != std::string::npos; // NOLINT
Expand Down
3 changes: 2 additions & 1 deletion evals/run_mmlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ void Run(GemmaEnv& env, JsonArgs& json) {
.stream_token = stream_token,
};
env.GetGemma()->Generate(runtime_config, prompt, /*pos=*/0,
env.MutableKVCache(), timing_info);
env.MutableKVCache(), env.MutableEnv(),
timing_info);

std::string output_string = env.StringFromTokens(predicted_token_ids);
fprintf(stderr, "Correct %s, model '%s'\n", correct_answer.c_str(),
Expand Down
4 changes: 2 additions & 2 deletions examples/hello_world/run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ int main(int argc, char** argv) {

// Instantiate model and KV Cache
gcpp::MatMulEnv env(MakeMatMulEnv(threading, inference));
gcpp::Gemma gemma(loader, inference, env);
gcpp::Gemma gemma(loader, inference, env.ctx.pools);
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference);
size_t generated = 0;

Expand Down Expand Up @@ -93,5 +93,5 @@ int main(int argc, char** argv) {
return !reject_tokens.contains(token);
},
};
gemma.Generate(runtime_config, tokens, 0, kv_cache, timing_info);
gemma.Generate(runtime_config, tokens, 0, kv_cache, env, timing_info);
}
4 changes: 2 additions & 2 deletions examples/simplified_gemma/gemma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class SimplifiedGemma {
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
: env_(MakeMatMulEnv(threading, inference)),
gemma_(loader, inference, env_),
gemma_(loader, inference, env_.ctx.pools),
kv_cache_(gemma_.GetModelConfig(), inference) {
// Initialize random number generator
std::random_device rd;
Expand Down Expand Up @@ -83,7 +83,7 @@ class SimplifiedGemma {
return !reject_tokens.contains(token);
},
};
gemma_.Generate(runtime_config, tokens, 0, kv_cache_, timing_info);
gemma_.Generate(runtime_config, tokens, 0, kv_cache_, env_, timing_info);
}
~SimplifiedGemma() = default;

Expand Down
7 changes: 4 additions & 3 deletions gemma/bindings/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
threading_args(threading_args),
matmul_env(MakeMatMulEnv(threading_args, inference_args)),
active_conversation_name("default"),
model(loader, inference_args, matmul_env) {
model(loader, inference_args, matmul_env.ctx.pools) {
std::stringstream ss;

LogDebug("Creating initial ConversationData");
Expand Down Expand Up @@ -207,7 +207,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
// Pass the populated image object to GenerateImageTokens
model.GenerateImageTokens(runtime_config,
active_conversation->kv_cache->SeqLen(), image,
image_tokens);
image_tokens, matmul_env);
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;

ss.str("");
Expand Down Expand Up @@ -244,7 +244,8 @@ int GemmaContext::GenerateInternal(const char* prompt_string,

// Pass the KVCache object by reference from the active conversation
model.Generate(runtime_config, prompt_span, active_conversation->abs_pos,
prefix_end, *(active_conversation->kv_cache), timing_info);
prefix_end, *active_conversation->kv_cache, matmul_env,
timing_info);

// prepare for next turn
if (!inference_args.multiturn ||
Expand Down
38 changes: 19 additions & 19 deletions gemma/gemma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -610,62 +610,62 @@ MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args,
}

Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
MatMulEnv& env)
: env_(env),
reader_(loader.weights),
NestedPools& pools)
: reader_(loader.weights),
model_(reader_, loader.tokenizer, loader.wrapping),
weights_(model_.Config()),
chat_template_(model_.Tokenizer(), model_.Config().model),
inference_(inference) {
weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_,
env.ctx.pools.Pool());
pools.Pool());
reader_.CloseFile();
}

Gemma::~Gemma() = default;

void Gemma::Save(const Path& weights_path, hwy::ThreadPool& pool) const {
void Gemma::Save(const Path& weights_path, NestedPools& pools) const {
BlobWriter writer;
const std::vector<uint32_t> serialized_mat_ptrs =
weights_.AddTensorDataToWriter(writer);
WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs,
writer, env_.ctx.pools.Pool(), weights_path);
writer, pools.Pool(), weights_path);
}

void Gemma::Generate(const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, size_t prefix_end,
KVCache& kv_cache, TimingInfo& timing_info) const {
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
KVCache& kv_cache, MatMulEnv& env,
TimingInfo& timing_info) const {
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);

HWY_DYNAMIC_DISPATCH(GenerateSingleT)(prompt, pos, prefix_end,
model_.Config(), runtime_config,
weights_, kv_cache, env_, timing_info);
weights_, kv_cache, env, timing_info);

env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}

void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
AllQueries& all_queries,
AllQueries& all_queries, MatMulEnv& env,
TimingInfo& timing_info) const {
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);

HWY_DYNAMIC_DISPATCH(GenerateBatchT)(model_.Config(), runtime_config,
weights_, all_queries, env_,
timing_info);
weights_, all_queries, env, timing_info);

env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}

void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
size_t seq_len, const Image& image,
ImageTokens& image_tokens) const {
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
ImageTokens& image_tokens,
MatMulEnv& env) const {
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);

HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(model_.Config(), runtime_config,
seq_len, weights_, image,
image_tokens, env_);
image_tokens, env);

env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}

} // namespace gcpp
Expand Down
24 changes: 13 additions & 11 deletions gemma/gemma.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,46 +229,48 @@ struct TimingInfo {
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args,
const InferenceArgs& inference_args);

// After construction, all methods are const and thread-compatible if using
// separate MatMulEnv for each thread.
class Gemma {
public:
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
// `env` must remain valid for the lifetime of this Gemma.
// `pools` are used to parallelize loading.
Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
MatMulEnv& env);

NestedPools& pools);
~Gemma();

MatMulEnv& Env() const { return env_; }
// TODO: rename to Config()
const ModelConfig& GetModelConfig() const { return model_.Config(); }
const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); }
const ModelWeightsPtrs& Weights() const { return weights_; }
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
const InferenceArgs& Inference() const { return inference_; }

void Save(const Path& weights_path, hwy::ThreadPool& pool) const;
void Save(const Path& weights_path, NestedPools& pools) const;

// `pos` is the position in the KV cache. Users are responsible for
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn.
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
size_t pos, KVCache& kv_cache, TimingInfo& timing_info) const {
Generate(runtime_config, prompt, pos, /*prefix_end=*/0, kv_cache,
size_t pos, KVCache& kv_cache, MatMulEnv& env,
TimingInfo& timing_info) const {
Generate(runtime_config, prompt, pos, /*prefix_end=*/0, kv_cache, env,
timing_info);
}
// For prefix-LM style attention, we can pass the end of the prefix.
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
size_t pos, size_t prefix_end, KVCache& kv_cache,
TimingInfo& timing_info) const;
MatMulEnv& env, TimingInfo& timing_info) const;

void GenerateBatch(const RuntimeConfig& runtime_config,
AllQueries& all_queries, TimingInfo& timing_info) const;
AllQueries& all_queries, MatMulEnv& env,
TimingInfo& timing_info) const;

// Generates the image tokens by running the image encoder ViT.
void GenerateImageTokens(const RuntimeConfig& runtime_config, size_t seq_len,
const Image& image, ImageTokens& image_tokens) const;
const Image& image, ImageTokens& image_tokens,
MatMulEnv& env) const;

private:
MatMulEnv& env_;
BlobReader reader_;
ModelStore model_;
std::vector<MatOwner> mat_owners_;
Expand Down
Loading
Loading