Skip to content

Commit deac1c0

Browse files
committed
feat: add eval() in InfiniCoreModule
1 parent b72e6c7 commit deac1c0

4 files changed

Lines changed: 123 additions & 26 deletions

File tree

python/infinicore/nn/modules/module.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
2121

22+
T = TypeVar('T', bound='InfiniCoreModule')
2223

2324
class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])):
2425
def __repr__(self):
@@ -514,6 +515,77 @@ def load(module, local_state_dict, prefix=''):
514515
self.__class__.__name__, "\n\t".join(error_msgs)))
515516
return _IncompatibleKeys(missing_keys, unexpected_keys)
516517

518+
def children(self) -> Iterator['InfiniCoreModule']:
519+
r"""Returns an iterator over immediate children modules.
520+
521+
Yields:
522+
Module: a child module
523+
"""
524+
for name, module in self.named_children():
525+
yield module
526+
527+
def named_children(self) -> Iterator[Tuple[str, 'InfiniCoreModule']]:
528+
r"""Returns an iterator over immediate children modules, yielding both
529+
the name of the module as well as the module itself.
530+
531+
Yields:
532+
(str, Module): Tuple containing a name and child module
533+
534+
Example::
535+
536+
>>> # xdoctest: +SKIP("undefined vars")
537+
>>> for name, module in model.named_children():
538+
>>> if name in ['conv4', 'conv5']:
539+
>>> print(module)
540+
541+
"""
542+
memo = set()
543+
for name, module in self._modules.items():
544+
if module is not None and module not in memo:
545+
memo.add(module)
546+
yield name, module
547+
548+
549+
def train(self: T, mode: bool = True) -> T:
550+
r"""Sets the module in training mode.
551+
552+
This has any effect only on certain modules. See documentations of
553+
particular modules for details of their behaviors in training/evaluation
554+
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
555+
etc.
556+
557+
Args:
558+
mode (bool): whether to set training mode (``True``) or evaluation
559+
mode (``False``). Default: ``True``.
560+
561+
Returns:
562+
Module: self
563+
"""
564+
if not isinstance(mode, bool):
565+
raise ValueError("training mode is expected to be boolean")
566+
self.training = mode
567+
for module in self.children():
568+
module.train(mode)
569+
return self
570+
571+
def eval(self: T) -> T:
572+
r"""Sets the module in evaluation mode.
573+
574+
This has any effect only on certain modules. See documentations of
575+
particular modules for details of their behaviors in training/evaluation
576+
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
577+
etc.
578+
579+
This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.
580+
581+
See :ref:`locally-disable-grad-doc` for a comparison between
582+
`.eval()` and several similar mechanisms that may be confused with it.
583+
584+
Returns:
585+
Module: self
586+
"""
587+
return self.train(False)
588+
517589

518590
def to(self, device: torch.device) -> "InfiniCoreModule":
519591
for name, param in self._parameters.items():

test/infinicore/infinicore_nn_test.py

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,65 +7,90 @@
77
# 1. 使用 PyTorch 定义并保存模型
88
# ============================================================
99

10-
class TorchMLP(nn.Module):
11-
def __init__(self, in_dim=4, hidden_dim=8, out_dim=2):
10+
class TorchConvNet(nn.Module):
11+
def __init__(self, in_ch=3, hidden_ch=8, out_ch=3):
1212
super().__init__()
13-
self.fc1 = nn.Linear(in_dim, hidden_dim)
14-
self.fc2 = nn.Linear(hidden_dim, out_dim)
13+
# 主体网络
14+
self.conv1 = nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1)
15+
self.bn1 = nn.BatchNorm2d(hidden_ch)
16+
self.conv2 = nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1)
17+
self.bn2 = nn.BatchNorm2d(hidden_ch)
18+
self.conv3 = nn.Conv2d(hidden_ch, out_ch, kernel_size=1)
19+
self.relu = nn.ReLU()
20+
21+
# 自定义 Parameter(例如一个可学习缩放因子)
22+
self.scale = nn.Parameter(torch.ones(1) * 0.5)
23+
24+
# 注册一个 buffer(非参数,例如推理时的固定偏移)
25+
self.register_buffer("offset", torch.tensor(0.1))
1526

1627
def forward(self, x):
17-
x = torch.relu(self.fc1(x))
18-
x = self.fc2(x)
28+
x = self.relu(self.bn1(self.conv1(x)))
29+
x = self.relu(self.bn2(self.conv2(x)))
30+
x = self.conv3(x)
31+
# 应用自定义参数和 buffer
32+
x = x * self.scale + self.offset
1933
return x
2034

