Skip to content

Commit ede337f

Browse files
Merge pull request #125 from szabadka:gemma1
PiperOrigin-RevId: 621549709
2 parents 93a6489 + b670d43 commit ede337f

6 files changed

Lines changed: 252 additions & 8 deletions

File tree

BUILD.bazel

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,20 @@ cc_binary(
115115
"@hwy//:thread_pool",
116116
],
117117
)
118+
119+
cc_binary(
120+
name = "compress_weights",
121+
srcs = [
122+
"compress_weights.cc",
123+
],
124+
deps = [
125+
":args",
126+
":gemma_lib",
127+
# "//base",
128+
"//compression:compress",
129+
"@hwy//:hwy",
130+
"@hwy//:nanobenchmark",
131+
"@hwy//:profiler",
132+
"@hwy//:thread_pool",
133+
],
134+
)

CMakeLists.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ set_target_properties(libgemma PROPERTIES PREFIX "")
8080
set_property(TARGET libgemma PROPERTY POSITION_INDEPENDENT_CODE ON)
8181
target_include_directories(libgemma PUBLIC ./)
8282
target_link_libraries(libgemma hwy hwy_contrib sentencepiece)
83-
target_include_directories(libgemma PRIVATE ${sentencepiece_SOURCE_DIR})
83+
target_include_directories(libgemma PUBLIC ${sentencepiece_SOURCE_DIR})
8484
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
8585
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
8686

@@ -115,3 +115,8 @@ foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)
115115
gtest_discover_tests(${TESTNAME})
116116
endforeach ()
117117
endif() # GEMMA_ENABLE_TESTS
118+
119+
## Tools
120+
121+
add_executable(compress_weights compress_weights.cc)
122+
target_link_libraries(compress_weights libgemma hwy hwy_contrib)

