Skip to content

Commit 0221956

Browse files
Merge pull request #87 from google:refactor-tidy
PiperOrigin-RevId: 615204427
2 parents a9aa63f + 7224761 commit 0221956

6 files changed

Lines changed: 146 additions & 131 deletions

File tree

BUILD.bazel

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,6 @@ cc_library(
4646
],
4747
)
4848

49-
cc_library(
50-
name = "app",
51-
hdrs = [
52-
"util/app.h",
53-
],
54-
deps = [
55-
":args",
56-
"@hwy//:hwy",
57-
],
58-
)
59-
6049
cc_library(
6150
name = "gemma_lib",
6251
srcs = [
@@ -80,6 +69,18 @@ cc_library(
8069
],
8170
)
8271

72+
cc_library(
73+
name = "app",
74+
hdrs = [
75+
"util/app.h",
76+
],
77+
deps = [
78+
":args",
79+
":gemma_lib",
80+
"@hwy//:hwy",
81+
],
82+
)
83+
8384
cc_binary(
8485
name = "gemma",
8586
srcs = [

examples/hello_world/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ if (BUILD_MODE STREQUAL "local")
3333
# Relative path to gemma.cpp from examples/hello_world/build/
3434
FetchContent_Declare(gemma SOURCE_DIR ../../..)
3535
else()
36-
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 8c7b2cf61b9794b806de091685dc6739dd3db837)
36+
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e)
3737
endif()
3838
FetchContent_MakeAvailable(gemma)
3939

examples/hello_world/run.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
// copybara:import_next_line:gemma_cpp
2222
#include "util/args.h"
2323
// copybara:end
24+
// copybara:import_next_line:gemma_cpp
25+
#include "util/app.h" // LoaderArgs
26+
// copybara:end
2427
#include "hwy/contrib/thread_pool/thread_pool.h"
2528

2629
std::vector<int> tokenize(

gemma.h

Lines changed: 0 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -71,124 +71,6 @@ struct RuntimeConfig {
7171
int verbosity;
7272
};
7373

74-
struct LoaderArgs : public ArgsBase<LoaderArgs> {
75-
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
76-
77-
static std::string ToLower(const std::string& text) {
78-
std::string result = text;
79-
std::transform(begin(result), end(result), begin(result),
80-
[](unsigned char c) { return std::tolower(c); });
81-
return result;
82-
}
83-
84-
gcpp::Model ModelType() const {
85-
const std::string model_type_lc = ToLower(model_type);
86-
if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") {
87-
return gcpp::Model::GEMMA_2B;
88-
} else {
89-
return gcpp::Model::GEMMA_7B;
90-
}
91-
}
92-
93-
gcpp::ModelTraining ModelTraining() const {
94-
const std::string model_type_lc = ToLower(model_type);
95-
if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") {
96-
return gcpp::ModelTraining::GEMMA_PT;
97-
} else {
98-
return gcpp::ModelTraining::GEMMA_IT;
99-
}
100-
}
101-
102-
// Returns error string or nullptr if OK.
103-
const char* Validate() const {
104-
const std::string model_type_lc = ToLower(model_type);
105-
if (model_type.empty()) {
106-
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
107-
"2b-it, or 7b-it.";
108-
}
109-
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
110-
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
111-
return "Model type must be 2b-pt, 7b-pt, 2b-it, or "
112-
"7b-it.";
113-
}
114-
if (tokenizer.path.empty()) {
115-
return "Missing --tokenizer flag, a file for the tokenizer is required.";
116-
}
117-
if (compressed_weights.path.empty()) {
118-
return "Missing --compressed_weights flag, a file for the compressed "
119-
"model.";
120-
}
121-
return nullptr;
122-
}
123-
124-
Path tokenizer;
125-
Path weights; // uncompressed weights file location
126-
Path compressed_weights; // compressed weights file location
127-
std::string model_type;
128-
129-
template <class Visitor>
130-
void ForEach(const Visitor& visitor) {
131-
visitor(tokenizer, "tokenizer", Path(),
132-
"Path name of tokenizer model file.\n Required argument.");
133-
visitor(
134-
compressed_weights, "compressed_weights", Path(),
135-
"Path name of compressed weights file, regenerated from `--weights` "
136-
"file if "
137-
"the compressed weights file does not exist.\n Required argument.");
138-
visitor(model_type, "model", std::string(),
139-
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
140-
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
141-
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n"
142-
" Required argument.");
143-
visitor(weights, "weights", Path(),
144-
"Path name of model weights (.sbs) file. Only required if "
145-
"compressed_weights file is not present and needs to be "
146-
"regenerated. This parameter is only required for compressing"
147-
"new model weight exports, otherwise it is not needed.");
148-
}
149-
};
150-
151-
struct InferenceArgs : public ArgsBase<InferenceArgs> {
152-
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
153-
154-
size_t max_tokens;
155-
size_t max_generated_tokens;
156-
157-
float temperature;
158-
bool deterministic;
159-
bool multiturn;
160-
161-
// Returns error string or nullptr if OK.
162-
const char* Validate() const {
163-
if (max_tokens > gcpp::kSeqLen) {
164-
return "max_tokens is larger than the maximum sequence length (see "
165-
"configs.h).";
166-
}
167-
if (max_generated_tokens > max_tokens) {
168-
return "Maximum number of generated tokens is larger than the maximum "
169-
"total tokens.";
170-
}
171-
return nullptr;
172-
}
173-
174-
template <class Visitor>
175-
void ForEach(const Visitor& visitor) {
176-
visitor(max_tokens, "max_tokens", size_t{3072},
177-
"Maximum number of tokens in prompt + generation.");
178-
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
179-
"Maximum number of tokens to generate.");
180-
181-
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
182-
visitor(deterministic, "deterministic", false,
183-
"Make top-k sampling deterministic", 2);
184-
visitor(multiturn, "multiturn", false,
185-
"Multiturn mode\n 0 = clear KV cache after every "
186-
"interaction\n 1 = continue KV cache after every interaction\n "
187-
" Default : 0 (conversation "
188-
"resets every turn)");
189-
}
190-
};
191-
19274
struct GemmaInterface;
19375

