-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhook_utils.py
More file actions
52 lines (43 loc) · 1.86 KB
/
hook_utils.py
File metadata and controls
52 lines (43 loc) · 1.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import torch
import torch.nn as nn
import numpy as np
import random
import logging
import functools
import sys
from typing import Callable, Any
def get_output(module, input, output):
return output
def get_input(module, input, output):
return output
def get_input_output(module, input, output):
return input, output
def get_module_input_output(module, input, output):
return module, input, output
def register_hook_fn_to_module(model: nn.Module, module_name: str, hook_fn: Callable[[nn.Module, torch.Tensor, torch.Tensor], Any]):
results_dict = {}
for name, m in model.named_modules():
if module_name == name:
handle = m.register_forward_hook(_hook_fn_cntr(name, results_dict, hook_fn))
return handle, results_dict
def register_bkw_hook_fn_to_module(model: nn.Module, module_name: str, hook_fn: Callable[[nn.Module, torch.Tensor, torch.Tensor], Any]):
results_dict = {}
for name, m in model.named_modules():
if module_name == name:
handle = m.register_full_backward_hook(_hook_fn_cntr(name, results_dict, hook_fn))
return handle, results_dict
def register_hook_fn_to_all_modules(model: nn.Module, hook_fn: Callable[[nn.Module, torch.Tensor, torch.Tensor], Any]):
results_dict = {}
for name, m in model.named_modules():
_ = m.register_forward_hook(_hook_fn_cntr(name, results_dict, hook_fn))
return results_dict
def register_bkw_hook_fn_to_all_modules(model: nn.Module, hook_fn: Callable[[nn.Module, torch.Tensor, torch.Tensor], Any]):
results_dict = {}
for name, m in model.named_modules():
_ = m.register_full_backward_hook(_hook_fn_cntr(name, results_dict, hook_fn))
return results_dict
def _hook_fn_cntr(name, activation_dict, hook_fn):
def hook(model, input, output):
activation_dict[name] = hook_fn(model, input, output)
return hook