11# -*- coding: utf-8 -*-
22"""
3- Building a Convolution/Batch Norm fuser with torch.compile
4- ===========================================================
3+ torch.compile ๊ธฐ๋ฐ ํฉ์ฑ๊ณฑยท๋ฐฐ์น ์ ๊ทํ ํจ์ ( Convolution/Batch Norm fuser) ๋ง๋ค๊ธฐ
4+ =========================================================================
55
6- **Author:** `Horace He <https://github.com/chillee>`_, `Will Feng <https://github.com/yf225>`_
6+ **์ ์:** `Horace He <https://github.com/chillee>`_, `Will Feng <https://github.com/yf225>`_
7+ **๋ฒ์ญ:** `์ฌ๊ธฐํ <https://github.com/skt0725>`_
78
89.. grid:: 2
910
10- .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
11+ .. grid-item-card:: :octicon:`mortar-board;1em;` ๋ฐฐ์ธ ๋ด์ฉ
1112 :class-card: card-prerequisites
1213
13- * How to register custom fusion patterns with torch.compile's pattern matcher
14+ * torch.compile์ ํจํด ๋งค์ฒ์ ์ปค์คํ
ํจ์ ํจํด์ ๋ฑ๋กํ๋ ๋ฐฉ๋ฒ
1415
15- .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
16+ .. grid-item-card:: :octicon:`list-unordered;1em;` ์ ์ ์กฐ๊ฑด
1617 :class-card: card-prerequisites
1718
1819 * PyTorch v2.7.0
1920
2021.. note::
21- This optimization only works for models in inference mode (i.e. ``model.eval()``).
22- However, torch.compile's pattern matching system works for both training and inference .
22+ ์ด ์ต์ ํ๋ ์ถ๋ก ๋ชจ๋์ ๋ชจ๋ธ์๋ง ์ ์ฉ๋ฉ๋๋ค (์: ``model.eval()``).
23+ ํ์ง๋ง torch.compile์ ํจํด ๋งค์นญ ์์คํ
์ ํ์ต๊ณผ ์ถ๋ก ๋ชจ๋์์ ๋์ํฉ๋๋ค .
2324
2425"""
2526
2627
2728######################################################################
28- # First, let's get some imports out of the way (we will be using all
29- # of these later in the code).
29+ # ๋จผ์ ์ดํ ์ฝ๋์์ ์ฌ์ฉํ ๋ชจ๋๋ค์ import ํ๊ฒ ์ต๋๋ค.
3030
3131from typing import Type , Dict , Any , Tuple , Iterable
3232import copy
3636device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
3737
3838######################################################################
39- # For this tutorial, we are going to create a model consisting of convolutions
40- # and batch norms. Note that this model has some tricky components - some of
41- # the conv/batch norm patterns are hidden within Sequentials and one of the
42- # ``BatchNorms`` is wrapped in another Module .
39+ # ์ด๋ฒ ํํ ๋ฆฌ์ผ์์๋ ํฉ์ฑ๊ณฑ๊ณผ ๋ฐฐ์น ์ ๊ทํ๋ก ๊ตฌ์ฑ๋ ๋ชจ๋ธ์ ๋ง๋ค์ด ๋ณด๊ฒ ์ต๋๋ค.
40+ # ์ด ๋ชจ๋ธ์๋ ๋ช ๊ฐ์ง ๊น๋ค๋ก์ด ์์๊ฐ ์๋ค๋ ์ ์ ์ ์ํ์ธ์.
41+ # ์ผ๋ถ ํฉ์ฑ๊ณฑยท๋ฐฐ์น ์ ๊ทํ ํจํด์ Sequential ๋ด๋ถ์ ์จ๊ฒจ์ ธ ์์ผ๋ฉฐ, ``BatchNorms`` ์ค ํ๋๋ ๋ ๋ค๋ฅธ
42+ # Module๋ก ๊ฐ์ธ์ ธ ์์ต๋๋ค .
4343
4444class WrappedBatchNorm (nn .Module ):
4545 def __init__ (self ):
@@ -72,42 +72,37 @@ def forward(self, x):
7272model .eval ()
7373
7474######################################################################
75- # Fusing Convolution with Batch Norm
75+ # ํฉ์ฑ๊ณฑ๊ณผ ๋ฐฐ์น ์ ๊ทํ ํจ์ ํ๊ธฐ
7676# -----------------------------------------
77- # One of the primary challenges with trying to automatically fuse convolution
78- # and batch norm in PyTorch is that PyTorch does not provide an easy way of
79- # accessing the computational graph. torch.compile resolves this problem by
80- # capturing the computational graph during compilation, allowing us to apply
81- # pattern-based optimizations across the entire model, including operations
82- # nested within Sequential modules or wrapped in custom modules.
77+ # ํฉ์ฑ๊ณฑ๊ณผ ๋ฐฐ์น ์ ๊ทํ๋ฅผ ์๋์ผ๋ก ํจ์ ํ๋ ค ํ ๋์ ์ฃผ์ ์ด๋ ค์ ์ค ํ๋๋ PyTorch๊ฐ ๊ณ์ฐ
78+ # ๊ทธ๋ํ(computational graph)์ ์ฝ๊ฒ ์ ๊ทผํ ์ ์๋ ๋ฐฉ๋ฒ์ ์ ๊ณตํ์ง ์๋๋ค๋ ์ ์
๋๋ค.
79+ # torch.compile์ ์ปดํ์ผ ๊ณผ์ ์์ ๊ณ์ฐ ๊ทธ๋ํ๋ฅผ ํ๋ณดํจ์ผ๋ก์จ ์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ฉฐ,
80+ # ์ด๋ฅผ ํตํด Sequential ๋ชจ๋ ๋ด๋ถ์ ์๋ ์ค์ฒฉ๋ ์ฐ์ฐ์ด๋ ์ฌ์ฉ์ ์ ์ ๋ชจ๋๋ก ๊ฐ์ธ์ง ์ฐ์ฐ์ ํฌํจํ
81+ # ๋ชจ๋ธ ์ ์ฒด์ ํจํด ๊ธฐ๋ฐ ์ต์ ํ๋ฅผ ์ ์ฉํ ์ ์์ต๋๋ค.
8382import torch ._inductor .pattern_matcher as pm
8483from torch ._inductor .pattern_matcher import register_replacement
8584
8685######################################################################
87- # torch.compile will capture a graph representation of our model. During
88- # compilation, modules hidden within Sequential containers and wrapped
89- # modules are all inlined into the graph, making them available for
90- # pattern matching and optimization.
86+ # torch.compile์ ๋ชจ๋ธ์ ๊ณ์ฐ ๊ทธ๋ํ๋ฅผ ํ๋ณดํฉ๋๋ค.
87+ # ์ปดํ์ผ ๊ณผ์ ์์ Sequential ์ปจํ
์ด๋์ ์จ๊ฒจ์ง ๋ชจ๋๊ณผ ๋ค๋ฅธ ๋ชจ๋๋ก ๊ฐ์ธ์ง ๋ชจ๋๋ค์ ๋ชจ๋ ๊ทธ๋ํ์
88+ # ์ง์ ํฌํจ๋์ด ํจํด ๋งค์นญ๊ณผ ์ต์ ํ์ ๋์์ด ๋ฉ๋๋ค.
9189
9290
93- ####################################
94- # Fusing Convolution with Batch Norm
91+ ######################################################################
92+ # ํฉ์ฑ๊ณฑ๊ณผ ๋ฐฐ์น ์ ๊ทํ ํจ์ ํ๊ธฐ
9593# ----------------------------------
96- # Unlike some other fusions, fusion of convolution with batch norm does not
97- # require any new operators. Instead, as batch norm during inference
98- # consists of a pointwise add and multiply, these operations can be "baked"
99- # into the preceding convolution's weights. This allows us to remove the batch
100- # norm entirely from our model! Read
101- # https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ for further details. The
102- # code here is copied from
103- # https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/utils/fusion.py
104- # clarity purposes.
94+ # ๋ค๋ฅธ ์ผ๋ถ ํจ์ ๊ณผ ๋ฌ๋ฆฌ, ํฉ์ฑ๊ณฑ๊ณผ ๋ฐฐ์น ์ ๊ทํ์ ํจ์ ์๋ ์๋ก์ด ์ฐ์ฐ์๊ฐ ํ์ํ์ง ์์ต๋๋ค.
95+ # ์ถ๋ก ๊ณผ์ ์์ ๋ฐฐ์น ์ ๊ทํ๋ ์์๋ณ ๋ง์
๊ณผ ๊ณฑ์
์ผ๋ก ์ด๋ฃจ์ด์ง๋ฏ๋ก ์ด๋ฌํ ์ฐ์ฐ๋ค์ ์์ ํฉ์ฑ๊ณฑ์ ๊ฐ์ค์น์
96+ # ๋ฐ์ํ ์ ์์ต๋๋ค. ์ด๋ฅผ ํตํด ๋ชจ๋ธ์์ ๋ฐฐ์น ์ ๊ทํ๋ฅผ ์์ ํ ์ ๊ฑฐํ ์ ์์ต๋๋ค!
97+ # ์์ธํ ๋ด์ฉ์ ์ด ๊ธ์ ์ฐธ๊ณ ํ์ธ์.
98+ # https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
99+ # ์ฌ๊ธฐ์ ์ฌ์ฉํ ์ฝ๋๋ ์ค๋ช
์ ๋ช
ํ์ฑ์ ์ํด ๋ค์์ ๊ตฌํ์ ๊ฐ์ ธ์จ ๊ฒ์
๋๋ค. https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/utils/fusion.py
105100def fuse_conv_bn_eval (conv , bn ):
106101 """
107- Given a conv Module `A` and an batch_norm module `B`, returns a conv
108- module `C` such that C(x) == B(A(x)) in inference mode .
102+ ํฉ์ฑ๊ณฑ ๋ชจ๋ A์ ๋ฐฐ์น ์ ๊ทํ ๋ชจ๋ B๊ฐ ์ฃผ์ด์ก์ ๋, ์ถ๋ก ๋ชจ๋์์ C(x) == B(A(x))๋ฅผ ๋ง์กฑํ๋
103+ ํฉ์ฑ๊ณฑ ๋ชจ๋ C๋ฅผ ๋ฐํํฉ๋๋ค .
109104 """
110- assert (not (conv .training or bn .training )), "Fusion only for eval !"
105+ assert (not (conv .training or bn .training )), "์ถ๋ก ๋ชจ๋์์๋ง ํจ์ ํฉ๋๋ค !"
111106 fused_conv = copy .deepcopy (conv )
112107
113108 fused_conv .weight , fused_conv .bias = \
@@ -131,14 +126,13 @@ def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
131126 return torch .nn .Parameter (conv_w ), torch .nn .Parameter (conv_b )
132127
133128
134- ####################################
135- # Pattern Matching with torch.compile
129+ ######################################################################
130+ # torch.compile ๊ธฐ๋ฐ ํจํด ๋งค์นญ
136131# ------------------------------------
137- # Now that we have our fusion logic, we need to register a pattern that
138- # torch.compile's pattern matcher will recognize and replace during
139- # compilation.
132+ # ์ด์ ํจ์ ๋ก์ง์ ๊ตฌํํ์ผ๋ฏ๋ก ์ปดํ์ผ ๊ณผ์ ์์ torch.compile์ ํจํด ๋งค์ฒ๊ฐ ์ธ์ํ๊ณ ์นํํ ์ ์๋
133+ # ํจํด์ ๋ฑ๋กํด์ผ ํฉ๋๋ค.
140134
141- # Define the pattern we want to match: conv2d followed by batch_norm
135+ # ๋งค์นญํ๋ ค๋ ํจํด์ ์ ์ํฉ๋๋ค. conv2d ๋ค์์ batch_norm์ด ์ค๋ ํจํด์
๋๋ค.
142136def conv_bn_pattern (x , conv_weight , conv_bias , bn_mean , bn_var , bn_weight , bn_bias ):
143137 conv_out = torch .nn .functional .conv2d (x , conv_weight , conv_bias )
144138 bn_out = torch .nn .functional .batch_norm (
@@ -153,30 +147,29 @@ def conv_bn_replacement(x, conv_weight, conv_bias, bn_mean, bn_var, bn_weight, b
153147 )
154148 return torch .nn .functional .conv2d (x , fused_weight , fused_bias )
155149
156- # Example inputs are needed to trace the pattern functions.
157- # The inputs should match the function signatures of conv_bn_pattern and conv_bn_replacement.
158- # These are used to trace the pattern functions to create the match template.
159- # IMPORTANT: The pattern matcher is shape-agnostic! The specific shapes you use here
160- # don't limit what shapes will be matched - any valid conv2d->batch_norm sequence
161- # will be matched regardless of channels, kernel size, or spatial dimensions.
162- # - x: input tensor (batch_size, channels, height, width)
163- # - conv_weight: (out_channels, in_channels, kernel_h, kernel_w)
164- # - conv_bias: (out_channels,)
165- # - bn_mean, bn_var, bn_weight, bn_bias: all have shape (num_features,) matching out_channels
150+ # ํจํด ํจ์๋ค์ ์ถ์ ํ๋ ค๋ฉด ์์ ์
๋ ฅ์ด ํ์ํฉ๋๋ค.
151+ # ์ด ์
๋ ฅ๋ค์ conv_bn_pattern ๋ฐ conv_bn_replacement์ ํจ์ ์๊ทธ๋์ฒ์ ์ผ์นํด์ผ ํฉ๋๋ค.
152+ # ์ด๋ค์ ํจํด ํจ์๋ฅผ ์ถ์ ํ์ฌ ๋งค์น ํ
ํ๋ฆฟ์ ๋ง๋๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค.
153+ # ์ค์: ํจํด ๋งค์ฒ๋ ์
๋ ฅ ํํ์ ๊ตฌ์ ๋ฐ์ง ์์ต๋๋ค! ์ฌ๊ธฐ์ ์ฌ์ฉํ๋ ํน์ ํํ๊ฐ ๋งค์นญ๋ ํํ๋ฅผ ์ ํํ์ง ์์ต๋๋ค.
154+ # ์ฑ๋, ์ปค๋ ํฌ๊ธฐ, ๊ณต๊ฐ ์ฐจ์์ ๊ด๊ณ์์ด ์ ํจํ conv2d -> batch_norm ์ํ์ค๋ผ๋ฉด ๋ชจ๋ ๋งค์นญ๋ฉ๋๋ค.
155+ # - x: ์
๋ ฅ tensor (๋ฐฐ์น ํฌ๊ธฐ, ์ฑ๋, ๋์ด, ๋๋น)
156+ # - conv_weight: (์ถ๋ ฅ ์ฑ๋, ์
๋ ฅ ์ฑ๋, ์ปค๋ ๋์ด, ์ปค๋ ๋๋น)
157+ # - conv_bias: (์ถ๋ ฅ ์ฑ๋,)
158+ # - bn_mean, bn_var, bn_weight, bn_bias: ๋ชจ๋ ์ถ๋ ฅ ์ฑ๋๊ณผ ์ผ์นํ๋ ํํ(num_features,)๋ฅผ ๊ฐ์ง๋๋ค.
166159example_inputs = [
167- torch .randn (1 , 1 , 4 , 4 ).to (device ), # x: input tensor
168- torch .randn (1 , 1 , 1 , 1 ).to (device ), # conv_weight: 1 output channel, 1 input channel , 1x1 kernel
169- torch .randn (1 ).to (device ), # conv_bias: 1 output channel
170- torch .randn (1 ).to (device ), # bn_mean: batch norm running mean
171- torch .randn (1 ).to (device ), # bn_var: batch norm running variance
172- torch .randn (1 ).to (device ), # bn_weight: batch norm weight (gamma )
173- torch .randn (1 ).to (device ), # bn_bias: batch norm bias (beta )
160+ torch .randn (1 , 1 , 4 , 4 ).to (device ), # x: ์
๋ ฅ tensor
161+ torch .randn (1 , 1 , 1 , 1 ).to (device ), # conv_weight: ์ถ๋ ฅ ์ฑ๋ 1, ์
๋ ฅ ์ฑ๋ 1 , 1x1 ์ปค๋
162+ torch .randn (1 ).to (device ), # conv_bias: ์ถ๋ ฅ ์ฑ๋ 1
163+ torch .randn (1 ).to (device ), # bn_mean: ๋ฐฐ์น ์ ๊ทํ ์ด๋ ํ๊ท
164+ torch .randn (1 ).to (device ), # bn_var: ๋ฐฐ์น ์ ๊ทํ ์ด๋ ๋ถ์ฐ
165+ torch .randn (1 ).to (device ), # bn_weight: ๋ฐฐ์น ์ ๊ทํ ๊ฐ์ค์น (๊ฐ๋ง )
166+ torch .randn (1 ).to (device ), # bn_bias: ๋ฐฐ์น ์ ๊ทํ ํธํฅ (๋ฒ ํ )
174167]
175168
176169from torch ._inductor .pattern_matcher import PatternMatcherPass
177170from torch ._inductor import config
178171
179- # Create a pattern matcher pass and register our pattern
172+ # ํจํด ๋งค์ฒ ํจ์ค๋ฅผ ์์ฑํ๊ณ ํจํด์ ๋ฑ๋กํฉ๋๋ค.
180173patterns = PatternMatcherPass ()
181174
182175register_replacement (
@@ -187,48 +180,47 @@ def conv_bn_replacement(x, conv_weight, conv_bias, bn_mean, bn_var, bn_weight, b
187180 patterns ,
188181)
189182
190- # Create a custom pass function that applies our patterns
183+ # ๋ฑ๋ก๋ ํจํด์ ์ ์ฉํ๋ ์ปค์คํ
ํจ์ค ํจ์๋ฅผ ์์ฑํฉ๋๋ค.
191184def conv_bn_fusion_pass (graph ):
192185 return patterns .apply (graph )
193186
194- # Set our custom pass in the config
187+ # ์ค์ ์ ์ปค์คํ
ํจ์ค๋ฅผ ์ง์ ํฉ๋๋ค.
195188config .post_grad_custom_post_pass = conv_bn_fusion_pass
196189
197190
198191######################################################################
199- # .. note::
200- # We make some simplifications here for demonstration purposes, such as only
201- # matching 2D convolutions. The pattern matcher in torch.compile
202- # can handle more complex patterns.
192+ # .. ์ฐธ๊ณ ::
193+ # ์ค๋ช
์ ๋๊ธฐ ์ํด 2D ํฉ์ฑ๊ณฑ ์ฐ์ฐ๋ง ๋งค์นญํ๋ ๋ฑ ์ผ๋ถ ๋จ์ํ๋ฅผ ์ ์ฉํ์์ต๋๋ค.
194+ # torch.compile์ ํจํด ๋งค์ฒ๋ ์ด๋ณด๋ค ํจ์ฌ ๋ ๋ณต์กํ ํจํด๋ ์ฒ๋ฆฌํ ์ ์์ต๋๋ค.
203195
204196######################################################################
205- # Testing out our Fusion Pass
197+ # ํจ์ ํจ์ค ํ
์คํธํ๊ธฐ
206198# -----------------------------------------
207- # We can now run this fusion pass on our initial toy model and verify that our
208- # results are identical. In addition, we can print out the code for our fused
209- # model and verify that there are no more batch norms .
199+ # ์์ ๋ง๋ ํ ์ด ๋ชจ๋ธ์ ์ด ํจ์ ํจ์ค๋ฅผ ์คํํ์ฌ ๊ฒฐ๊ณผ๊ฐ ๊ธฐ์กด๊ณผ ์๋ฒฝํ ๋์ผํ์ง ํ์ธํ ์ ์์ต๋๋ค.
200+ # ๋ํ, ํจ์ ์ด ์๋ฃ๋ ๋ชจ๋ธ์ ์ฝ๋๋ฅผ ์ง์ ์ถ๋ ฅํด ๋ด์ผ๋ก์จ ๋ฐฐ์น ์ ๊ทํ ์ฐ์ฐ์ด ์ ๋ง๋ก ๋ชจ๋ ์ ๊ฑฐ๋์๋์ง
201+ # ๊ฒ์ฆํ ์ ์์ต๋๋ค .
210202
211203from torch ._dynamo .utils import counters
212204
213- # Clear the counters before compilation
205+ # ์ปดํ์ผํ๊ธฐ ์ ์ ์นด์ดํฐ๋ฅผ ์ด๊ธฐํํฉ๋๋ค.
214206counters .clear ()
215207
216- # Ensure pattern matcher is enabled
208+ # ํจํด ๋งค์ฒ๊ฐ ํ์ฑํ๋์ด ์๋์ง ํ์ธํฉ๋๋ค.
217209config .pattern_matcher = True
218210
219211fused_model = torch .compile (model , backend = "inductor" )
220212inp = torch .randn (5 , 1 , 1 , 1 ).to (device )
221213
222- # Run the model to trigger compilation and pattern matching
214+ # ๋ชจ๋ธ์ ์คํํ์ฌ ์ปดํ์ผ๊ณผ ํจํด ๋งค์นญ ๊ณผ์ ์ ์ํํฉ๋๋ค.
223215with torch .no_grad ():
224216 output = fused_model (inp )
225217 expected = model (inp )
226218 torch .testing .assert_close (output , expected )
227219
228- # Check how many patterns were matched
229- assert counters ['inductor' ]['pattern_matcher_count' ] == 3 , "Expected 3 conv-bn patterns to be matched "
220+ # ๋ช ๊ฐ์ ํจํด์ด ๋งค์นญ๋์๋์ง ํ์ธํฉ๋๋ค.
221+ assert counters ['inductor' ]['pattern_matcher_count' ] == 3 , "3๊ฐ์ conv-bn ํจํด์ด ๋งค์นญ๋ ๊ฒ์ผ๋ก ์์๋ฉ๋๋ค. "
230222
231- # Create a model with different shapes than our example_inputs
223+ # ์์ ์์ ์
๋ ฅ๊ณผ๋ ๋ค๋ฅธ ํํ๋ฅผ ๊ฐ์ง ๋ชจ๋ธ์ ๋ง๋ญ๋๋ค.
232224test_model_diff_shape = nn .Sequential (
233225 nn .Conv2d (3 , 16 , 5 ),
234226 nn .BatchNorm2d (16 ),
@@ -243,15 +235,15 @@ def conv_bn_fusion_pass(graph):
243235with torch .no_grad ():
244236 compiled_diff_shape (test_input_diff_shape )
245237
246- # Check how many patterns were matched
247- assert counters ['inductor' ]['pattern_matcher_count' ] == 2 , "Expected 2 conv-bn patterns to be matched "
238+ # ๋ช ๊ฐ์ ํจํด์ด ๋งค์นญ๋์๋์ง ํ์ธํฉ๋๋ค.
239+ assert counters ['inductor' ]['pattern_matcher_count' ] == 2 , "2๊ฐ์ conv-bn ํจํด์ด ๋งค์นญ๋ ๊ฒ์ผ๋ก ์์๋ฉ๋๋ค. "
248240
249241
250242######################################################################
251- # Benchmarking our Fusion on ResNet18
243+ # ResNet18 ๋ชจ๋ธ์ ์ฌ์ฉํ ํจ์ ์ฑ๋ฅ ์ธก์
252244# -----------------------------------
253- # We can test our fusion pass on a larger model like ResNet18 and see how much
254- # this pass improves inference performance .
245+ # ResNet18๊ณผ ๊ฐ์ ๋ ํฐ ๋ชจ๋ธ์ ํจ์ ํจ์ค๋ฅผ ํ
์คํธํ์ฌ
246+ # ์ด ๋จ๊ณ๊ฐ ์ถ๋ก ์ฑ๋ฅ์ ์ผ๋ง๋ ํฅ์์ํค๋์ง ํ์ธํ ์ ์์ต๋๋ค .
255247import torchvision .models as models
256248import time
257249
@@ -270,23 +262,22 @@ def benchmark(model, iters=20):
270262 model (inp )
271263 return str (time .time ()- begin )
272264
273- # Benchmark original model
265+ # ์๋ณธ ๋ชจ๋ธ์ ์ฑ๋ฅ์ ์ธก์ ํฉ๋๋ค.
274266print ("Original model time: " , benchmark (rn18 ))
275267
276- # Compile with our custom pattern
268+ # ์์ ์ ์ํ ์ปค์คํ
ํจํด์ ์ ์ฉํ์ฌ ์ปดํ์ผํฉ๋๋ค.
277269compiled_with_pattern_matching = torch .compile (rn18 , backend = "inductor" )
278270
279- # Benchmark compiled model
271+ # ์ปดํ์ผ๋ ๋ชจ๋ธ์ ์ฑ๋ฅ์ ์ธก์ ํฉ๋๋ค.
280272print ("\n torch.compile (with conv-bn pattern matching and other fusions): " , benchmark (compiled_with_pattern_matching ))
281273
282274
283275############
284- # Conclusion
276+ # ๊ฒฐ๋ก
285277# ----------
286- # As we can see, torch.compile provides a powerful way to implement
287- # graph transformations and optimizations through pattern matching.
288- # By registering custom patterns, we can extend torch.compile's
289- # optimization capabilities to handle domain-specific transformations.
278+ # ๋ณด์๋ค์ํผ torch.compile์ ํจํด ๋งค์นญ์ ํตํด ๊ทธ๋ํ ๋ณํ ๋ฐ ์ต์ ํ๋ฅผ ๊ตฌํํ๋ ๋งค์ฐ ๊ฐ๋ ฅํ ๋ฐฉ๋ฒ์
279+ # ์ ๊ณตํฉ๋๋ค. ์ปค์คํ
ํจํด์ ๋ฑ๋กํจ์ผ๋ก์จ torch.compile์ ์ต์ ํ ๊ธฐ๋ฅ์ ๋์ฑ ํ์ฅํ์ฌ ํน์ ๋๋ฉ์ธ์
280+ # ํนํ๋ ๋ณํ๊น์ง ์ฒ๋ฆฌํ ์ ์์ต๋๋ค.
290281#
291- # The conv-bn fusion demonstrated here is just one example of what's
292- # possible with torch.compile's pattern matching system .
282+ # ์ฌ๊ธฐ์ ๋ณด์ฌ๋๋ฆฐ conv-bn ํจ์ ์ torch.compile์ ํจํด ๋งค์นญ ์์คํ
์ผ๋ก ํ ์ ์๋
283+ # ๋ง์ ์ผ๋ค ์ค ํ๋์ ์์์ผ ๋ฟ์
๋๋ค .
0 commit comments