Skip to content

Commit 0935d5a

Browse files
authored
3429 Enhance the scalar write logic of TensorBoardStatsHandler (#3431)
* [DLMED] extract write logic Signed-off-by: Nic Ma <nma@nvidia.com> * [DLMED] update according to comments Signed-off-by: Nic Ma <nma@nvidia.com> * [DLMED] update according to comments Signed-off-by: Nic Ma <nma@nvidia.com> * [DLMED] update according to comments Signed-off-by: Nic Ma <nma@nvidia.com>
1 parent f6a0c87 commit 0935d5a

4 files changed

Lines changed: 49 additions & 12 deletions

File tree

monai/handlers/classification_saver.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,14 @@ def attach(self, engine: Engine) -> None:
9898
if not engine.has_event_handler(self._finalize, Events.EPOCH_COMPLETED):
9999
engine.add_event_handler(Events.EPOCH_COMPLETED, self._finalize)
100100

101-
def _started(self, engine: Engine) -> None:
101+
def _started(self, _engine: Engine) -> None:
102+
"""
103+
Initialize internal buffers.
104+
105+
Args:
106+
_engine: Ignite Engine, unused argument.
107+
108+
"""
102109
self._outputs = []
103110
self._filenames = []
104111

@@ -120,12 +127,12 @@ def __call__(self, engine: Engine) -> None:
120127
o = o.detach()
121128
self._outputs.append(o)
122129

123-
def _finalize(self, engine: Engine) -> None:
130+
def _finalize(self, _engine: Engine) -> None:
124131
"""
125132
All gather classification results from ranks and save to CSV file.
126133
127134
Args:
128-
engine: Ignite Engine, it can be a trainer, validator or evaluator.
135+
_engine: Ignite Engine, unused argument.
129136
"""
130137
ws = idist.get_world_size()
131138
if self.save_rank >= ws:

monai/handlers/metrics_saver.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,14 @@ def attach(self, engine: Engine) -> None:
105105
engine.add_event_handler(Events.ITERATION_COMPLETED, self._get_filenames)
106106
engine.add_event_handler(Events.EPOCH_COMPLETED, self)
107107

108-
def _started(self, engine: Engine) -> None:
108+
def _started(self, _engine: Engine) -> None:
109+
"""
110+
Initialize internal buffers.
111+
112+
Args:
113+
_engine: Ignite Engine, unused argument.
114+
115+
"""
109116
self._filenames = []
110117

111118
def _get_filenames(self, engine: Engine) -> None:

monai/handlers/stats_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,14 @@ def iteration_completed(self, engine: Engine) -> None:
143143
else:
144144
self._default_iteration_print(engine)
145145

146-
def exception_raised(self, engine: Engine, e: Exception) -> None:
146+
def exception_raised(self, _engine: Engine, e: Exception) -> None:
147147
"""
148148
Handler for train or validation/evaluation exception raised Event.
149149
Print the exception information and traceback. This callback may be skipped because the logic
150150
with Ignite can only trigger the first attached handler for `EXCEPTION_RAISED` event.
151151
152152
Args:
153-
engine: Ignite Engine, it can be a trainer, validator or evaluator.
153+
_engine: Ignite Engine, unused argument.
154154
e: the exception caught in Ignite during engine.run().
155155
156156
"""

monai/handlers/tensorboard_handlers.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,21 @@ def iteration_completed(self, engine: Engine) -> None:
173173
else:
174174
self._default_iteration_writer(engine, self._writer)
175175

176+
def _write_scalar(self, _engine: Engine, writer: SummaryWriter, tag: str, value: Any, step: int) -> None:
177+
"""
178+
Write scale value into TensorBoard.
179+
Default to call `SummaryWriter.add_scalar()`.
180+
181+
Args:
182+
_engine: Ignite Engine, unused argument.
183+
writer: TensorBoard or TensorBoardX writer, passed or created in TensorBoardHandler.
184+
tag: tag name in the TensorBoard.
185+
value: value of the scalar data for current step.
186+
step: index of current step.
187+
188+
"""
189+
writer.add_scalar(tag, value, step)
190+
176191
def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter) -> None:
177192
"""
178193
Execute epoch level event write operation.
@@ -188,11 +203,11 @@ def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter) -> None:
188203
summary_dict = engine.state.metrics
189204
for name, value in summary_dict.items():
190205
if is_scalar(value):
191-
writer.add_scalar(name, value, current_epoch)
206+
self._write_scalar(engine, writer, name, value, current_epoch)
192207

193208
if self.state_attributes is not None:
194209
for attr in self.state_attributes:
195-
writer.add_scalar(attr, getattr(engine.state, attr, None), current_epoch)
210+
self._write_scalar(engine, writer, attr, getattr(engine.state, attr, None), current_epoch)
196211
writer.flush()
197212

198213
def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> None:
@@ -221,12 +236,20 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> No
221236
" {}:{}".format(name, type(value))
222237
)
223238
continue # not plot multi dimensional output
224-
writer.add_scalar(
225-
name, value.item() if isinstance(value, torch.Tensor) else value, engine.state.iteration
239+
self._write_scalar(
240+
_engine=engine,
241+
writer=writer,
242+
tag=name,
243+
value=value.item() if isinstance(value, torch.Tensor) else value,
244+
step=engine.state.iteration,
226245
)
227246
elif is_scalar(loss): # not printing multi dimensional output
228-
writer.add_scalar(
229-
self.tag_name, loss.item() if isinstance(loss, torch.Tensor) else loss, engine.state.iteration
247+
self._write_scalar(
248+
_engine=engine,
249+
writer=writer,
250+
tag=self.tag_name,
251+
value=loss.item() if isinstance(loss, torch.Tensor) else loss,
252+
step=engine.state.iteration,
230253
)
231254
else:
232255
warnings.warn(

0 commit comments

Comments
 (0)