|
26 | 26 | <meta property="og:site_name" content="Run's Studio"> |
27 | 27 | <meta property="og:description" content="本笔记从小土堆Pytorch教程中记录一些实用的Pytorch相关操作. 1. 加载数据1.1 PILPIL类可以用于加载图像、保存图像等操作 12from PIL import Imageimg = Image.open('data/hymenoptera_data/train/ants/342438950_a3da61deab.jpg') 1.2 DataSetDataSe"> |
28 | 28 | <meta property="og:locale" content="zh_CN"> |
| 29 | +<meta property="og:image" content="https://runsstudio.github.io/2025/04/02/Pytorch%E5%BF%AB%E9%80%9F%E4%B8%8A%E6%89%8B/image.png"> |
29 | 30 | <meta property="article:published_time" content="2025-04-02T12:32:28.000Z"> |
30 | | -<meta property="article:modified_time" content="2025-05-08T13:41:08.956Z"> |
| 31 | +<meta property="article:modified_time" content="2025-05-12T12:34:13.555Z"> |
31 | 32 | <meta property="article:tag" content="交通"> |
32 | 33 | <meta name="twitter:card" content="summary_large_image"> |
| 34 | +<meta name="twitter:image" content="https://runsstudio.github.io/2025/04/02/Pytorch%E5%BF%AB%E9%80%9F%E4%B8%8A%E6%89%8B/image.png"> |
33 | 35 |
|
34 | 36 |
|
35 | 37 |
|
@@ -339,7 +341,9 @@ <h2 id="5-4-小网络搭建实战"><a href="#5-4-小网络搭建实战" class="h |
339 | 341 |
|
340 | 342 |
|
341 | 343 |
|
342 | | -<h1 id="6-损失函数与反向传播"><a href="#6-损失函数与反向传播" class="headerlink" title="6 损失函数与反向传播"></a>6 损失函数与反向传播</h1><h2 id="6-1-损失函数"><a href="#6-1-损失函数" class="headerlink" title="6.1 损失函数"></a>6.1 损失函数</h2><p>损失函数(Loss Function)是一个衡量预测结果与真实结果之间差异的函数 ,也称为误差函数。它通过计算模型的预测值与真实值之间的不一致程度,来评估模型的性能.<br>根据任务不同,选择的损失函数也不同,对于回归任务,常见的损失函数有<code>MSELoss</code>,对于分类任务常见的损失函数有交叉熵损失<code>CrossEntropyLoss</code><br>交叉熵的损失函数可以描述为 $$loss(x,class) = -log(exp(x[class]/sum_j(exp(x[j])))=-x[class]+ln(sum_j(exp(x[j])]))$$<br>举例说明:</p> |
| 344 | +<h1 id="6-损失函数与反向传播"><a href="#6-损失函数与反向传播" class="headerlink" title="6 损失函数与反向传播"></a>6 损失函数与反向传播</h1><h2 id="6-1-损失函数"><a href="#6-1-损失函数" class="headerlink" title="6.1 损失函数"></a>6.1 损失函数</h2><p>损失函数(Loss Function)是一个衡量预测结果与真实结果之间差异的函数 ,也称为误差函数。它通过计算模型的预测值与真实值之间的不一致程度,来评估模型的性能.<br>根据任务不同,选择的损失函数也不同,对于回归任务,常见的损失函数有<code>MSELoss</code>,对于分类任务常见的损失函数有交叉熵损失<code>CrossEntropyLoss</code><br>交叉熵的损失函数可以描述为 </p> |
| 345 | +<p><img src="/2025/04/02/Pytorch%E5%BF%AB%E9%80%9F%E4%B8%8A%E6%89%8B/image.png" srcset="/img/loading.gif" lazyload alt="交叉熵的损失函数"></p> |
| 346 | +<p>举例说明:</p> |
343 | 347 | <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">import</span> torch.nn.functional <span class="hljs-keyword">as</span> F<br>x = torch.tensor([[<span class="hljs-number">0.1</span>, <span class="hljs-number">0.2</span>, <span class="hljs-number">0.3</span>]]) <span class="hljs-comment"># 预测三个类别概率分别是0.1,0.2,0.3</span><br>y = torch.tensor([<span class="hljs-number">1</span>]) <span class="hljs-comment"># 答案是1</span><br>loss = F.cross_entropy(x, y) <span class="hljs-comment"># 计算交叉熵 loss = -0.2 + ln(e^0.1+e^0.2+e^0.3) = 1.10194284823</span><br><span class="hljs-built_in">print</span>(loss) <span class="hljs-comment"># tensor(1.1019)</span><br></code></pre></td></tr></table></figure> |
344 | 348 | <p>其他的案例也差不多</p> |
345 | 349 | <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><code class="hljs python">x = torch.tensor([[<span class="hljs-number">0.55</span>, <span class="hljs-number">0.88</span>]]) <span class="hljs-comment"># 预测值</span><br>y = torch.tensor([[<span class="hljs-number">0.5</span>, <span class="hljs-number">0.8</span>]]) <span class="hljs-comment"># 真实值</span><br>loss_l1 = F.l1_loss(x, y) <span class="hljs-comment"># L1Loss 一阶距</span><br>loss_mse = F.mse_loss(x, y) <span class="hljs-comment"># MSE_LOSS</span><br><span class="hljs-built_in">print</span>(loss_l1) <span class="hljs-comment"># tensor(0.0650)</span><br><span class="hljs-built_in">print</span>(loss_mse) <span class="hljs-comment"># tensor(0.0044)</span><br><br>loss_layer = nn.L1Loss(reduction=<span class="hljs-string">'sum'</span>) <span class="hljs-comment"># 备注:reduction默认是mean,用mean的话结果是0.065</span><br>loss_l1_by_layer = loss_layer(x, y)<br><span class="hljs-built_in">print</span>(loss_l1_by_layer) <span class="hljs-comment"># tensor(0.1300)</span><br></code></pre></td></tr></table></figure> |
@@ -374,9 +378,9 @@ <h2 id="7-2-对常用模型进行增加或修改"><a href="#7-2-对常用模型 |
374 | 378 | <figure class="highlight stylus"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs stylus"><span class="hljs-selector-tag">del</span> vgg16<span class="hljs-selector-class">.classifier</span><span class="hljs-selector-attr">[7]</span><br></code></pre></td></tr></table></figure> |
375 | 379 |
|
376 | 380 | <ol start="4"> |
377 | | -<li>冻结部分层<br>我们现在只想训练最后的fc1层,然后就有了下面的</li> |
| 381 | +<li>冻结部分层<br>我们现在只想微调最后的fc1层,其他层的参数冻结不训练,可以用<code>requires_grad = True</code> 或 <code>False</code> 来控制是否参与梯度回传</li> |
378 | 382 | </ol> |
379 | | -<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-comment"># 冻结fc1层的参数</span><br><span class="hljs-keyword">for</span> name, param <span class="hljs-keyword">in</span> model.named_parameters():<br> <span class="hljs-keyword">if</span> <span class="hljs-string">"fc1"</span> <span class="hljs-keyword">in</span> name:<br> param.requires_grad = <span class="hljs-literal">False</span><br><br><span class="hljs-comment"># 只传入需要更新的参数给优化器</span><br>optimizer = optim.SGD(<span class="hljs-built_in">filter</span>(<span class="hljs-keyword">lambda</span> p: p.requires_grad, model.parameters()), lr=<span class="hljs-number">1e-2</span>)<br></code></pre></td></tr></table></figure> |
| 383 | +<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-comment"># 冻结fc1层的参数</span><br><span class="hljs-keyword">for</span> name, param <span class="hljs-keyword">in</span> model.named_parameters():<br> <span class="hljs-keyword">if</span> <span class="hljs-string">"fc1"</span> <span class="hljs-keyword">in</span> name:<br> param.requires_grad = <span class="hljs-literal">True</span><br> <span class="hljs-keyword">else</span>:<br> param.requires_grad = <span class="hljs-literal">False</span><br><br><span class="hljs-comment"># 在优化的过程中,只传入需要更新的参数的层给优化器(采用过滤器选出参与梯度回传的层)</span><br>optimizer = optim.SGD(<span class="hljs-built_in">filter</span>(<span class="hljs-keyword">lambda</span> p: p.requires_grad, model.parameters()), lr=<span class="hljs-number">1e-2</span>)<br></code></pre></td></tr></table></figure> |
380 | 384 |
|
381 | 385 | <h1 id="8-完整的训练流程"><a href="#8-完整的训练流程" class="headerlink" title="8 完整的训练流程"></a>8 完整的训练流程</h1><p>包括数据集准备,dataLoader准备、网络构建、损失函数定义、循环、计算误差、tensorboard可视化等</p> |
382 | 386 | <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">import</span> torch.nn<br><span class="hljs-keyword">import</span> torchvision<br><span class="hljs-keyword">from</span> torch <span class="hljs-keyword">import</span> nn<br><span class="hljs-keyword">from</span> torch.utils.data <span class="hljs-keyword">import</span> DataLoader<br><span class="hljs-keyword">from</span> torch.utils.tensorboard <span class="hljs-keyword">import</span> SummaryWriter<br><br>train_data = torchvision.datasets.CIFAR10(root=<span class="hljs-string">'/data'</span>, train=<span class="hljs-literal">True</span>, transform=torchvision.transforms.ToTensor(),<br> download=<span class="hljs-literal">True</span>)<br>test_data = torchvision.datasets.CIFAR10(root=<span class="hljs-string">'/data'</span>, train=<span class="hljs-literal">False</span>, transform=torchvision.transforms.ToTensor(),<br> download=<span class="hljs-literal">True</span>)<br><br>train_data_size = <span class="hljs-built_in">len</span>(train_data)<br>test_data_size = <span class="hljs-built_in">len</span>(test_data)<br><br>train_data_loader = DataLoader(train_data, batch_size=<span class="hljs-number">64</span>)<br>test_data_loader = DataLoader(test_data, batch_size=<span class="hljs-number">64</span>)<br><br><br><span class="hljs-keyword">class</span> <span class="hljs-title class_">MyNeuralNetwork</span>(nn.Module):<br><br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, *args, **kwargs</span>) -> <span class="hljs-literal">None</span>:<br> <span class="hljs-built_in">super</span>().__init__(*args, **kwargs)<br> <span class="hljs-variable language_">self</span>.model = nn.Sequential(<br> nn.Conv2d(<span class="hljs-number">3</span>, <span class="hljs-number">32</span>, <span class="hljs-number">5</span>, <span class="hljs-number">1</span>, <span class="hljs-number">2</span>),<br> nn.MaxPool2d(<span class="hljs-number">2</span>),<br> nn.Conv2d(<span class="hljs-number">32</span>, <span class="hljs-number">32</span>, <span class="hljs-number">5</span>, <span class="hljs-number">1</span>, <span class="hljs-number">2</span>),<br> nn.MaxPool2d(<span class="hljs-number">2</span>),<br> nn.Conv2d(<span class="hljs-number">32</span>, <span class="hljs-number">64</span>, <span class="hljs-number">5</span>, <span class="hljs-number">1</span>, <span class="hljs-number">2</span>),<br> nn.MaxPool2d(<span class="hljs-number">2</span>),<br> nn.Flatten(),<br> nn.Linear(<span class="hljs-number">64</span> * <span class="hljs-number">4</span> * <span class="hljs-number">4</span>, <span class="hljs-number">64</span>),<br> nn.Linear(<span class="hljs-number">64</span>, <span class="hljs-number">10</span>)<br> )<br><br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">self, x</span>):<br> <span class="hljs-keyword">return</span> <span class="hljs-variable language_">self</span>.model(x)<br><br><br>mnn = MyNeuralNetwork()<br><span class="hljs-keyword">if</span> torch.cuda.is_available():<br> mnn = mnn.cuda()<br><br>loss_fn = nn.CrossEntropyLoss()<br>loss_fn = loss_fn.cuda()<br><br>lr = <span class="hljs-number">1e-2</span><br>optim = torch.optim.SGD(mnn.parameters(), lr=lr)<br><br>writer = SummaryWriter(<span class="hljs-string">'../logs_train'</span>)<br>epoch = <span class="hljs-number">10</span><br><span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(epoch):<br> mnn.train()<br> train_step = <span class="hljs-number">0</span><br> <span class="hljs-keyword">for</span> data <span class="hljs-keyword">in</span> train_data_loader:<br> imgs, targets = data<br> imgs = imgs.cuda()<br> targets = targets.cuda()<br> outputs = mnn(imgs)<br> loss = loss_fn(outputs, targets)<br><br> optim.zero_grad()<br> loss.backward()<br> optim.step()<br><br> train_step = train_step + <span class="hljs-number">1</span><br> <span class="hljs-keyword">if</span> train_step % <span class="hljs-number">100</span> == <span class="hljs-number">0</span>:<br> <span class="hljs-built_in">print</span>(<span class="hljs-string">"训练次数:{},loss = {}"</span>.<span class="hljs-built_in">format</span>(train_step, loss.item()))<br> writer.add_scalar(<span class="hljs-string">"train_loss"</span>, loss.item(), train_step)<br><br> mnn.<span class="hljs-built_in">eval</span>()<br> total_test_loss = <span class="hljs-number">0</span><br> total_accuracy = <span class="hljs-number">0</span><br> total_test_step = <span class="hljs-number">0</span><br> <span class="hljs-keyword">with</span> torch.no_grad():<br><br> <span class="hljs-keyword">for</span> data <span class="hljs-keyword">in</span> test_data_loader:<br> imgs, targets = data<br> imgs = imgs.cuda()<br> targets = targets.cuda()<br> outputs = mnn(imgs)<br> loss = loss_fn(outputs, targets)<br> total_test_loss += loss.item()<br> <span class="hljs-comment"># 求正确率</span><br> accuracy = (outputs.argmax(<span class="hljs-number">1</span>) == targets).<span class="hljs-built_in">sum</span>()<br> total_accuracy += accuracy<br> <span class="hljs-built_in">print</span>(<span class="hljs-string">"整体测试集上的loss:{}"</span>.<span class="hljs-built_in">format</span>(total_test_loss))<br> <span class="hljs-built_in">print</span>(<span class="hljs-string">"整体测试集上的正确率:{}"</span>.<span class="hljs-built_in">format</span>(total_accuracy / test_data_size))<br> writer.add_scalar(<span class="hljs-string">"test_loss"</span>, total_test_loss, total_test_step)<br> writer.add_scalar(<span class="hljs-string">"test_accuracy"</span>, total_accuracy / test_data_size, total_test_step)<br> total_test_step = total_test_step + <span class="hljs-number">1</span><br><br> <span class="hljs-comment"># 保存模型</span><br> torch.save(mnn, <span class="hljs-string">"mnn{}.pth"</span>.<span class="hljs-built_in">format</span>(i))<br> <span class="hljs-built_in">print</span>(<span class="hljs-string">"模型已保存"</span>)<br><br>writer.close()<br><br></code></pre></td></tr></table></figure> |
|
0 commit comments