compress_weights.cc

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
// Copyright 2024 Google LLC
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// https://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
// Command line tool to create compressed weights.
17+
18+
#include <iostream>
19+
#include <string>
20+
21+
// copybara:import_next_line:gemma_cpp
22+
#include "gemma.h" // Gemma
23+
// copybara:end
24+
// copybara:import_next_line:gemma_cpp
25+
#include "util/args.h"
26+
// copybara:end
27+
28+
namespace gcpp {
29+
30+
struct Args : public ArgsBase<Args> {
31+
static constexpr size_t kDefaultNumThreads = ~size_t{0};
32+
33+
void ChooseNumThreads() {
34+
if (num_threads == kDefaultNumThreads) {
35+
// This is a rough heuristic, replace with something better in the future.
36+
num_threads = static_cast<size_t>(std::clamp(
37+
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
38+
}
39+
}
40+
41+
public:
42+
Args(int argc, char* argv[]) {
43+
InitAndParse(argc, argv);
44+
ChooseNumThreads();
45+
}
46+
47+
static std::string ToLower(const std::string& text) {
48+
std::string result = text;
49+
std::transform(begin(result), end(result), begin(result),
50+
[](unsigned char c) { return std::tolower(c); });
51+
return result;
52+
}
53+
54+
gcpp::Model ModelType() const {
55+
const std::string model_type_lc = ToLower(model_type);
56+
if (model_type_lc.substr(0, 2) == "2b") {
57+
return gcpp::Model::GEMMA_2B;
58+
} else if (model_type_lc.substr(0, 2) == "7b") {
59+
return gcpp::Model::GEMMA_7B;
60+
} else {
61+
HWY_ABORT("Unknown model type %s", model_type_lc.c_str());
62+
}
63+
}
64+
65+
// Returns error string or nullptr if OK.
66+
const char* Validate() const {
67+
const std::string model_type_lc = ToLower(model_type);
68+
if (model_type.empty()) {
69+
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
70+
"2b-it, 7b-it.";
71+
}
72+
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
73+
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
74+
return "Model type must be 2b-pt, 7b-pt, 2b-it, 7b-it.";
75+
}
76+
if (weights.path.empty()) {
77+
return "Missing --weights flag, a file for the uncompressed model.";
78+
}
79+
if (compressed_weights.path.empty()) {
80+
return "Missing --compressed_weights flag, a file for the compressed "
81+
"model.";
82+
}
83+
if (!weights.exists()) {
84+
return "Can't open file specified with --weights flag.";
85+
}
86+
return nullptr;
87+
}
88+
89+
Path weights; // uncompressed weights file location
90+
Path compressed_weights; // compressed weights file location
91+
std::string model_type;
92+
size_t num_threads;
93+
94+
template <class Visitor>
95+
void ForEach(const Visitor& visitor) {
96+
visitor(weights, "weights", Path(),
97+
"Path name of model weights (.sbs) file.\n"
98+
" Required argument.");
99+
visitor(model_type, "model", std::string(),
100+
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
101+
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
102+
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
103+
" Required argument.");
104+
visitor(compressed_weights, "compressed_weights", Path(),
105+
"Path name where compressed weights file will be written.\n"
106+
" Required argument.");
107+
visitor(num_threads, "num_threads",
108+
kDefaultNumThreads, // see ChooseNumThreads
109+
"Number of threads to use.\n Default = Estimate of the "
110+
"number of suupported concurrent threads.",
111+
2);
112+
}
113+
};
114+
115+
void ShowHelp(gcpp::Args& args) {
116+
std::cerr
117+
<< "Usage:\n./compress_weights --weights <path to uncompressed weights> "
118+
" --model <model type> --compressed_weights <output path>\n";
119+
std::cerr << "\n*Arguments*\n\n";
120+
args.Help();
121+
std::cerr << "\n";
122+
}
123+
124+
void Run(Args& args) {
125+
hwy::ThreadPool pool(args.num_threads);
126+
gcpp::CompressWeights(args.ModelType(), args.weights, args.compressed_weights,
127+
pool);
128+
}
129+
130+
} // namespace gcpp
131+
132+
int main(int argc, char** argv) {
133+
gcpp::Args args(argc, argv);
134+
135+
if (gcpp::HasHelp(argc, argv)) {
136+
ShowHelp(args);
137+
return 0;
138+
}
139+
140+
if (const char* error = args.Validate()) {
141+
ShowHelp(args);
142+
HWY_ABORT("\nInvalid args: %s", error);
143+
}
144+
145+
gcpp::Run(args);
146+
147+
return 0;
148+
}

gemma.cc

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,13 @@ hwy::AlignedUniquePtr<Weights<TConfig>> LoadWeights(const Path& checkpoint) {
116116
checkpoint.path.c_str());
117117
}
118118
bool ok = true;
119+
uint64_t total_size = 0;
119120
ok &= 1 == fread(&(weights->embedder_input_embedding),
120121
sizeof(weights->embedder_input_embedding), 1, fptr);
121122
ok &= 1 == fread(&(weights->final_norm_scale),
122123
sizeof(weights->final_norm_scale), 1, fptr);
124+
total_size += sizeof(weights->embedder_input_embedding) +
125+
sizeof(weights->final_norm_scale);
123126
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
124127
Layer<TConfig>* layer_view = &weights->layers[layer];
125128
ok &= 1 == fread(&layer_view->attn_vec_einsum_w,
@@ -134,10 +137,12 @@ hwy::AlignedUniquePtr<Weights<TConfig>> LoadWeights(const Path& checkpoint) {
134137
sizeof(layer_view->pre_attention_norm_scale), 1, fptr);
135138
ok &= 1 == fread(&layer_view->pre_ffw_norm_scale,
136139
sizeof(layer_view->pre_ffw_norm_scale), 1, fptr);
140+
total_size += sizeof(*layer_view);
137141
}
138142
if (!ok) {
139-
HWY_ABORT("Failed to read from %s - might be a directory, or too small?",
140-
checkpoint.path.c_str());
143+
HWY_ABORT("Failed to read from %s - might be a directory, or too small? "
144+
"expected size: %d kB", checkpoint.path.c_str(),
145+
static_cast<uint32_t>(total_size >> 10));
141146
}
142147
HWY_ASSERT(0 == fclose(fptr));
143148
return weights;
@@ -813,6 +818,47 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeightsT(
813818
}
814819
}
815820

