forked from google/gemma.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtensor_info_test.cc
More file actions
71 lines (62 loc) · 2.73 KB
/
tensor_info_test.cc
File metadata and controls
71 lines (62 loc) · 2.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include "gemma/tensor_info.h"
#include <stdio.h>
#include "gtest/gtest.h"
#include "compression/types.h" // SfpStream
#include "gemma/configs.h"
#include "gemma/weights.h"
#include "util/mat.h"
#include "hwy/base.h" // HWY_ASSERT_M
namespace gcpp {
namespace {
// Tests for all models that each tensor in the model can be found and that the
// TensorInfoRegistry returns the correct shape and name for the tensor.
TEST(TensorInfoRegistryTest, Find) {
ForEachModel([&](Model model) {
const ModelConfig config(model, Type::kSFP, ChooseWrapping(model));
fprintf(stderr, "Testing %s (%s)\n", config.display_name.c_str(),
config.Specifier().c_str());
const TensorInfoRegistry tensors(config);
// Each tensor in the model should be known/found.
WeightsPtrs weights(config);
weights.ForEachTensor(nullptr, nullptr, [&tensors](const TensorArgs& t) {
const TensorInfo* info = tensors.Find(t.mat.Name());
HWY_ASSERT_M(info, t.mat.Name());
// Test that the `MatPtr` can be constructed from the TensorInfo,
// and that the dimensions match.
const MatPtr mat_ptr(t.mat.Name(), Type::kUnknown,
ExtentsFromInfo(tensors.Find(t.mat.Name())));
EXPECT_STREQ(t.mat.Name(), mat_ptr.Name()) << t.mat.Name();
EXPECT_EQ(t.mat.Rows(), mat_ptr.Rows()) << t.mat.Name();
EXPECT_EQ(t.mat.Cols(), mat_ptr.Cols()) << t.mat.Name();
});
});
}
// Gemma 3 LM variants must not request any ViT tensors: their `vit_config`
// stays empty so `WeightsPtrs::ForEachTensor` skips the whole block.
TEST(TensorInfoRegistryTest, LmConfigsHaveNoVit) {
for (Model model :
{Model::GEMMA3_4B_LM, Model::GEMMA3_12B_LM, Model::GEMMA3_27B_LM}) {
const ModelConfig config(model, Type::kSFP, ChooseWrapping(model));
EXPECT_TRUE(config.vit_config.layer_configs.empty())
<< config.display_name;
EXPECT_EQ(config.wrapping, PromptWrapping::GEMMA_IT) << config.display_name;
WeightsPtrs weights(config);
weights.ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) {
const std::string name = t.mat.Name();
EXPECT_EQ(name.find("enc_norm_"), std::string::npos) << name;
EXPECT_EQ(name.find("img_"), std::string::npos) << name;
EXPECT_EQ(name.find("mm_embed_norm"), std::string::npos) << name;
});
}
}
// FindModel must disambiguate `gemma3-4b-...` and `gemma3-4b-lm-...` by
// preferring the longest matching prefix.
TEST(TensorInfoRegistryTest, FindModelLongestMatch) {
// Construction via the specifier-string ctor goes through `FindModel`.
const ModelConfig lm("gemma3-4b-lm-sfp-it");
EXPECT_EQ(lm.model, Model::GEMMA3_4B_LM);
const ModelConfig vlm("gemma3-4b-sfp");
EXPECT_EQ(vlm.model, Model::GEMMA3_4B);
}
} // namespace
} // namespace gcpp