19476
struct Gemma {

run.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
119119
verbosity](int token, float) {
120120
++abs_pos;
121121
++current_pos;
122-
if (current_pos <= prompt_size) {
122+
if (current_pos < prompt_size) {
123123
std::cerr << "." << std::flush;
124124
} else if (token == gcpp::EOS_ID) {
125125
if (!args.multiturn) {

util/app.h

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,28 @@
1818
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
1919
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
2020

21+
#include <iterator>
2122
#if HWY_OS_LINUX
2223
#include <sched.h>
2324

25+
#include <cctype>
2426
#include <cerrno> // IDE does not recognize errno.h as providing errno.
27+
#include <string>
2528
#endif
2629
#include <stddef.h>
2730
#include <stdio.h>
2831

2932
#include <algorithm> // std::clamp
3033
#include <thread> // NOLINT>
3134

35+
// copybara:import_next_line:gemma_cpp
36+
#include "configs.h"
37+
// copybara:end
38+
39+
// copybara:import_next_line:gemma_cpp
40+
#include "gemma.h"
41+
// copybara:end
42+
3243
// copybara:import_next_line:gemma_cpp
3344
#include "util/args.h"
3445
// copybara:end
@@ -116,6 +127,124 @@ class AppArgs : public ArgsBase<AppArgs> {
116127
}
117128
};
118129

130+
struct LoaderArgs : public ArgsBase<LoaderArgs> {
131+
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
132+
133+
static std::string ToLower(const std::string& text) {
134+
std::string result = text;
135+
std::transform(begin(result), end(result), begin(result),
136+
[](unsigned char c) { return std::tolower(c); });
137+
return result;
138+
}
139+
140+
gcpp::Model ModelType() const {
141+
const std::string model_type_lc = ToLower(model_type);
142+
if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") {
143+
return gcpp::Model::GEMMA_2B;
144+
} else {
145+
return gcpp::Model::GEMMA_7B;
146+
}
147+
}
148+
149+
gcpp::ModelTraining ModelTraining() const {
150+
const std::string model_type_lc = ToLower(model_type);
151+
if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") {
152+
return gcpp::ModelTraining::GEMMA_PT;
153+
} else {
154+
return gcpp::ModelTraining::GEMMA_IT;
155+
}
156+
}
157+
158+
// Returns error string or nullptr if OK.
159+
const char* Validate() const {
160+
const std::string model_type_lc = ToLower(model_type);
161+
if (model_type.empty()) {
162+
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
163+
"2b-it, or 7b-it.";
164+
}
165+
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
166+
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
167+
return "Model type must be 2b-pt, 7b-pt, 2b-it, or "
168+
"7b-it.";
169+
}
170+
if (tokenizer.path.empty()) {
171+
return "Missing --tokenizer flag, a file for the tokenizer is required.";
172+
}
173+
if (compressed_weights.path.empty()) {
174+
return "Missing --compressed_weights flag, a file for the compressed "
175+
"model.";
176+
}
177+
return nullptr;
178+
}
179+
180+
Path tokenizer;
181+
Path weights; // uncompressed weights file location
182+
Path compressed_weights; // compressed weights file location
183+
std::string model_type;
184+
185+
template <class Visitor>
186+
void ForEach(const Visitor& visitor) {
187+
visitor(tokenizer, "tokenizer", Path(),
188+
"Path name of tokenizer model file.\n Required argument.");
189+
visitor(
190+
compressed_weights, "compressed_weights", Path(),
191+
"Path name of compressed weights file, regenerated from `--weights` "
192+
"file if "
193+
"the compressed weights file does not exist.\n Required argument.");
194+
visitor(model_type, "model", std::string(),
195+
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
196+
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
197+
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n"
198+
" Required argument.");
199+
visitor(weights, "weights", Path(),
200+
"Path name of model weights (.sbs) file. Only required if "
201+
"compressed_weights file is not present and needs to be "
202+
"regenerated. This parameter is only required for compressing"
203+
"new model weight exports, otherwise it is not needed.");
204+
}
205+
};
206+
207+
struct InferenceArgs : public ArgsBase<InferenceArgs> {
208+
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
209+
210+
size_t max_tokens;
211+
size_t max_generated_tokens;
212+
213+
float temperature;
214+
bool deterministic;
215+
bool multiturn;
216+
217+
// Returns error string or nullptr if OK.
218+
const char* Validate() const {
219+
if (max_tokens > gcpp::kSeqLen) {
220+
return "max_tokens is larger than the maximum sequence length (see "
221+
"configs.h).";
222+
}
223+
if (max_generated_tokens > max_tokens) {
224+
return "Maximum number of generated tokens is larger than the maximum "
225+
"total tokens.";
226+
}
227+
return nullptr;
228+
}
229+
230+
template <class Visitor>
231+
void ForEach(const Visitor& visitor) {
232+
visitor(max_tokens, "max_tokens", size_t{3072},
233+
"Maximum number of tokens in prompt + generation.");
234+
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
235+
"Maximum number of tokens to generate.");
236+
237+
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
238+
visitor(deterministic, "deterministic", false,
239+
"Make top-k sampling deterministic", 2);
240+
visitor(multiturn, "multiturn", false,
241+
"Multiturn mode\n 0 = clear KV cache after every "
242+
"interaction\n 1 = continue KV cache after every interaction\n "
243+
" Default : 0 (conversation "
244+
"resets every turn)");
245+
}
246+
};
247+
119248
} // namespace gcpp
120249

121250
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_

0 commit comments

Comments
 (0)