Skip to content

Commit 3e2b652

Browse files
authored
intermediate_source/torch_compile_conv_bn_fuser.py ๋ฒˆ์—ญ (#1119)
* feat: translation * Update comments for clarity in torch_compile_conv_bn_fuser * Refine comments in torch_compile_conv_bn_fuser.py Updated comments for clarity and consistency in Korean. * Refine comments in torch_compile_conv_bn_fuser.py Updated comments for consistency and clarity in Korean. * Fix formatting in Korean comments in the script
1 parent 75400ca commit 3e2b652

1 file changed

Lines changed: 85 additions & 94 deletions

File tree

Lines changed: 85 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,32 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
Building a Convolution/Batch Norm fuser with torch.compile
4-
===========================================================
3+
torch.compile ๊ธฐ๋ฐ˜ ํ•ฉ์„ฑ๊ณฑยท๋ฐฐ์น˜ ์ •๊ทœํ™” ํ“จ์ €(Convolution/Batch Norm fuser) ๋งŒ๋“ค๊ธฐ
4+
=========================================================================
55
6-
**Author:** `Horace He <https://github.com/chillee>`_, `Will Feng <https://github.com/yf225>`_
6+
**์ €์ž:** `Horace He <https://github.com/chillee>`_, `Will Feng <https://github.com/yf225>`_
7+
**๋ฒˆ์—ญ:** `์‹ฌ๊ธฐํƒ <https://github.com/skt0725>`_
78
89
.. grid:: 2
910
10-
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
11+
.. grid-item-card:: :octicon:`mortar-board;1em;` ๋ฐฐ์šธ ๋‚ด์šฉ
1112
:class-card: card-prerequisites
1213
13-
* How to register custom fusion patterns with torch.compile's pattern matcher
14+
* torch.compile์˜ ํŒจํ„ด ๋งค์ฒ˜์— ์ปค์Šคํ…€ ํ“จ์ „ ํŒจํ„ด์„ ๋“ฑ๋กํ•˜๋Š” ๋ฐฉ๋ฒ•
1415
15-
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
16+
.. grid-item-card:: :octicon:`list-unordered;1em;` ์ „์ œ ์กฐ๊ฑด
1617
:class-card: card-prerequisites
1718
1819
* PyTorch v2.7.0
1920
2021
.. note::
21-
This optimization only works for models in inference mode (i.e. ``model.eval()``).
22-
However, torch.compile's pattern matching system works for both training and inference.
22+
์ด ์ตœ์ ํ™”๋Š” ์ถ”๋ก  ๋ชจ๋“œ์˜ ๋ชจ๋ธ์—๋งŒ ์ ์šฉ๋ฉ๋‹ˆ๋‹ค (์˜ˆ: ``model.eval()``).
23+
ํ•˜์ง€๋งŒ torch.compile์˜ ํŒจํ„ด ๋งค์นญ ์‹œ์Šคํ…œ์€ ํ•™์Šต๊ณผ ์ถ”๋ก  ๋ชจ๋‘์—์„œ ๋™์ž‘ํ•ฉ๋‹ˆ๋‹ค.
2324
2425
"""
2526

2627

2728
######################################################################
28-
# First, let's get some imports out of the way (we will be using all
29-
# of these later in the code).
29+
# ๋จผ์ € ์ดํ›„ ์ฝ”๋“œ์—์„œ ์‚ฌ์šฉํ•  ๋ชจ๋“ˆ๋“ค์„ import ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.
3030

3131
from typing import Type, Dict, Any, Tuple, Iterable
3232
import copy
@@ -36,10 +36,10 @@
3636
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3737

3838
######################################################################
39-
# For this tutorial, we are going to create a model consisting of convolutions
40-
# and batch norms. Note that this model has some tricky components - some of
41-
# the conv/batch norm patterns are hidden within Sequentials and one of the
42-
# ``BatchNorms`` is wrapped in another Module.
39+
# ์ด๋ฒˆ ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” ํ•ฉ์„ฑ๊ณฑ๊ณผ ๋ฐฐ์น˜ ์ •๊ทœํ™”๋กœ ๊ตฌ์„ฑ๋œ ๋ชจ๋ธ์„ ๋งŒ๋“ค์–ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
40+
# ์ด ๋ชจ๋ธ์—๋Š” ๋ช‡ ๊ฐ€์ง€ ๊นŒ๋‹ค๋กœ์šด ์š”์†Œ๊ฐ€ ์žˆ๋‹ค๋Š” ์ ์— ์œ ์˜ํ•˜์„ธ์š”.
41+
# ์ผ๋ถ€ ํ•ฉ์„ฑ๊ณฑยท๋ฐฐ์น˜ ์ •๊ทœํ™” ํŒจํ„ด์€ Sequential ๋‚ด๋ถ€์— ์ˆจ๊ฒจ์ ธ ์žˆ์œผ๋ฉฐ, ``BatchNorms`` ์ค‘ ํ•˜๋‚˜๋Š” ๋˜ ๋‹ค๋ฅธ
42+
# Module๋กœ ๊ฐ์‹ธ์ ธ ์žˆ์Šต๋‹ˆ๋‹ค.
4343

4444
class WrappedBatchNorm(nn.Module):
4545
def __init__(self):
@@ -72,42 +72,37 @@ def forward(self, x):
7272
model.eval()
7373

7474
######################################################################
75-
# Fusing Convolution with Batch Norm
75+
# ํ•ฉ์„ฑ๊ณฑ๊ณผ ๋ฐฐ์น˜ ์ •๊ทœํ™” ํ“จ์ „ํ•˜๊ธฐ
7676
# -----------------------------------------
77-
# One of the primary challenges with trying to automatically fuse convolution
78-
# and batch norm in PyTorch is that PyTorch does not provide an easy way of
79-
# accessing the computational graph. torch.compile resolves this problem by
80-
# capturing the computational graph during compilation, allowing us to apply
81-
# pattern-based optimizations across the entire model, including operations
82-
# nested within Sequential modules or wrapped in custom modules.
77+
# ํ•ฉ์„ฑ๊ณฑ๊ณผ ๋ฐฐ์น˜ ์ •๊ทœํ™”๋ฅผ ์ž๋™์œผ๋กœ ํ“จ์ „ํ•˜๋ ค ํ•  ๋•Œ์˜ ์ฃผ์š” ์–ด๋ ค์›€ ์ค‘ ํ•˜๋‚˜๋Š” PyTorch๊ฐ€ ๊ณ„์‚ฐ
78+
# ๊ทธ๋ž˜ํ”„(computational graph)์— ์‰ฝ๊ฒŒ ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•์„ ์ œ๊ณตํ•˜์ง€ ์•Š๋Š”๋‹ค๋Š” ์ ์ž…๋‹ˆ๋‹ค.
79+
# torch.compile์€ ์ปดํŒŒ์ผ ๊ณผ์ •์—์„œ ๊ณ„์‚ฐ ๊ทธ๋ž˜ํ”„๋ฅผ ํ™•๋ณดํ•จ์œผ๋กœ์จ ์ด ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋ฉฐ,
80+
# ์ด๋ฅผ ํ†ตํ•ด Sequential ๋ชจ๋“ˆ ๋‚ด๋ถ€์— ์žˆ๋Š” ์ค‘์ฒฉ๋œ ์—ฐ์‚ฐ์ด๋‚˜ ์‚ฌ์šฉ์ž ์ •์˜ ๋ชจ๋“ˆ๋กœ ๊ฐ์‹ธ์ง„ ์—ฐ์‚ฐ์„ ํฌํ•จํ•œ
81+
# ๋ชจ๋ธ ์ „์ฒด์— ํŒจํ„ด ๊ธฐ๋ฐ˜ ์ตœ์ ํ™”๋ฅผ ์ ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
8382
import torch._inductor.pattern_matcher as pm
8483
from torch._inductor.pattern_matcher import register_replacement
8584

8685
######################################################################
87-
# torch.compile will capture a graph representation of our model. During
88-
# compilation, modules hidden within Sequential containers and wrapped
89-
# modules are all inlined into the graph, making them available for
90-
# pattern matching and optimization.
86+
# torch.compile์€ ๋ชจ๋ธ์˜ ๊ณ„์‚ฐ ๊ทธ๋ž˜ํ”„๋ฅผ ํ™•๋ณดํ•ฉ๋‹ˆ๋‹ค.
87+
# ์ปดํŒŒ์ผ ๊ณผ์ •์—์„œ Sequential ์ปจํ…Œ์ด๋„ˆ์— ์ˆจ๊ฒจ์ง„ ๋ชจ๋“ˆ๊ณผ ๋‹ค๋ฅธ ๋ชจ๋“ˆ๋กœ ๊ฐ์‹ธ์ง„ ๋ชจ๋“ˆ๋“ค์€ ๋ชจ๋‘ ๊ทธ๋ž˜ํ”„์—
88+
# ์ง์ ‘ ํฌํ•จ๋˜์–ด ํŒจํ„ด ๋งค์นญ๊ณผ ์ตœ์ ํ™”์˜ ๋Œ€์ƒ์ด ๋ฉ๋‹ˆ๋‹ค.
9189

9290

93-
####################################
94-
# Fusing Convolution with Batch Norm
91+
######################################################################
92+
# ํ•ฉ์„ฑ๊ณฑ๊ณผ ๋ฐฐ์น˜ ์ •๊ทœํ™” ํ“จ์ „ํ•˜๊ธฐ
9593
# ----------------------------------
96-
# Unlike some other fusions, fusion of convolution with batch norm does not
97-
# require any new operators. Instead, as batch norm during inference
98-
# consists of a pointwise add and multiply, these operations can be "baked"
99-
# into the preceding convolution's weights. This allows us to remove the batch
100-
# norm entirely from our model! Read
101-
# https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ for further details. The
102-
# code here is copied from
103-
# https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/utils/fusion.py
104-
# clarity purposes.
94+
# ๋‹ค๋ฅธ ์ผ๋ถ€ ํ“จ์ „๊ณผ ๋‹ฌ๋ฆฌ, ํ•ฉ์„ฑ๊ณฑ๊ณผ ๋ฐฐ์น˜ ์ •๊ทœํ™”์˜ ํ“จ์ „์—๋Š” ์ƒˆ๋กœ์šด ์—ฐ์‚ฐ์ž๊ฐ€ ํ•„์š”ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
95+
# ์ถ”๋ก  ๊ณผ์ •์—์„œ ๋ฐฐ์น˜ ์ •๊ทœํ™”๋Š” ์š”์†Œ๋ณ„ ๋ง์…ˆ๊ณผ ๊ณฑ์…ˆ์œผ๋กœ ์ด๋ฃจ์–ด์ง€๋ฏ€๋กœ ์ด๋Ÿฌํ•œ ์—ฐ์‚ฐ๋“ค์„ ์•ž์„  ํ•ฉ์„ฑ๊ณฑ์˜ ๊ฐ€์ค‘์น˜์—
96+
# ๋ฐ˜์˜ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ๋ชจ๋ธ์—์„œ ๋ฐฐ์น˜ ์ •๊ทœํ™”๋ฅผ ์™„์ „ํžˆ ์ œ๊ฑฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค!
97+
# ์ž์„ธํ•œ ๋‚ด์šฉ์€ ์ด ๊ธ€์„ ์ฐธ๊ณ ํ•˜์„ธ์š”.
98+
# https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
99+
# ์—ฌ๊ธฐ์„œ ์‚ฌ์šฉํ•œ ์ฝ”๋“œ๋Š” ์„ค๋ช…์˜ ๋ช…ํ™•์„ฑ์„ ์œ„ํ•ด ๋‹ค์Œ์˜ ๊ตฌํ˜„์„ ๊ฐ€์ ธ์˜จ ๊ฒƒ์ž…๋‹ˆ๋‹ค. https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/utils/fusion.py
105100
def fuse_conv_bn_eval(conv, bn):
106101
"""
107-
Given a conv Module `A` and an batch_norm module `B`, returns a conv
108-
module `C` such that C(x) == B(A(x)) in inference mode.
102+
ํ•ฉ์„ฑ๊ณฑ ๋ชจ๋“ˆ A์™€ ๋ฐฐ์น˜ ์ •๊ทœํ™” ๋ชจ๋“ˆ B๊ฐ€ ์ฃผ์–ด์กŒ์„ ๋•Œ, ์ถ”๋ก  ๋ชจ๋“œ์—์„œ C(x) == B(A(x))๋ฅผ ๋งŒ์กฑํ•˜๋Š”
103+
ํ•ฉ์„ฑ๊ณฑ ๋ชจ๋“ˆ C๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
109104
"""
110-
assert(not (conv.training or bn.training)), "Fusion only for eval!"
105+
assert(not (conv.training or bn.training)), "์ถ”๋ก  ๋ชจ๋“œ์—์„œ๋งŒ ํ“จ์ „ํ•ฉ๋‹ˆ๋‹ค!"
111106
fused_conv = copy.deepcopy(conv)
112107

113108
fused_conv.weight, fused_conv.bias = \
@@ -131,14 +126,13 @@ def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
131126
return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)
132127

133128

134-
####################################
135-
# Pattern Matching with torch.compile
129+
######################################################################
130+
# torch.compile ๊ธฐ๋ฐ˜ ํŒจํ„ด ๋งค์นญ
136131
# ------------------------------------
137-
# Now that we have our fusion logic, we need to register a pattern that
138-
# torch.compile's pattern matcher will recognize and replace during
139-
# compilation.
132+
# ์ด์ œ ํ“จ์ „ ๋กœ์ง์„ ๊ตฌํ˜„ํ–ˆ์œผ๋ฏ€๋กœ ์ปดํŒŒ์ผ ๊ณผ์ •์—์„œ torch.compile์˜ ํŒจํ„ด ๋งค์ฒ˜๊ฐ€ ์ธ์‹ํ•˜๊ณ  ์น˜ํ™˜ํ•  ์ˆ˜ ์žˆ๋Š”
133+
# ํŒจํ„ด์„ ๋“ฑ๋กํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
140134

141-
# Define the pattern we want to match: conv2d followed by batch_norm
135+
# ๋งค์นญํ•˜๋ ค๋Š” ํŒจํ„ด์„ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. conv2d ๋‹ค์Œ์— batch_norm์ด ์˜ค๋Š” ํŒจํ„ด์ž…๋‹ˆ๋‹ค.
142136
def conv_bn_pattern(x, conv_weight, conv_bias, bn_mean, bn_var, bn_weight, bn_bias):
143137
conv_out = torch.nn.functional.conv2d(x, conv_weight, conv_bias)
144138
bn_out = torch.nn.functional.batch_norm(
@@ -153,30 +147,29 @@ def conv_bn_replacement(x, conv_weight, conv_bias, bn_mean, bn_var, bn_weight, b
153147
)
154148
return torch.nn.functional.conv2d(x, fused_weight, fused_bias)
155149

156-
# Example inputs are needed to trace the pattern functions.
157-
# The inputs should match the function signatures of conv_bn_pattern and conv_bn_replacement.
158-
# These are used to trace the pattern functions to create the match template.
159-
# IMPORTANT: The pattern matcher is shape-agnostic! The specific shapes you use here
160-
# don't limit what shapes will be matched - any valid conv2d->batch_norm sequence
161-
# will be matched regardless of channels, kernel size, or spatial dimensions.
162-
# - x: input tensor (batch_size, channels, height, width)
163-
# - conv_weight: (out_channels, in_channels, kernel_h, kernel_w)
164-
# - conv_bias: (out_channels,)
165-
# - bn_mean, bn_var, bn_weight, bn_bias: all have shape (num_features,) matching out_channels
150+
# ํŒจํ„ด ํ•จ์ˆ˜๋“ค์„ ์ถ”์ ํ•˜๋ ค๋ฉด ์˜ˆ์‹œ ์ž…๋ ฅ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
151+
# ์ด ์ž…๋ ฅ๋“ค์€ conv_bn_pattern ๋ฐ conv_bn_replacement์˜ ํ•จ์ˆ˜ ์‹œ๊ทธ๋‹ˆ์ฒ˜์™€ ์ผ์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
152+
# ์ด๋“ค์€ ํŒจํ„ด ํ•จ์ˆ˜๋ฅผ ์ถ”์ ํ•˜์—ฌ ๋งค์น˜ ํ…œํ”Œ๋ฆฟ์„ ๋งŒ๋“œ๋Š” ๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.
153+
# ์ค‘์š”: ํŒจํ„ด ๋งค์ฒ˜๋Š” ์ž…๋ ฅ ํ˜•ํƒœ์— ๊ตฌ์• ๋ฐ›์ง€ ์•Š์Šต๋‹ˆ๋‹ค! ์—ฌ๊ธฐ์„œ ์‚ฌ์šฉํ•˜๋Š” ํŠน์ • ํ˜•ํƒœ๊ฐ€ ๋งค์นญ๋  ํ˜•ํƒœ๋ฅผ ์ œํ•œํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
154+
# ์ฑ„๋„, ์ปค๋„ ํฌ๊ธฐ, ๊ณต๊ฐ„ ์ฐจ์›์— ๊ด€๊ณ„์—†์ด ์œ ํšจํ•œ conv2d -> batch_norm ์‹œํ€€์Šค๋ผ๋ฉด ๋ชจ๋‘ ๋งค์นญ๋ฉ๋‹ˆ๋‹ค.
155+
# - x: ์ž…๋ ฅ tensor (๋ฐฐ์น˜ ํฌ๊ธฐ, ์ฑ„๋„, ๋†’์ด, ๋„ˆ๋น„)
156+
# - conv_weight: (์ถœ๋ ฅ ์ฑ„๋„, ์ž…๋ ฅ ์ฑ„๋„, ์ปค๋„ ๋†’์ด, ์ปค๋„ ๋„ˆ๋น„)
157+
# - conv_bias: (์ถœ๋ ฅ ์ฑ„๋„,)
158+
# - bn_mean, bn_var, bn_weight, bn_bias: ๋ชจ๋‘ ์ถœ๋ ฅ ์ฑ„๋„๊ณผ ์ผ์น˜ํ•˜๋Š” ํ˜•ํƒœ(num_features,)๋ฅผ ๊ฐ€์ง‘๋‹ˆ๋‹ค.
166159
example_inputs = [
167-
torch.randn(1, 1, 4, 4).to(device), # x: input tensor
168-
torch.randn(1, 1, 1, 1).to(device), # conv_weight: 1 output channel, 1 input channel, 1x1 kernel
169-
torch.randn(1).to(device), # conv_bias: 1 output channel
170-
torch.randn(1).to(device), # bn_mean: batch norm running mean
171-
torch.randn(1).to(device), # bn_var: batch norm running variance
172-
torch.randn(1).to(device), # bn_weight: batch norm weight (gamma)
173-
torch.randn(1).to(device), # bn_bias: batch norm bias (beta)
160+
torch.randn(1, 1, 4, 4).to(device), # x: ์ž…๋ ฅ tensor
161+
torch.randn(1, 1, 1, 1).to(device), # conv_weight: ์ถœ๋ ฅ ์ฑ„๋„ 1, ์ž…๋ ฅ ์ฑ„๋„ 1, 1x1 ์ปค๋„
162+
torch.randn(1).to(device), # conv_bias: ์ถœ๋ ฅ ์ฑ„๋„ 1
163+
torch.randn(1).to(device), # bn_mean: ๋ฐฐ์น˜ ์ •๊ทœํ™” ์ด๋™ ํ‰๊ท 
164+
torch.randn(1).to(device), # bn_var: ๋ฐฐ์น˜ ์ •๊ทœํ™” ์ด๋™ ๋ถ„์‚ฐ
165+
torch.randn(1).to(device), # bn_weight: ๋ฐฐ์น˜ ์ •๊ทœํ™” ๊ฐ€์ค‘์น˜ (๊ฐ๋งˆ)
166+
torch.randn(1).to(device), # bn_bias: ๋ฐฐ์น˜ ์ •๊ทœํ™” ํŽธํ–ฅ (๋ฒ ํƒ€)
174167
]
175168

176169
from torch._inductor.pattern_matcher import PatternMatcherPass
177170
from torch._inductor import config
178171

179-
# Create a pattern matcher pass and register our pattern
172+
# ํŒจํ„ด ๋งค์ฒ˜ ํŒจ์Šค๋ฅผ ์ƒ์„ฑํ•˜๊ณ  ํŒจํ„ด์„ ๋“ฑ๋กํ•ฉ๋‹ˆ๋‹ค.
180173
patterns = PatternMatcherPass()
181174

182175
register_replacement(
@@ -187,48 +180,47 @@ def conv_bn_replacement(x, conv_weight, conv_bias, bn_mean, bn_var, bn_weight, b
187180
patterns,
188181
)
189182

190-
# Create a custom pass function that applies our patterns
183+
# ๋“ฑ๋ก๋œ ํŒจํ„ด์„ ์ ์šฉํ•˜๋Š” ์ปค์Šคํ…€ ํŒจ์Šค ํ•จ์ˆ˜๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
191184
def conv_bn_fusion_pass(graph):
192185
return patterns.apply(graph)
193186

194-
# Set our custom pass in the config
187+
# ์„ค์ •์— ์ปค์Šคํ…€ ํŒจ์Šค๋ฅผ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.
195188
config.post_grad_custom_post_pass = conv_bn_fusion_pass
196189

197190

198191
######################################################################
199-
# .. note::
200-
# We make some simplifications here for demonstration purposes, such as only
201-
# matching 2D convolutions. The pattern matcher in torch.compile
202-
# can handle more complex patterns.
192+
# .. ์ฐธ๊ณ ::
193+
# ์„ค๋ช…์„ ๋•๊ธฐ ์œ„ํ•ด 2D ํ•ฉ์„ฑ๊ณฑ ์—ฐ์‚ฐ๋งŒ ๋งค์นญํ•˜๋Š” ๋“ฑ ์ผ๋ถ€ ๋‹จ์ˆœํ™”๋ฅผ ์ ์šฉํ•˜์˜€์Šต๋‹ˆ๋‹ค.
194+
# torch.compile์˜ ํŒจํ„ด ๋งค์ฒ˜๋Š” ์ด๋ณด๋‹ค ํ›จ์”ฌ ๋” ๋ณต์žกํ•œ ํŒจํ„ด๋„ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
203195

204196
######################################################################
205-
# Testing out our Fusion Pass
197+
# ํ“จ์ „ ํŒจ์Šค ํ…Œ์ŠคํŠธํ•˜๊ธฐ
206198
# -----------------------------------------
207-
# We can now run this fusion pass on our initial toy model and verify that our
208-
# results are identical. In addition, we can print out the code for our fused
209-
# model and verify that there are no more batch norms.
199+
# ์•ž์„œ ๋งŒ๋“  ํ† ์ด ๋ชจ๋ธ์— ์ด ํ“จ์ „ ํŒจ์Šค๋ฅผ ์‹คํ–‰ํ•˜์—ฌ ๊ฒฐ๊ณผ๊ฐ€ ๊ธฐ์กด๊ณผ ์™„๋ฒฝํžˆ ๋™์ผํ•œ์ง€ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
200+
# ๋˜ํ•œ, ํ“จ์ „์ด ์™„๋ฃŒ๋œ ๋ชจ๋ธ์˜ ์ฝ”๋“œ๋ฅผ ์ง์ ‘ ์ถœ๋ ฅํ•ด ๋ด„์œผ๋กœ์จ ๋ฐฐ์น˜ ์ •๊ทœํ™” ์—ฐ์‚ฐ์ด ์ •๋ง๋กœ ๋ชจ๋‘ ์ œ๊ฑฐ๋˜์—ˆ๋Š”์ง€
201+
# ๊ฒ€์ฆํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
210202

211203
from torch._dynamo.utils import counters
212204

213-
# Clear the counters before compilation
205+
# ์ปดํŒŒ์ผํ•˜๊ธฐ ์ „์— ์นด์šดํ„ฐ๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
214206
counters.clear()
215207

216-
# Ensure pattern matcher is enabled
208+
# ํŒจํ„ด ๋งค์ฒ˜๊ฐ€ ํ™œ์„ฑํ™”๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.
217209
config.pattern_matcher = True
218210

219211
fused_model = torch.compile(model, backend="inductor")
220212
inp = torch.randn(5, 1, 1, 1).to(device)
221213

222-
# Run the model to trigger compilation and pattern matching
214+
# ๋ชจ๋ธ์„ ์‹คํ–‰ํ•˜์—ฌ ์ปดํŒŒ์ผ๊ณผ ํŒจํ„ด ๋งค์นญ ๊ณผ์ •์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
223215
with torch.no_grad():
224216
output = fused_model(inp)
225217
expected = model(inp)
226218
torch.testing.assert_close(output, expected)
227219

228-
# Check how many patterns were matched
229-
assert counters['inductor']['pattern_matcher_count'] == 3, "Expected 3 conv-bn patterns to be matched"
220+
# ๋ช‡ ๊ฐœ์˜ ํŒจํ„ด์ด ๋งค์นญ๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.
221+
assert counters['inductor']['pattern_matcher_count'] == 3, "3๊ฐœ์˜ conv-bn ํŒจํ„ด์ด ๋งค์นญ๋  ๊ฒƒ์œผ๋กœ ์˜ˆ์ƒ๋ฉ๋‹ˆ๋‹ค."
230222

231-
# Create a model with different shapes than our example_inputs
223+
# ์•ž์„  ์˜ˆ์‹œ ์ž…๋ ฅ๊ณผ๋Š” ๋‹ค๋ฅธ ํ˜•ํƒœ๋ฅผ ๊ฐ€์ง„ ๋ชจ๋ธ์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
232224
test_model_diff_shape = nn.Sequential(
233225
nn.Conv2d(3, 16, 5),
234226
nn.BatchNorm2d(16),
@@ -243,15 +235,15 @@ def conv_bn_fusion_pass(graph):
243235
with torch.no_grad():
244236
compiled_diff_shape(test_input_diff_shape)
245237

246-
# Check how many patterns were matched
247-
assert counters['inductor']['pattern_matcher_count'] == 2, "Expected 2 conv-bn patterns to be matched"
238+
# ๋ช‡ ๊ฐœ์˜ ํŒจํ„ด์ด ๋งค์นญ๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.
239+
assert counters['inductor']['pattern_matcher_count'] == 2, "2๊ฐœ์˜ conv-bn ํŒจํ„ด์ด ๋งค์นญ๋  ๊ฒƒ์œผ๋กœ ์˜ˆ์ƒ๋ฉ๋‹ˆ๋‹ค."
248240

249241

250242
######################################################################
251-
# Benchmarking our Fusion on ResNet18
243+
# ResNet18 ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ ํ“จ์ „ ์„ฑ๋Šฅ ์ธก์ •
252244
# -----------------------------------
253-
# We can test our fusion pass on a larger model like ResNet18 and see how much
254-
# this pass improves inference performance.
245+
# ResNet18๊ณผ ๊ฐ™์€ ๋” ํฐ ๋ชจ๋ธ์— ํ“จ์ „ ํŒจ์Šค๋ฅผ ํ…Œ์ŠคํŠธํ•˜์—ฌ
246+
# ์ด ๋‹จ๊ณ„๊ฐ€ ์ถ”๋ก  ์„ฑ๋Šฅ์„ ์–ผ๋งˆ๋‚˜ ํ–ฅ์ƒ์‹œํ‚ค๋Š”์ง€ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
255247
import torchvision.models as models
256248
import time
257249

@@ -270,23 +262,22 @@ def benchmark(model, iters=20):
270262
model(inp)
271263
return str(time.time()-begin)
272264

273-
# Benchmark original model
265+
# ์›๋ณธ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ์ธก์ •ํ•ฉ๋‹ˆ๋‹ค.
274266
print("Original model time: ", benchmark(rn18))
275267

276-
# Compile with our custom pattern
268+
# ์•ž์„œ ์ •์˜ํ•œ ์ปค์Šคํ…€ ํŒจํ„ด์„ ์ ์šฉํ•˜์—ฌ ์ปดํŒŒ์ผํ•ฉ๋‹ˆ๋‹ค.
277269
compiled_with_pattern_matching = torch.compile(rn18, backend="inductor")
278270

279-
# Benchmark compiled model
271+
# ์ปดํŒŒ์ผ๋œ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ์ธก์ •ํ•ฉ๋‹ˆ๋‹ค.
280272
print("\ntorch.compile (with conv-bn pattern matching and other fusions): ", benchmark(compiled_with_pattern_matching))
281273

282274

283275
############
284-
# Conclusion
276+
# ๊ฒฐ๋ก 
285277
# ----------
286-
# As we can see, torch.compile provides a powerful way to implement
287-
# graph transformations and optimizations through pattern matching.
288-
# By registering custom patterns, we can extend torch.compile's
289-
# optimization capabilities to handle domain-specific transformations.
278+
# ๋ณด์‹œ๋‹ค์‹œํ”ผ torch.compile์€ ํŒจํ„ด ๋งค์นญ์„ ํ†ตํ•ด ๊ทธ๋ž˜ํ”„ ๋ณ€ํ™˜ ๋ฐ ์ตœ์ ํ™”๋ฅผ ๊ตฌํ˜„ํ•˜๋Š” ๋งค์šฐ ๊ฐ•๋ ฅํ•œ ๋ฐฉ๋ฒ•์„
279+
# ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ์ปค์Šคํ…€ ํŒจํ„ด์„ ๋“ฑ๋กํ•จ์œผ๋กœ์จ torch.compile์˜ ์ตœ์ ํ™” ๊ธฐ๋Šฅ์„ ๋”์šฑ ํ™•์žฅํ•˜์—ฌ ํŠน์ • ๋„๋ฉ”์ธ์—
280+
# ํŠนํ™”๋œ ๋ณ€ํ™˜๊นŒ์ง€ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
290281
#
291-
# The conv-bn fusion demonstrated here is just one example of what's
292-
# possible with torch.compile's pattern matching system.
282+
# ์—ฌ๊ธฐ์„œ ๋ณด์—ฌ๋“œ๋ฆฐ conv-bn ํ“จ์ „์€ torch.compile์˜ ํŒจํ„ด ๋งค์นญ ์‹œ์Šคํ…œ์œผ๋กœ ํ•  ์ˆ˜ ์žˆ๋Š”
283+
# ๋งŽ์€ ์ผ๋“ค ์ค‘ ํ•˜๋‚˜์˜ ์˜ˆ์‹œ์ผ ๋ฟ์ž…๋‹ˆ๋‹ค.

0 commit comments

Comments
ย (0)