|
| 1 | +// Copyright 2024 Google LLC |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | +// |
| 4 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +// you may not use this file except in compliance with the License. |
| 6 | +// You may obtain a copy of the License at |
| 7 | +// |
| 8 | +// https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +// |
| 10 | +// Unless required by applicable law or agreed to in writing, software |
| 11 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +// See the License for the specific language governing permissions and |
| 14 | +// limitations under the License. |
| 15 | + |
| 16 | +// Command line tool to create compressed weights. |
| 17 | + |
| 18 | +#include <iostream> |
| 19 | +#include <string> |
| 20 | + |
| 21 | +// copybara:import_next_line:gemma_cpp |
| 22 | +#include "gemma.h" // Gemma |
| 23 | +// copybara:end |
| 24 | +// copybara:import_next_line:gemma_cpp |
| 25 | +#include "util/args.h" |
| 26 | +// copybara:end |
| 27 | + |
| 28 | +namespace gcpp { |
| 29 | + |
| 30 | +struct Args : public ArgsBase<Args> { |
| 31 | + static constexpr size_t kDefaultNumThreads = ~size_t{0}; |
| 32 | + |
| 33 | + void ChooseNumThreads() { |
| 34 | + if (num_threads == kDefaultNumThreads) { |
| 35 | + // This is a rough heuristic, replace with something better in the future. |
| 36 | + num_threads = static_cast<size_t>(std::clamp( |
| 37 | + static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18)); |
| 38 | + } |
| 39 | + } |
| 40 | + |
| 41 | + public: |
| 42 | + Args(int argc, char* argv[]) { |
| 43 | + InitAndParse(argc, argv); |
| 44 | + ChooseNumThreads(); |
| 45 | + } |
| 46 | + |
| 47 | + static std::string ToLower(const std::string& text) { |
| 48 | + std::string result = text; |
| 49 | + std::transform(begin(result), end(result), begin(result), |
| 50 | + [](unsigned char c) { return std::tolower(c); }); |
| 51 | + return result; |
| 52 | + } |
| 53 | + |
| 54 | + gcpp::Model ModelType() const { |
| 55 | + const std::string model_type_lc = ToLower(model_type); |
| 56 | + if (model_type_lc.substr(0, 2) == "2b") { |
| 57 | + return gcpp::Model::GEMMA_2B; |
| 58 | + } else if (model_type_lc.substr(0, 2) == "7b") { |
| 59 | + return gcpp::Model::GEMMA_7B; |
| 60 | + } else { |
| 61 | + HWY_ABORT("Unknown model type %s", model_type_lc.c_str()); |
| 62 | + } |
| 63 | + } |
| 64 | + |
| 65 | + // Returns error string or nullptr if OK. |
| 66 | + const char* Validate() const { |
| 67 | + const std::string model_type_lc = ToLower(model_type); |
| 68 | + if (model_type.empty()) { |
| 69 | + return "Missing --model flag, need to specify either 2b-pt, 7b-pt, " |
| 70 | + "2b-it, 7b-it."; |
| 71 | + } |
| 72 | + if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" && |
| 73 | + model_type_lc != "2b-it" && model_type_lc != "7b-it") { |
| 74 | + return "Model type must be 2b-pt, 7b-pt, 2b-it, 7b-it."; |
| 75 | + } |
| 76 | + if (weights.path.empty()) { |
| 77 | + return "Missing --weights flag, a file for the uncompressed model."; |
| 78 | + } |
| 79 | + if (compressed_weights.path.empty()) { |
| 80 | + return "Missing --compressed_weights flag, a file for the compressed " |
| 81 | + "model."; |
| 82 | + } |
| 83 | + if (!weights.exists()) { |
| 84 | + return "Can't open file specified with --weights flag."; |
| 85 | + } |
| 86 | + return nullptr; |
| 87 | + } |
| 88 | + |
| 89 | + Path weights; // uncompressed weights file location |
| 90 | + Path compressed_weights; // compressed weights file location |
| 91 | + std::string model_type; |
| 92 | + size_t num_threads; |
| 93 | + |
| 94 | + template <class Visitor> |
| 95 | + void ForEach(const Visitor& visitor) { |
| 96 | + visitor(weights, "weights", Path(), |
| 97 | + "Path name of model weights (.sbs) file.\n" |
| 98 | + " Required argument."); |
| 99 | + visitor(model_type, "model", std::string(), |
| 100 | + "Model type\n 2b-it = 2B parameters, instruction-tuned\n " |
| 101 | + "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters " |
| 102 | + "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n " |
| 103 | + " Required argument."); |
| 104 | + visitor(compressed_weights, "compressed_weights", Path(), |
| 105 | + "Path name where compressed weights file will be written.\n" |
| 106 | + " Required argument."); |
| 107 | + visitor(num_threads, "num_threads", |
| 108 | + kDefaultNumThreads, // see ChooseNumThreads |
| 109 | + "Number of threads to use.\n Default = Estimate of the " |
| 110 | + "number of suupported concurrent threads.", |
| 111 | + 2); |
| 112 | + } |
| 113 | +}; |
| 114 | + |
| 115 | +void ShowHelp(gcpp::Args& args) { |
| 116 | + std::cerr |
| 117 | + << "Usage:\n./compress_weights --weights <path to uncompressed weights> " |
| 118 | + " --model <model type> --compressed_weights <output path>\n"; |
| 119 | + std::cerr << "\n*Arguments*\n\n"; |
| 120 | + args.Help(); |
| 121 | + std::cerr << "\n"; |
| 122 | +} |
| 123 | + |
| 124 | +void Run(Args& args) { |
| 125 | + hwy::ThreadPool pool(args.num_threads); |
| 126 | + gcpp::CompressWeights(args.ModelType(), args.weights, args.compressed_weights, |
| 127 | + pool); |
| 128 | +} |
| 129 | + |
| 130 | +} // namespace gcpp |
| 131 | + |
| 132 | +int main(int argc, char** argv) { |
| 133 | + gcpp::Args args(argc, argv); |
| 134 | + |
| 135 | + if (gcpp::HasHelp(argc, argv)) { |
| 136 | + ShowHelp(args); |
| 137 | + return 0; |
| 138 | + } |
| 139 | + |
| 140 | + if (const char* error = args.Validate()) { |
| 141 | + ShowHelp(args); |
| 142 | + HWY_ABORT("\nInvalid args: %s", error); |
| 143 | + } |
| 144 | + |
| 145 | + gcpp::Run(args); |
| 146 | + |
| 147 | + return 0; |
| 148 | +} |
0 commit comments