Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 86 additions & 16 deletions deepxube/base/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import os
import shutil
import time
import threading


@dataclass
Expand All @@ -40,6 +41,7 @@ class TrainArgs:
:param loss_thresh: Loss threshold for updating.
:param targ_up_searches: If > 0, do a greedy search with updater for minimum given number of searches to test
if target network should be updated. Otherwise, it will be updated automatically.
:param policy_kl: KL divergence when training policy.
:param display: Number of iterations to display progress when training nnet. No display if 0.
:param skip_heur: Skip training of heuristic function
:param skip_policy: Skip training of policy
Expand All @@ -52,6 +54,7 @@ class TrainArgs:
rb: int = 1
loss_thresh: float = np.inf
targ_up_searches: int = 0
policy_kl: float = 0.1
skip_heur: bool = False
skip_policy: bool = False
display: int = 100
Expand Down Expand Up @@ -215,12 +218,18 @@ def __init__(self, nnet: NNet, updater: Up, to_main_q: Queue, from_main_qs: List
db_dtypes: List[np.dtype] = [x[1] for x in shapes_dtypes]
self.db: DataBuffer = DataBuffer(self.train_args.batch_size * self.updater.up_args.get_up_gen_itrs(), db_shapes, db_dtypes)

# async pipeline: second buffer for prefetching next round's data
self._prefetch_db: DataBuffer = DataBuffer(self.train_args.batch_size * self.updater.up_args.get_up_gen_itrs(), db_shapes, db_dtypes)
self._prefetch_thread: Optional[threading.Thread] = None
self._prefetch_done: threading.Event = threading.Event()
self._prefetch_error: Optional[Exception] = None
self._prefetch_itr_init: int = 0

# optimizer and criterion
self.optimizer: Optimizer = optim.Adam(self.nnet.parameters(), lr=self.train_args.lr)
self.train_start_time = time.time()

def update_step(self) -> None:
self.db.clear()
itr_init: int = self.status.itr

# print info
Expand All @@ -233,21 +242,47 @@ def update_step(self) -> None:
print(f"\nGetting Data - {', '.join(start_info_l)}")
times: Times = Times()

# start updater
start_time = time.time()
self.updater.start_update(self.status.step_probs.tolist(), num_gen, self.train_args.batch_size, self.device, self.on_gpu)
times.record_time("up_start", time.time() - start_time)

# do training
self.train_start_time = time.time()
loss: float
if not self.updater.up_args.sync_main:
self._get_update_data(num_gen, times)
self._end_update(itr_init, times)
loss = self._train(times)
else:
if self.updater.up_args.sync_main:
# sync_main path: unchanged
self.db.clear()
start_time = time.time()
self.updater.start_update(self.status.step_probs.tolist(), num_gen, self.train_args.batch_size, self.device, self.on_gpu)
times.record_time("up_start", time.time() - start_time)
self.train_start_time = time.time()
loss = self._train_sync_main(num_gen, times)
self._end_update(itr_init, times)
else:
# async pipeline path
if self._prefetch_thread is not None:
# WARM START: prefetch from previous update_step is running
start_time = time.time()
self._wait_prefetch()
times.record_time("prefetch_wait", time.time() - start_time)
# swap buffers: _prefetch_db has fresh data
self.db, self._prefetch_db = self._prefetch_db, self.db
# end the update that was prefetched
self._end_update(self._prefetch_itr_init, times)
else:
# COLD START: first iteration, no prefetch available
self.db.clear()
start_time = time.time()
self.updater.start_update(self.status.step_probs.tolist(), num_gen, self.train_args.batch_size, self.device, self.on_gpu)
times.record_time("up_start", time.time() - start_time)
self._get_update_data_into(self.db, num_gen, times)
self._end_update(itr_init, times)

# start next round's data generation (overlaps with training below)
if self.status.itr + self.updater.up_args.up_itrs <= self.train_args.max_itrs:
start_time = time.time()
self.updater.start_update(self.status.step_probs.tolist(), num_gen, self.train_args.batch_size, self.device, self.on_gpu)
times.record_time("up_start_next", time.time() - start_time)
self._prefetch_itr_init = self.status.itr
self._start_prefetch(num_gen)

# train on current data (overlaps with prefetch thread)
self.train_start_time = time.time()
loss = self._train(times)

# save nnet
start_time = time.time()
Expand All @@ -273,14 +308,49 @@ def update_step(self) -> None:
print(f"Train - itrs: {self.updater.up_args.up_itrs}, loss: {loss:.2E}, targ_updated: {update_targ}")
print(f"Times - {times.get_time_str()}")

def _get_update_data(self, num_gen: int, times: Times) -> None:
# --- async pipeline helpers ---

def _get_update_data_into(self, db: DataBuffer, num_gen: int, times: Times) -> None:
start_time = time.time()
while self.db.size() < num_gen:
while db.size() < num_gen:
data_l: List[List[NDArray]] = self.updater.get_update_data()
for data in data_l:
self.db.add(data)
db.add(data)
times.record_time("up_data", time.time() - start_time)

def _prefetch_data(self, num_gen: int) -> None:
try:
while self._prefetch_db.size() < num_gen:
data_l: List[List[NDArray]] = self.updater.get_update_data()
for data in data_l:
self._prefetch_db.add(data)
except Exception as e:
self._prefetch_error = e
finally:
self._prefetch_done.set()

def _start_prefetch(self, num_gen: int) -> None:
self._prefetch_db.clear()
self._prefetch_done.clear()
self._prefetch_error = None
self._prefetch_thread = threading.Thread(target=self._prefetch_data, args=(num_gen,), daemon=True)
self._prefetch_thread.start()

def _wait_prefetch(self) -> None:
if self._prefetch_thread is not None:
self._prefetch_thread.join()
self._prefetch_thread = None
if self._prefetch_error is not None:
raise self._prefetch_error

def cleanup_prefetch(self) -> None:
if self._prefetch_thread is not None:
self._wait_prefetch()
self.updater.end_update()

def _get_update_data(self, num_gen: int, times: Times) -> None:
self._get_update_data_into(self.db, num_gen, times)

def _train(self, times: Times) -> float:
loss: float = np.inf
first_itr_in_update: bool = True
Expand Down
5 changes: 5 additions & 0 deletions deepxube/trainers/utils/train_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,11 @@ def train(domain: Domain, heur_nnet_par: Optional[HeurNNetPar], update_heur: Opt
if (test_args is not None) and up_itr_performed:
test(domain, heur_nnet_par, train_heur, policy_nnet_par, train_policy, test_args, writer, curr_itr)

# clean up any active prefetch before stopping procs
for train_obj in [train_heur, train_policy]:
if train_obj is not None:
train_obj.cleanup_prefetch()

# stop procs
for updater in [update_heur, update_policy]:
if updater is not None:
Expand Down