|
7 | 7 | # 1. 使用 PyTorch 定义并保存模型 |
8 | 8 | # ============================================================ |
9 | 9 |
|
10 | | -class TorchMLP(nn.Module): |
11 | | - def __init__(self, in_dim=4, hidden_dim=8, out_dim=2): |
| 10 | +class TorchConvNet(nn.Module): |
| 11 | + def __init__(self, in_ch=3, hidden_ch=8, out_ch=3): |
12 | 12 | super().__init__() |
13 | | - self.fc1 = nn.Linear(in_dim, hidden_dim) |
14 | | - self.fc2 = nn.Linear(hidden_dim, out_dim) |
| 13 | + # 主体网络 |
| 14 | + self.conv1 = nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1) |
| 15 | + self.bn1 = nn.BatchNorm2d(hidden_ch) |
| 16 | + self.conv2 = nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1) |
| 17 | + self.bn2 = nn.BatchNorm2d(hidden_ch) |
| 18 | + self.conv3 = nn.Conv2d(hidden_ch, out_ch, kernel_size=1) |
| 19 | + self.relu = nn.ReLU() |
| 20 | + |
| 21 | + # 自定义 Parameter(例如一个可学习缩放因子) |
| 22 | + self.scale = nn.Parameter(torch.ones(1) * 0.5) |
| 23 | + |
| 24 | + # 注册一个 buffer(非参数,例如推理时的固定偏移) |
| 25 | + self.register_buffer("offset", torch.tensor(0.1)) |
15 | 26 |
|
16 | 27 | def forward(self, x): |
17 | | - x = torch.relu(self.fc1(x)) |
18 | | - x = self.fc2(x) |
| 28 | + x = self.relu(self.bn1(self.conv1(x))) |
| 29 | + x = self.relu(self.bn2(self.conv2(x))) |
| 30 | + x = self.conv3(x) |
| 31 | + # 应用自定义参数和 buffer |
| 32 | + x = x * self.scale + self.offset |
19 | 33 | return x |
20 | 34 |
|
21 | | -torch_model = TorchMLP() |
22 | | -torch_state_dict = torch_model.state_dict() |
23 | 35 |
|
24 | | -safetensors.torch.save_file(torch_state_dict, "torch_model.safetensors") |
| 36 | +# ===== 保存 Torch 模型 ===== |
| 37 | +torch_model = TorchConvNet() |
| 38 | +torch_state_dict = torch_model.state_dict() |
| 39 | +safetensors.torch.save_file(torch_state_dict, "torch_convnet_with_param.safetensors") |
25 | 40 |
|
26 | 41 | # ============================================================ |
27 | 42 | # 2. 使用 torch 方式加载并推理 |
28 | 43 | # ============================================================ |
29 | 44 |
|
30 | | -torch_model_infer = TorchMLP() |
31 | | -torch_model_infer.load_state_dict(safetensors.torch.load_file("torch_model.safetensors")) |
| 45 | +torch_model_infer = TorchConvNet() |
| 46 | +torch_model_infer.load_state_dict(safetensors.torch.load_file("torch_convnet_with_param.safetensors")) |
32 | 47 | torch_model_infer.eval() |
33 | 48 |
|
34 | | -input = torch.rand(1, 4) |
| 49 | +input = torch.rand(1, 3, 8, 8) |
35 | 50 | torch_model_out = torch_model_infer(input) |
36 | | -print("Torch 输出:", torch_model_out.detach().numpy()) |
| 51 | +print("Torch 输出:", torch_model_out.detach().numpy().mean()) |
37 | 52 |
|
38 | 53 | # ============================================================ |
39 | 54 | # 3. 使用 InfiniCore.Module 系统加载并推理 |
40 | 55 | # ============================================================ |
41 | 56 |
|
42 | | -# ===== 下面定义一个与 TorchMLP 对应的 InfiniCoreModule类 ===== |
43 | 57 | from python.infinicore.nn.modules import InfiniCoreModule |
44 | 58 |
|
45 | | -class InfiniCoreMLP(InfiniCoreModule): |
46 | | - def __init__(self, in_dim=4, hidden_dim=8, out_dim=2): |
| 59 | +class InfiniCoreConvNet(InfiniCoreModule): |
| 60 | + def __init__(self, in_ch=3, hidden_ch=8, out_ch=3): |
47 | 61 | super().__init__() |
48 | | - self.fc1 = nn.Linear(in_dim, hidden_dim) |
49 | | - self.fc2 = nn.Linear(hidden_dim, out_dim) |
| 62 | + self.conv1 = nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1) |
| 63 | + self.bn1 = nn.BatchNorm2d(hidden_ch) |
| 64 | + self.conv2 = nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1) |
| 65 | + self.bn2 = nn.BatchNorm2d(hidden_ch) |
| 66 | + self.conv3 = nn.Conv2d(hidden_ch, out_ch, kernel_size=1) |
| 67 | + self.relu = nn.ReLU() |
| 68 | + |
| 69 | + # 保持与 Torch 模型一致的自定义参数和 buffer |
| 70 | + self.scale = nn.Parameter(torch.ones(1) * 0.5) |
| 71 | + self.register_buffer("offset", torch.tensor(0.1)) |
50 | 72 |
|
51 | 73 | def forward(self, x): |
52 | | - x = torch.relu(self.fc1(x)) |
53 | | - x = self.fc2(x) |
| 74 | + x = self.relu(self.bn1(self.conv1(x))) |
| 75 | + x = self.relu(self.bn2(self.conv2(x))) |
| 76 | + x = self.conv3(x) |
| 77 | + x = x * self.scale + self.offset |
54 | 78 | return x |
55 | 79 |
|
56 | | -# ===== 使用 InfiniCoreMLP 读取 safetensors 并推理 ===== |
57 | | -infinicore_model_infer = InfiniCoreMLP() |
58 | | -infinicore_model_infer.load_state_dict(safetensors.torch.load_file("torch_model.safetensors")) |
59 | | -infinicore_model_out = infinicore_model_infer.forward(input) |
| 80 | +# ===== 使用 InfiniCoreConvNet 读取 safetensors 并推理 ===== |
| 81 | +infinicore_model_infer = InfiniCoreConvNet() |
| 82 | +infinicore_model_infer.load_state_dict(safetensors.torch.load_file("torch_convnet_with_param.safetensors")) |
| 83 | +infinicore_model_infer.eval() |
60 | 84 |
|
61 | | -print("InfiniCore 输出:", infinicore_model_out.detach().numpy()) |
| 85 | +infinicore_model_out = infinicore_model_infer.forward(input) |
| 86 | +print("InfiniCore 输出:", infinicore_model_out.detach().numpy().mean()) |
62 | 87 |
|
63 | 88 | # ============================================================ |
64 | 89 | # 4. 对比结果 |
65 | 90 | # ============================================================ |
66 | 91 |
|
67 | 92 | diff = (infinicore_model_out - torch_model_out).abs().max().item() |
68 | | -print(f"InfiniCoreModule 与 Torch 最大误差: {diff:.6f}") |
| 93 | +print(f"InfiniCoreModule 与 Torch 最大误差: {diff:.8f}") |
69 | 94 | if diff < 1e-6: |
70 | 95 | print("✅ InfiniCoreModule 与 Torch 精度一致!") |
71 | 96 | else: |
|
0 commit comments