821+
template <class TConfig>
822+
void CompressWeights(const Path& weights_path,
823+
const Path& compressed_weights_path,
824+
hwy::ThreadPool& pool) {
825+
if (!std::filesystem::exists(weights_path.path)) {
826+
HWY_ABORT("The model weights file '%s' does not exist.",
827+
weights_path.path.c_str());
828+
}
829+
830+
// Allocate compressed weights.
831+
using CWeights = CompressedWeights<TConfig>;
832+
hwy::AlignedFreeUniquePtr<uint8_t[]> c_weights_u8 =
833+
hwy::AllocateAligned<uint8_t>(sizeof(CWeights));
834+
CWeights* c_weights = reinterpret_cast<CWeights*>(c_weights_u8.get());
835+
new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
836+
837+
// Get weights, compress, and store.
838+
const hwy::AlignedUniquePtr<Weights<TConfig>> weights =
839+
LoadWeights<TConfig>(weights_path);
840+
Compressor compressor(pool);
841+
ForEachTensor<TConfig>(weights.get(), *c_weights, compressor);
842+
compressor.WriteAll(pool, compressed_weights_path.path.c_str());
843+
844+
c_weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
845+
}
846+
847+
void CompressWeightsT(gcpp::Model model, const Path& weights,
848+
const Path& compressed_weights,
849+
hwy::ThreadPool& pool) {
850+
switch (model) {
851+
case Model::GEMMA_2B:
852+
CompressWeights<ConfigGemma2B>(weights, compressed_weights, pool);
853+
break;
854+
case Model::GEMMA_7B:
855+
CompressWeights<ConfigGemma7B>(weights, compressed_weights, pool);
856+
break;
857+
default:
858+
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
859+
}
860+
}
861+
816862
} // namespace HWY_NAMESPACE
817863
} // namespace gcpp
818864
HWY_AFTER_NAMESPACE();
@@ -821,6 +867,7 @@ HWY_AFTER_NAMESPACE();
821867
namespace gcpp {
822868

823869
HWY_EXPORT(GetCompressedWeightsT);
870+
HWY_EXPORT(CompressWeightsT);
824871
HWY_EXPORT(Generate2B);
825872
HWY_EXPORT(Generate7B);
826873

@@ -922,5 +969,12 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
922969
stream_token, [](int) { return true; }, gen, runtime_config.verbosity);
923970
}
924971

972+
void CompressWeights(gcpp::Model model, const Path& weights,
973+
const Path& compressed_weights,
974+
hwy::ThreadPool& pool) {
975+
HWY_DYNAMIC_DISPATCH(CompressWeightsT)(
976+
model, weights, compressed_weights, pool);
977+
}
978+
925979
} // namespace gcpp
926980
#endif // HWY_ONCE

gemma.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
9898
KVCache& kv_cache, hwy::ThreadPool& pool,
9999
const StreamFunc& stream_token, std::mt19937& gen);
100100

101+
void CompressWeights(gcpp::Model model, const Path& weights,
102+
const Path& compressed_weights,
103+
hwy::ThreadPool& pool);
104+
101105
constexpr int EOS_ID = 1;
102106

103107
} // namespace gcpp

util/args.h

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,27 +25,43 @@
2525

2626
#include "hwy/base.h" // HWY_ABORT
2727

28+
#if defined(_WIN32)
29+
#include <io.h>
30+
#define F_OK 0
31+
#define access _access
32+
#else
33+
#include <unistd.h>
34+
#endif
35+
2836
namespace gcpp {
2937

3038
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
3139
// strings and supports shortening for display purposes.
3240
struct Path {
41+
Path() {}
42+
explicit Path(const char* p) : path(p) {}
43+
3344
Path& operator=(const char* other) {
3445
path = other;
3546
return *this;
3647
}
3748

3849
std::string Shortened() const {
39-
constexpr size_t max_len = 48;
40-
constexpr size_t cut_point = max_len / 2 - 5;
41-
if (path.size() > max_len) {
42-
return std::string(begin(path), begin(path) + cut_point) + " ... " +
43-
std::string(end(path) - cut_point, end(path));
50+
constexpr size_t kMaxLen = 48;
51+
constexpr size_t kCutPoint = kMaxLen / 2 - 5;
52+
if (path.size() > kMaxLen) {
53+
return std::string(begin(path), begin(path) + kCutPoint) + " ... " +
54+
std::string(end(path) - kCutPoint, end(path));
4455
}
4556
if (path.empty()) return "[no path specified]";
4657
return path;
4758
}
4859

60+
// Beware, TOCTOU.
61+
bool exists() const {
62+
return (access(path.c_str(), F_OK) == 0);
63+
}
64+
4965
std::string path;
5066
};
5167

0 commit comments

Comments
 (0)