Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ cc_library(
deps = [
":basics",
"//compression:types",
"//io",
"//io:fields",
"@highway//:hwy", # base.h
],
Expand Down
7 changes: 4 additions & 3 deletions gemma/configs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include "compression/types.h" // Type
#include "io/fields.h" // IFields
#include "io/io.h" // Path
#include "hwy/base.h"

namespace gcpp {
Expand Down Expand Up @@ -708,7 +709,7 @@ bool ModelConfig::OverwriteWithCanonical() {
return found;
}

Model DeduceModel(size_t layers, int layer_types) {
Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) {
switch (layers) {
case 2:
return Model::GEMMA_TINY;
Expand Down Expand Up @@ -740,8 +741,8 @@ Model DeduceModel(size_t layers, int layer_types) {
return Model::PALIGEMMA2_772M_224;
*/
default:
HWY_WARN("Failed to deduce model type from layer count %zu types %x.",
layers, layer_types);
HWY_WARN("Failed to deduce model type from %s, layer count %zu types %x.",
blob_path.path.c_str(), layers, layer_types);
return Model::UNKNOWN;
}
}
Expand Down
3 changes: 2 additions & 1 deletion gemma/configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include "compression/types.h" // Type
#include "io/fields.h" // IFieldsVisitor
#include "io/io.h" // Path
#include "util/basics.h"

namespace gcpp {
Expand Down Expand Up @@ -482,7 +483,7 @@ enum DeducedLayerTypes {
};

// layer_types is one or more of `DeducedLayerTypes`.
Model DeduceModel(size_t layers, int layer_types);
Model DeduceModel(const Path& blob_path, size_t layers, int layer_types);

} // namespace gcpp

Expand Down
3 changes: 2 additions & 1 deletion gemma/model_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ static ModelConfig ReadOrDeduceConfig(BlobReader& reader,
// Always deduce so we can verify it against the config we read.
const size_t layers = DeduceNumLayers(reader.Keys());
const int layer_types = DeduceLayerTypes(reader);
const Model deduced_model = DeduceModel(layers, layer_types);
const Model deduced_model =
DeduceModel(reader.blob_path(), layers, layer_types);

ModelConfig config;
// Check first to prevent `CallWithSpan` from printing a warning.
Expand Down
4 changes: 3 additions & 1 deletion io/blob_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,9 @@ class BlobStore {
}; // BlobStore

BlobReader::BlobReader(const Path& blob_path)
: file_(OpenFileOrAbort(blob_path, "r")), file_bytes_(file_->FileSize()) {
: blob_path_(blob_path),
file_(OpenFileOrAbort(blob_path, "r")),
file_bytes_(file_->FileSize()) {
if (file_bytes_ == 0) HWY_ABORT("Zero-sized file %s", blob_path.path.c_str());

BlobStore bs(*file_);
Expand Down
3 changes: 3 additions & 0 deletions io/blob_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class BlobReader {
// Aborts on error.
explicit BlobReader(const Path& blob_path);

const Path& blob_path() const { return blob_path_; }

// Non-const version required for File::Map().
File& file() { return *file_; }
const File& file() const { return *file_; }
Expand Down Expand Up @@ -101,6 +103,7 @@ class BlobReader {
}

private:
Path blob_path_;
std::unique_ptr<File> file_;
const uint64_t file_bytes_;

Expand Down
Loading