1+ {
2+ "cells" : [
3+ {
4+ "cell_type" : " markdown" ,
5+ "metadata" : {
6+ "id" : " view-in-github" ,
7+ "colab_type" : " text"
8+ },
9+ "source" : [
10+ " <a href=\" https://colab.research.google.com/github/lovnishverma/Python-Getting-Started/blob/main/RNN_basics.ipynb\" target=\" _parent\" ><img src=\" https://colab.research.google.com/assets/colab-badge.svg\" alt=\" Open In Colab\" /></a>"
11+ ]
12+ },
13+ {
14+ "cell_type" : " markdown" ,
15+ "id" : " c9345e09" ,
16+ "metadata" : {
17+ "id" : " c9345e09"
18+ },
19+ "source" : [
20+ " # Beginner RNN Basics\n " ,
21+ " M.TECH AI - Practical 4"
22+ ]
23+ },
24+ {
25+ "cell_type" : " code" ,
26+ "execution_count" : null ,
27+ "id" : " b8df6bb4" ,
28+ "metadata" : {
29+ "id" : " b8df6bb4"
30+ },
31+ "outputs" : [],
32+ "source" : [
33+ " import torch\n " ,
34+ " import torch.nn as nn\n " ,
35+ " import torch.optim as optim"
36+ ]
37+ },
38+ {
39+ "cell_type" : " markdown" ,
40+ "id" : " 39b6f123" ,
41+ "metadata" : {
42+ "id" : " 39b6f123"
43+ },
44+ "source" : [
45+ " ## 1. Generate Data"
46+ ]
47+ },
48+ {
49+ "cell_type" : " code" ,
50+ "execution_count" : null ,
51+ "id" : " 18c71a3f" ,
52+ "metadata" : {
53+ "colab" : {
54+ "base_uri" : " https://localhost:8080/"
55+ },
56+ "id" : " 18c71a3f" ,
57+ "outputId" : " 0b808b4f-791d-495a-9f56-2a7d9029adf5"
58+ },
59+ "outputs" : [
60+ {
61+ "output_type" : " stream" ,
62+ "name" : " stdout" ,
63+ "text" : [
64+ " torch.Size([500, 5, 1]) torch.Size([500, 1])\n "
65+ ]
66+ }
67+ ],
68+ "source" : [
69+ " def generate_data(n=500, seq_len=5):\n " ,
70+ " X = torch.randn(n, seq_len, 1)\n " ,
71+ " y = (X.sum(dim=1) > 0).float()\n " ,
72+ " return X, y\n " ,
73+ " \n " ,
74+ " X, y = generate_data()\n " ,
75+ " \n " ,
76+ " print(X.shape, y.shape)"
77+ ]
78+ },
79+ {
80+ "cell_type" : " markdown" ,
81+ "id" : " 334eae52" ,
82+ "metadata" : {
83+ "id" : " 334eae52"
84+ },
85+ "source" : [
86+ " ## 2. Simple RNN Model"
87+ ]
88+ },
89+ {
90+ "cell_type" : " code" ,
91+ "execution_count" : null ,
92+ "id" : " 4542f5ff" ,
93+ "metadata" : {
94+ "id" : " 4542f5ff"
95+ },
96+ "outputs" : [],
97+ "source" : [
98+ " class SimpleRNN(nn.Module):\n " ,
99+ " def __init__(self):\n " ,
100+ " super().__init__()\n " ,
101+ " self.rnn = nn.RNN(input_size=1, hidden_size=8, batch_first=True)\n " ,
102+ " self.fc = nn.Linear(8, 1)\n " ,
103+ " \n " ,
104+ " def forward(self, x):\n " ,
105+ " h0 = torch.zeros(1, x.size(0), 8)\n " ,
106+ " out, _ = self.rnn(x, h0)\n " ,
107+ " out = out[:, -1, :]\n " ,
108+ " out = self.fc(out)\n " ,
109+ " return torch.sigmoid(out)\n " ,
110+ " \n " ,
111+ " model = SimpleRNN()"
112+ ]
113+ },
114+ {
115+ "cell_type" : " markdown" ,
116+ "id" : " 3a79a18b" ,
117+ "metadata" : {
118+ "id" : " 3a79a18b"
119+ },
120+ "source" : [
121+ " ## 3. Training"
122+ ]
123+ },
124+ {
125+ "cell_type" : " code" ,
126+ "execution_count" : null ,
127+ "id" : " 63128b5d" ,
128+ "metadata" : {
129+ "colab" : {
130+ "base_uri" : " https://localhost:8080/"
131+ },
132+ "id" : " 63128b5d" ,
133+ "outputId" : " cf8632e4-3886-4a19-9071-ed160045c983"
134+ },
135+ "outputs" : [
136+ {
137+ "output_type" : " stream" ,
138+ "name" : " stdout" ,
139+ "text" : [
140+ " Epoch 10, Loss: 0.6385\n " ,
141+ " Epoch 20, Loss: 0.5232\n " ,
142+ " Epoch 30, Loss: 0.2741\n " ,
143+ " Epoch 40, Loss: 0.2147\n " ,
144+ " Epoch 50, Loss: 0.1879\n "
145+ ]
146+ }
147+ ],
148+ "source" : [
149+ " loss_fn = nn.BCELoss()\n " ,
150+ " optimizer = optim.Adam(model.parameters(), lr=0.01)\n " ,
151+ " \n " ,
152+ " for epoch in range(50):\n " ,
153+ " preds = model(X)\n " ,
154+ " loss = loss_fn(preds, y)\n " ,
155+ " \n " ,
156+ " optimizer.zero_grad()\n " ,
157+ " loss.backward()\n " ,
158+ " optimizer.step()\n " ,
159+ " \n " ,
160+ " if (epoch+1) % 10 == 0:\n " ,
161+ " print(f\" Epoch {epoch+1}, Loss: {loss.item():.4f}\" )"
162+ ]
163+ },
164+ {
165+ "cell_type" : " markdown" ,
166+ "id" : " 30dc228e" ,
167+ "metadata" : {
168+ "id" : " 30dc228e"
169+ },
170+ "source" : [
171+ " ## 4. Testing"
172+ ]
173+ },
174+ {
175+ "cell_type" : " code" ,
176+ "execution_count" : null ,
177+ "id" : " cf68e653" ,
178+ "metadata" : {
179+ "colab" : {
180+ "base_uri" : " https://localhost:8080/"
181+ },
182+ "id" : " cf68e653" ,
183+ "outputId" : " ae0d5783-1e10-4d51-92ae-cc16340c9912"
184+ },
185+ "outputs" : [
186+ {
187+ "output_type" : " stream" ,
188+ "name" : " stdout" ,
189+ "text" : [
190+ " Predictions: [0.0, 1.0, 0.0, 1.0, 1.0]\n " ,
191+ " Actual: [0.0, 1.0, 0.0, 1.0, 1.0]\n "
192+ ]
193+ }
194+ ],
195+ "source" : [
196+ " test_X, test_y = generate_data(5)\n " ,
197+ " \n " ,
198+ " with torch.no_grad():\n " ,
199+ " preds = model(test_X)\n " ,
200+ " preds = (preds > 0.5).float()\n " ,
201+ " \n " ,
202+ " print(\" Predictions:\" , preds.squeeze().tolist())\n " ,
203+ " print(\" Actual:\" , test_y.squeeze().tolist())"
204+ ]
205+ }
206+ ],
207+ "metadata" : {
208+ "colab" : {
209+ "provenance" : [],
210+ "include_colab_link" : true
211+ },
212+ "language_info" : {
213+ "name" : " python"
214+ },
215+ "kernelspec" : {
216+ "name" : " python3" ,
217+ "display_name" : " Python 3"
218+ }
219+ },
220+ "nbformat" : 4 ,
221+ "nbformat_minor" : 5
222+ }
0 commit comments