Skip to content

Commit 33028da

Browse files
committed
πŸ”€ [Merge] branch 'SETUP' into TEST
2 parents c0e778f + a33e03b commit 33028da

10 files changed

Lines changed: 64 additions & 55 deletions

File tree

β€Žexamples/lazy.pyβ€Ž

Lines changed: 0 additions & 37 deletions
This file was deleted.

β€Žexamples/notebook_colab.ipynbβ€Ž

Whitespace-only changes.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": []
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {},
14+
"outputs": [],
15+
"source": []
16+
}
17+
],
18+
"metadata": {
19+
"language_info": {
20+
"name": "python"
21+
}
22+
},
23+
"nbformat": 4,
24+
"nbformat_minor": 2
25+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": []
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {},
14+
"outputs": [],
15+
"source": []
16+
}
17+
],
18+
"metadata": {
19+
"language_info": {
20+
"name": "python"
21+
}
22+
},
23+
"nbformat": 4,
24+
"nbformat_minor": 2
25+
}
Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
sys.path.append(str(project_root))
99

1010
from yolo.config.config import Config
11-
from yolo.model.yolo import get_model
11+
from yolo.model.yolo import create_model
1212
from yolo.tools.data_loader import create_dataloader
1313
from yolo.tools.solver import ModelTester
1414
from yolo.utils.logging_utils import custom_logger, validate_log_directory
@@ -17,15 +17,11 @@
1717
@hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
1818
def main(cfg: Config):
1919
custom_logger()
20-
save_path = validate_log_directory(cfg, cfg.name)
21-
22-
device = torch.device(cfg.device)
23-
model = get_model(cfg).to(device)
24-
2520
save_path = validate_log_directory(cfg, cfg.name)
2621
dataloader = create_dataloader(cfg)
22+
2723
device = torch.device(cfg.device)
28-
model = get_model(cfg).to(device)
24+
model = create_model(cfg).to(device)
2925

3026
tester = ModelTester(cfg, model, save_path, device)
3127
tester.solve(dataloader)
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
sys.path.append(str(project_root))
99

1010
from yolo.config.config import Config
11-
from yolo.model.yolo import get_model
11+
from yolo.model.yolo import create_model
1212
from yolo.tools.data_loader import create_dataloader
1313
from yolo.tools.solver import ModelTrainer
1414
from yolo.utils.logging_utils import custom_logger, validate_log_directory
@@ -21,7 +21,7 @@ def main(cfg: Config):
2121
dataloader = create_dataloader(cfg)
2222
# TODO: get_device or rank, for DDP mode
2323
device = torch.device(cfg.device)
24-
model = get_model(cfg).to(device)
24+
model = create_model(cfg).to(device)
2525

2626
trainer = ModelTrainer(cfg, model, save_path, device)
2727
trainer.solve(dataloader, cfg.task.epoch)

β€Žtests/test_model/test_yolo.pyβ€Ž

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
project_root = Path(__file__).resolve().parent.parent.parent
99
sys.path.append(str(project_root))
1010

11-
from yolo.model.yolo import YOLO, get_model
11+
from yolo.model.yolo import YOLO, create_model
1212

1313
config_path = "../../yolo/config"
1414
config_name = "config"
@@ -24,18 +24,18 @@ def test_build_model():
2424
assert len(model.model) == 38
2525

2626

27-
def test_get_model():
27+
def test_create_model():
2828
with initialize(config_path=config_path, version_base=None):
2929
cfg = compose(config_name=config_name)
3030
cfg.weight = None
31-
model = get_model(cfg)
31+
model = create_model(cfg)
3232
assert isinstance(model, YOLO)
3333

3434

3535
def test_yolo_forward_output_shape():
3636
with initialize(config_path=config_path, version_base=None):
3737
cfg = compose(config_name=config_name)
38-
model = get_model(cfg)
38+
model = create_model(cfg)
3939
# 2 - batch size, 3 - number of channels, 640x640 - image dimensions
4040
dummy_input = torch.rand(2, 3, 640, 640)
4141

β€Žyolo/lazy.pyβ€Ž

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
sys.path.append(str(project_root))
99

1010
from yolo.config.config import Config
11-
from yolo.model.yolo import get_model
11+
from yolo.model.yolo import create_model
1212
from yolo.tools.data_loader import create_dataloader
1313
from yolo.tools.solver import ModelTester, ModelTrainer
1414
from yolo.utils.logging_utils import custom_logger, validate_log_directory
@@ -20,7 +20,7 @@ def main(cfg: Config):
2020
save_path = validate_log_directory(cfg, cfg.name)
2121
dataloader = create_dataloader(cfg)
2222
device = torch.device(cfg.device)
23-
model = get_model(cfg).to(device)
23+
model = create_model(cfg).to(device)
2424

2525
if cfg.task.task == "train":
2626
trainer = ModelTrainer(cfg, model, save_path, device)

β€Žyolo/model/yolo.pyβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Di
116116
raise ValueError(f"Unsupported layer type: {layer_type}")
117117

118118

119-
def get_model(cfg: Config) -> YOLO:
119+
def create_model(cfg: Config) -> YOLO:
120120
"""Constructs and returns a model from a Dictionary configuration file.
121121
122122
Args:

β€Žyolo/tools/drawer.pyβ€Ž

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ def draw_model(*, model_cfg=None, model=None, v7_base=False):
5757
from graphviz import Digraph
5858

5959
if model_cfg:
60-
from yolo.model.yolo import get_model
60+
from yolo.model.yolo import create_model
6161

62-
model = get_model(model_cfg)
62+
model = create_model(model_cfg)
6363
elif model is None:
6464
raise ValueError("Drawing Object is None")
6565

0 commit comments

Comments
Β (0)