21-
torch_model = TorchMLP()
22-
torch_state_dict = torch_model.state_dict()
2335

24-
safetensors.torch.save_file(torch_state_dict, "torch_model.safetensors")
36+
# ===== 保存 Torch 模型 =====
37+
torch_model = TorchConvNet()
38+
torch_state_dict = torch_model.state_dict()
39+
safetensors.torch.save_file(torch_state_dict, "torch_convnet_with_param.safetensors")
2540

2641
# ============================================================
2742
# 2. 使用 torch 方式加载并推理
2843
# ============================================================
2944

30-
torch_model_infer = TorchMLP()
31-
torch_model_infer.load_state_dict(safetensors.torch.load_file("torch_model.safetensors"))
45+
torch_model_infer = TorchConvNet()
46+
torch_model_infer.load_state_dict(safetensors.torch.load_file("torch_convnet_with_param.safetensors"))
3247
torch_model_infer.eval()
3348

34-
input = torch.rand(1, 4)
49+
input = torch.rand(1, 3, 8, 8)
3550
torch_model_out = torch_model_infer(input)
36-
print("Torch 输出:", torch_model_out.detach().numpy())
51+
print("Torch 输出:", torch_model_out.detach().numpy().mean())
3752

3853
# ============================================================
3954
# 3. 使用 InfiniCore.Module 系统加载并推理
4055
# ============================================================
4156

42-
# ===== 下面定义一个与 TorchMLP 对应的 InfiniCoreModule类 =====
4357
from python.infinicore.nn.modules import InfiniCoreModule
4458

45-
class InfiniCoreMLP(InfiniCoreModule):
46-
def __init__(self, in_dim=4, hidden_dim=8, out_dim=2):
59+
class InfiniCoreConvNet(InfiniCoreModule):
60+
def __init__(self, in_ch=3, hidden_ch=8, out_ch=3):
4761
super().__init__()
48-
self.fc1 = nn.Linear(in_dim, hidden_dim)
49-
self.fc2 = nn.Linear(hidden_dim, out_dim)
62+
self.conv1 = nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1)
63+
self.bn1 = nn.BatchNorm2d(hidden_ch)
64+
self.conv2 = nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1)
65+
self.bn2 = nn.BatchNorm2d(hidden_ch)
66+
self.conv3 = nn.Conv2d(hidden_ch, out_ch, kernel_size=1)
67+
self.relu = nn.ReLU()
68+
69+
# 保持与 Torch 模型一致的自定义参数和 buffer
70+
self.scale = nn.Parameter(torch.ones(1) * 0.5)
71+
self.register_buffer("offset", torch.tensor(0.1))
5072

5173
def forward(self, x):
52-
x = torch.relu(self.fc1(x))
53-
x = self.fc2(x)
74+
x = self.relu(self.bn1(self.conv1(x)))
75+
x = self.relu(self.bn2(self.conv2(x)))
76+
x = self.conv3(x)
77+
x = x * self.scale + self.offset
5478
return x
5579

56-
# ===== 使用 InfiniCoreMLP 读取 safetensors 并推理 =====
57-
infinicore_model_infer = InfiniCoreMLP()
58-
infinicore_model_infer.load_state_dict(safetensors.torch.load_file("torch_model.safetensors"))
59-
infinicore_model_out = infinicore_model_infer.forward(input)
80+
# ===== 使用 InfiniCoreConvNet 读取 safetensors 并推理 =====
81+
infinicore_model_infer = InfiniCoreConvNet()
82+
infinicore_model_infer.load_state_dict(safetensors.torch.load_file("torch_convnet_with_param.safetensors"))
83+
infinicore_model_infer.eval()
6084

61-
print("InfiniCore 输出:", infinicore_model_out.detach().numpy())
85+
infinicore_model_out = infinicore_model_infer.forward(input)
86+
print("InfiniCore 输出:", infinicore_model_out.detach().numpy().mean())
6287

6388
# ============================================================
6489
# 4. 对比结果
6590
# ============================================================
6691

6792
diff = (infinicore_model_out - torch_model_out).abs().max().item()
68-
print(f"InfiniCoreModule 与 Torch 最大误差: {diff:.6f}")
93+
print(f"InfiniCoreModule 与 Torch 最大误差: {diff:.8f}")
6994
if diff < 1e-6:
7095
print("✅ InfiniCoreModule 与 Torch 精度一致!")
7196
else:
4.76 KB
Binary file not shown.

torch_model.safetensors

-504 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)