Skip to content

Commit fe4eb0e

Browse files
committed
Initial VIN implementation
1 parent 40dbbee commit fe4eb0e

64 files changed

Lines changed: 5996 additions & 32 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
13.5 KB
Binary file not shown.
13.5 KB
Binary file not shown.
13.5 KB
Binary file not shown.

src/algorithms/algorithm_manager.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@
3030
from algorithms.classic.sample_based.rrt_star import RRT_Star
3131
from algorithms.classic.sample_based.rrt_connect import RRT_Connect
3232
from algorithms.classic.graph_based.wavefront import Wavefront
33-
from algorithms.lstm.LSTM_tile_by_tile import OnlineLSTM
34-
from algorithms.lstm.a_star_waypoint import WayPointNavigation
35-
from algorithms.lstm.combined_online_LSTM import CombinedOnlineLSTM
33+
from algorithms.learning.LSTM_tile_by_tile import OnlineLSTM
34+
from algorithms.learning.a_star_waypoint import WayPointNavigation
35+
from algorithms.learning.combined_online_LSTM import CombinedOnlineLSTM
36+
from algorithms.learning.VIN.VIN import VINAlgorithm
3637

3738
if HAS_OMPL:
3839
from algorithms.classic.sample_based.ompl_rrt import OMPL_RRT
@@ -103,7 +104,8 @@ def _static_init_(cls):
103104
"Dijkstra": (Dijkstra, DijkstraTesting, ([], {})),
104105
"Bug1": (Bug1, BasicTesting, ([], {})),
105106
"Bug2": (Bug2, BasicTesting, ([], {})),
106-
"Potential Field": (PotentialField, BasicTesting, ([], {}))
107+
"Potential Field": (PotentialField, BasicTesting, ([], {})),
108+
"VIN": (VINAlgorithm, BasicTesting, ([], {"load_name": "vin_pretrained"}))
107109
}
108110

109111
if HAS_OMPL:

src/algorithms/classic/testing/way_point_navigation_testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44

55
from algorithms.basic_testing import BasicTesting
6-
from algorithms.lstm.combined_online_LSTM import CombinedOnlineLSTM
6+
from algorithms.learning.combined_online_LSTM import CombinedOnlineLSTM
77
from simulator.services.debug import DebugLevel
88

99

src/algorithms/configuration/configuration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from algorithms.algorithm import Algorithm
55
from algorithms.basic_testing import BasicTesting
66
from algorithms.configuration.maps.map import Map
7-
from algorithms.lstm.LSTM_tile_by_tile import BasicLSTMModule
8-
from algorithms.lstm.ML_model import MLModel
7+
from algorithms.learning.LSTM_tile_by_tile import BasicLSTMModule
8+
from algorithms.learning.ML_model import MLModel
99
from simulator.services.debug import DebugLevel
1010

1111
from structures import Point

src/algorithms/lstm/LSTM_CAE_tile_by_tile.py renamed to src/algorithms/learning/LSTM_CAE_tile_by_tile.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
from algorithms.basic_testing import BasicTesting
1717
from algorithms.configuration.maps.map import Map
18-
from algorithms.lstm.LSTM_tile_by_tile import BasicLSTMModule, OnlineLSTM
19-
from algorithms.lstm.ML_model import MLModel, EvaluationResults
20-
from algorithms.lstm.map_processing import MapProcessing
18+
from algorithms.learning.LSTM_tile_by_tile import BasicLSTMModule, OnlineLSTM
19+
from algorithms.learning.ML_model import MLModel, EvaluationResults
20+
from algorithms.learning.map_processing import MapProcessing
2121
from simulator.services.services import Services
2222
from utility.constants import DATA_PATH
2323

src/algorithms/lstm/LSTM_CNN_tile_by_tile_obsolete.py renamed to src/algorithms/learning/LSTM_CNN_tile_by_tile_obsolete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch import nn
1111
1212
from algorithms.basic_testing import BasicTesting
13-
from algorithms.lstm.online_lstm import BasicLSTMModule, OnlineLSTM
13+
from algorithms.learning.online_lstm import BasicLSTMModule, OnlineLSTM
1414
from simulator.services.services import Services
1515
1616

src/algorithms/lstm/LSTM_tile_by_tile.py renamed to src/algorithms/learning/LSTM_tile_by_tile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from algorithms.basic_testing import BasicTesting
1111
from algorithms.configuration.entities.goal import Goal
1212
from algorithms.configuration.maps.map import Map
13-
from algorithms.lstm.ML_model import MLModel, SingleTensorDataset, PackedDataset
14-
from algorithms.lstm.map_processing import MapProcessing
13+
from algorithms.learning.ML_model import MLModel, SingleTensorDataset, PackedDataset
14+
from algorithms.learning.map_processing import MapProcessing
1515
from simulator.services.services import Services
1616
from simulator.views.map.display.entities_map_display import EntitiesMapDisplay
1717
from simulator.views.map.display.online_lstm_map_display import OnlineLSTMMapDisplay
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence, pack_sequence, PackedSequence
1313
from torch.utils import data
1414
from torch.utils.data import DataLoader, TensorDataset, Dataset, Subset
15-
from algorithms.lstm.map_processing import MapProcessing
15+
from algorithms.learning.map_processing import MapProcessing
1616
from simulator.services.debug import DebugLevel
1717
from simulator.services.services import Services
1818

@@ -154,7 +154,7 @@ class PackedDataset(Dataset):
154154
lengths: torch.Tensor
155155

156156
def __init__(self, seq: List[torch.Tensor]) -> None:
157-
from algorithms.lstm.LSTM_tile_by_tile import BasicLSTMModule
157+
from algorithms.learning.LSTM_tile_by_tile import BasicLSTMModule
158158

159159
ls = list(map(lambda el: el.shape[0], seq))
160160
self.perm = BasicLSTMModule.get_sort_by_lengths_indices(ls)

0 commit comments

Comments
 (0)