Skip to content

Commit 9ae3eb5

Browse files
committed
🔥 [Remove] layer_helper, same func as module_helper
1 parent 7d976be commit 9ae3eb5

3 files changed

Lines changed: 17 additions & 18 deletions

File tree

‎yolo/model/yolo.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from omegaconf import ListConfig, OmegaConf
66

77
from yolo.config.config import Config, Model, YOLOLayer
8-
from yolo.tools.layer_helper import get_layer_map
98
from yolo.tools.log_helper import log_model
9+
from yolo.tools.module_helper import get_layer_map
1010
from yolo.utils.drawer import draw_model
1111

1212

‎yolo/tools/layer_helper.py‎

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,3 @@
33
import torch.nn as nn
44

55
from yolo.model import module
6-
7-
8-
def auto_pad():
9-
raise NotImplementedError
10-
11-
12-
def get_layer_map():
13-
"""
14-
Dynamically generates a dictionary mapping class names to classes,
15-
filtering to include only those that are subclasses of nn.Module,
16-
ensuring they are relevant neural network layers.
17-
"""
18-
layer_map = {}
19-
for name, obj in inspect.getmembers(module, inspect.isclass):
20-
if issubclass(obj, nn.Module) and obj is not nn.Module:
21-
layer_map[name] = obj
22-
return layer_map

‎yolo/tools/module_helper.py‎

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
from typing import Tuple, Union
23

34
from torch import Tensor, nn
@@ -18,6 +19,21 @@ def auto_pad(kernel_size: _size_2_t, dilation: _size_2_t = 1, **kwargs) -> Tuple
1819
return (pad_h, pad_w)
1920

2021

22+
def get_layer_map():
23+
"""
24+
Dynamically generates a dictionary mapping class names to classes,
25+
filtering to include only those that are subclasses of nn.Module,
26+
ensuring they are relevant neural network layers.
27+
"""
28+
layer_map = {}
29+
from yolo.model import module
30+
31+
for name, obj in inspect.getmembers(module, inspect.isclass):
32+
if issubclass(obj, nn.Module) and obj is not nn.Module:
33+
layer_map[name] = obj
34+
return layer_map
35+
36+
2137
def get_activation(activation: str) -> nn.Module:
2238
"""
2339
Retrieves an activation function from the PyTorch nn module based on its name, case-insensitively.

0 commit comments

Comments
 (0)