Skip to content

Commit 4c06d71

Browse files
committed
fix: replace torch.inference_mode with torch.no_grad to prevent model pollution
torch.inference_mode() + model.to(device) creates inference tensors that replace model parameters. This breaks any downstream autograd operation (like torch-pruning's dependency graph tracing). torch.no_grad() provides the same performance benefit without tainting model state. Affected: compute_speed, compute_energy, _profile_layers
1 parent e22d46c commit 4c06d71

7 files changed

Lines changed: 676 additions & 131 deletions

File tree

fasterbench/energy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def _clear_stale_codecarbon_lock() -> None:
6565

6666

6767
#| export
68-
@torch.inference_mode()
68+
@torch.no_grad()
6969
def compute_energy(
7070
model: torch.nn.Module, # model to benchmark
7171
sample: torch.Tensor, # input tensor (with batch dimension)

fasterbench/profiling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def hook(mod, inp, output):
157157
}
158158

159159

160-
@torch.inference_mode()
160+
@torch.no_grad()
161161
def _profile_layers(
162162
model: nn.Module, # model to profile
163163
sample: torch.Tensor, # input tensor (with batch dimension)

fasterbench/speed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _forward_latencies(
9898

9999

100100
#| export
101-
@torch.inference_mode()
101+
@torch.no_grad()
102102
def compute_speed(
103103
model: nn.Module, # model to benchmark
104104
sample: torch.Tensor, # input tensor (with batch dimension)

nbs/analysis/profiling.ipynb

Lines changed: 223 additions & 2 deletions
Large diffs are not rendered by default.

nbs/metrics/energy.ipynb

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,130 @@
4242
"id": "58dcf143",
4343
"metadata": {},
4444
"outputs": [],
45-
"source": "#| export\n@dataclass(slots=True)\nclass EnergyMetrics:\n \"\"\"Energy consumption and carbon footprint metrics.\"\"\"\n mean_watts: float # average power during measurement\n energy_wh: float # Wh per inference\n co2_eq_g: float # g CO₂-eq per inference\n\n def as_dict(self) -> dict[str, float]:\n return asdict(self)\n\n\n#| export\ndef _nan_energy_metrics(device: str) -> EnergyMetrics: # device string (unused, for consistent signature)\n \"\"\"Create EnergyMetrics with NaN values for failed benchmarks.\"\"\"\n nan = float(\"nan\")\n return EnergyMetrics(nan, nan, nan)\n\n\n#| export\ndef _clear_stale_codecarbon_lock() -> None:\n \"\"\"Remove stale codecarbon lock file if the owning process no longer exists.\"\"\"\n import tempfile\n lock_path = os.path.join(tempfile.gettempdir(), \".codecarbon.lock\")\n if not os.path.exists(lock_path):\n return\n try:\n # Read the PID from the lock file (codecarbon writes its PID there)\n with open(lock_path) as f:\n content = f.read().strip()\n if content:\n pid = int(content)\n os.kill(pid, 0) # Check if process exists (signal 0 = no-op)\n # Process exists — lock is valid, don't remove\n return\n except (ValueError, ProcessLookupError, PermissionError, OSError):\n pass # PID invalid or process dead — lock is stale\n try:\n os.remove(lock_path)\n except OSError:\n pass\n\n\n#| export\n@torch.inference_mode()\ndef compute_energy(\n model: torch.nn.Module, # model to benchmark\n sample: torch.Tensor, # input tensor (with batch dimension)\n *,\n device: str | torch.device = \"cpu\", # device to run on\n warmup: int = 20, # warmup iterations\n steps: int = 100, # measurement iterations\n offline: bool = True, # use offline emissions tracker\n country_iso: str | None = None, # country ISO code for carbon intensity\n measure_secs: int = 1, # power sampling interval\n) -> EnergyMetrics:\n \"\"\"Measure power consumption and carbon footprint using codecarbon.\"\"\"\n if EmissionsTracker is None:\n warnings.warn(\"codecarbon not installed – returning NaNs\")\n return _nan_energy_metrics(str(device))\n\n _clear_stale_codecarbon_lock()\n\n Tracker = OfflineEmissionsTracker if offline else EmissionsTracker\n tracker = Tracker(\n project_name=\"fasterbench\",\n country_iso_code=(country_iso or os.getenv(\"NNBENCH_ISO\", \"USA\")),\n measure_power_secs=measure_secs,\n save_to_file=False,\n log_level=\"critical\",\n )\n\n with _device_ctx(device) as dev:\n model = model.eval().to(dev)\n sample = sample.to(dev, non_blocking=True)\n\n for _ in range(warmup):\n model(sample)\n _sync(dev)\n\n tracker.start()\n try:\n t0 = time.perf_counter()\n for _ in range(steps):\n model(sample)\n _sync(dev)\n finally:\n tracker.stop()\n dur_s = time.perf_counter() - t0\n\n # codecarbon silently fails if another instance is running,\n # leaving final_emissions_data as None\n if tracker.final_emissions_data is None:\n warnings.warn(\"codecarbon tracker did not collect data (another instance may be running)\")\n return _nan_energy_metrics(str(device))\n\n ene_kwh = tracker.final_emissions_data.energy_consumed\n co2_kg = tracker.final_emissions\n mean_w = (ene_kwh * 3600_000) / dur_s\n\n return EnergyMetrics(\n mean_watts=mean_w,\n energy_wh=(ene_kwh * 1_000) / steps,\n co2_eq_g=(co2_kg * 1_000) / steps,\n )\n\n\n#| export\ndef compute_energy_multi(\n model: torch.nn.Module, # model to benchmark\n sample: torch.Tensor, # input tensor (with batch dimension)\n *,\n devices: Sequence[str | torch.device] | None = None, # devices to benchmark (default: cpu + cuda)\n **kwargs,\n) -> dict[str, EnergyMetrics]:\n \"\"\"Measure energy on multiple devices.\"\"\"\n return _run_on_devices(\n compute_energy, model, sample, devices,\n nan_factory=_nan_energy_metrics,\n metric_name=\"Energy\",\n **kwargs\n )"
45+
"source": [
46+
"#| export\n",
47+
"@dataclass(slots=True)\n",
48+
"class EnergyMetrics:\n",
49+
" \"\"\"Energy consumption and carbon footprint metrics.\"\"\"\n",
50+
" mean_watts: float # average power during measurement\n",
51+
" energy_wh: float # Wh per inference\n",
52+
" co2_eq_g: float # g CO₂-eq per inference\n",
53+
"\n",
54+
" def as_dict(self) -> dict[str, float]:\n",
55+
" return asdict(self)\n",
56+
"\n",
57+
"\n",
58+
"#| export\n",
59+
"def _nan_energy_metrics(device: str) -> EnergyMetrics: # device string (unused, for consistent signature)\n",
60+
" \"\"\"Create EnergyMetrics with NaN values for failed benchmarks.\"\"\"\n",
61+
" nan = float(\"nan\")\n",
62+
" return EnergyMetrics(nan, nan, nan)\n",
63+
"\n",
64+
"\n",
65+
"#| export\n",
66+
"def _clear_stale_codecarbon_lock() -> None:\n",
67+
" \"\"\"Remove stale codecarbon lock file if the owning process no longer exists.\"\"\"\n",
68+
" import tempfile\n",
69+
" lock_path = os.path.join(tempfile.gettempdir(), \".codecarbon.lock\")\n",
70+
" if not os.path.exists(lock_path):\n",
71+
" return\n",
72+
" try:\n",
73+
" # Read the PID from the lock file (codecarbon writes its PID there)\n",
74+
" with open(lock_path) as f:\n",
75+
" content = f.read().strip()\n",
76+
" if content:\n",
77+
" pid = int(content)\n",
78+
" os.kill(pid, 0) # Check if process exists (signal 0 = no-op)\n",
79+
" # Process exists — lock is valid, don't remove\n",
80+
" return\n",
81+
" except (ValueError, ProcessLookupError, PermissionError, OSError):\n",
82+
" pass # PID invalid or process dead — lock is stale\n",
83+
" try:\n",
84+
" os.remove(lock_path)\n",
85+
" except OSError:\n",
86+
" pass\n",
87+
"\n",
88+
"\n",
89+
"#| export\n",
90+
"@torch.no_grad()\n",
91+
"def compute_energy(\n",
92+
" model: torch.nn.Module, # model to benchmark\n",
93+
" sample: torch.Tensor, # input tensor (with batch dimension)\n",
94+
" *,\n",
95+
" device: str | torch.device = \"cpu\", # device to run on\n",
96+
" warmup: int = 20, # warmup iterations\n",
97+
" steps: int = 100, # measurement iterations\n",
98+
" offline: bool = True, # use offline emissions tracker\n",
99+
" country_iso: str | None = None, # country ISO code for carbon intensity\n",
100+
" measure_secs: int = 1, # power sampling interval\n",
101+
") -> EnergyMetrics:\n",
102+
" \"\"\"Measure power consumption and carbon footprint using codecarbon.\"\"\"\n",
103+
" if EmissionsTracker is None:\n",
104+
" warnings.warn(\"codecarbon not installed – returning NaNs\")\n",
105+
" return _nan_energy_metrics(str(device))\n",
106+
"\n",
107+
" _clear_stale_codecarbon_lock()\n",
108+
"\n",
109+
" Tracker = OfflineEmissionsTracker if offline else EmissionsTracker\n",
110+
" tracker = Tracker(\n",
111+
" project_name=\"fasterbench\",\n",
112+
" country_iso_code=(country_iso or os.getenv(\"NNBENCH_ISO\", \"USA\")),\n",
113+
" measure_power_secs=measure_secs,\n",
114+
" save_to_file=False,\n",
115+
" log_level=\"critical\",\n",
116+
" )\n",
117+
"\n",
118+
" with _device_ctx(device) as dev:\n",
119+
" model = model.eval().to(dev)\n",
120+
" sample = sample.to(dev, non_blocking=True)\n",
121+
"\n",
122+
" for _ in range(warmup):\n",
123+
" model(sample)\n",
124+
" _sync(dev)\n",
125+
"\n",
126+
" tracker.start()\n",
127+
" try:\n",
128+
" t0 = time.perf_counter()\n",
129+
" for _ in range(steps):\n",
130+
" model(sample)\n",
131+
" _sync(dev)\n",
132+
" finally:\n",
133+
" tracker.stop()\n",
134+
" dur_s = time.perf_counter() - t0\n",
135+
"\n",
136+
" # codecarbon silently fails if another instance is running,\n",
137+
" # leaving final_emissions_data as None\n",
138+
" if tracker.final_emissions_data is None:\n",
139+
" warnings.warn(\"codecarbon tracker did not collect data (another instance may be running)\")\n",
140+
" return _nan_energy_metrics(str(device))\n",
141+
"\n",
142+
" ene_kwh = tracker.final_emissions_data.energy_consumed\n",
143+
" co2_kg = tracker.final_emissions\n",
144+
" mean_w = (ene_kwh * 3600_000) / dur_s\n",
145+
"\n",
146+
" return EnergyMetrics(\n",
147+
" mean_watts=mean_w,\n",
148+
" energy_wh=(ene_kwh * 1_000) / steps,\n",
149+
" co2_eq_g=(co2_kg * 1_000) / steps,\n",
150+
" )\n",
151+
"\n",
152+
"\n",
153+
"#| export\n",
154+
"def compute_energy_multi(\n",
155+
" model: torch.nn.Module, # model to benchmark\n",
156+
" sample: torch.Tensor, # input tensor (with batch dimension)\n",
157+
" *,\n",
158+
" devices: Sequence[str | torch.device] | None = None, # devices to benchmark (default: cpu + cuda)\n",
159+
" **kwargs,\n",
160+
") -> dict[str, EnergyMetrics]:\n",
161+
" \"\"\"Measure energy on multiple devices.\"\"\"\n",
162+
" return _run_on_devices(\n",
163+
" compute_energy, model, sample, devices,\n",
164+
" nan_factory=_nan_energy_metrics,\n",
165+
" metric_name=\"Energy\",\n",
166+
" **kwargs\n",
167+
" )"
168+
]
46169
},
47170
{
48171
"cell_type": "code",
@@ -94,4 +217,4 @@
94217
"metadata": {},
95218
"nbformat": 4,
96219
"nbformat_minor": 5
97-
}
220+
}

0 commit comments

Comments
 (0)