Skip to content

Commit 030275a

Browse files
committed
Created using Colab
1 parent 153741a commit 030275a

1 file changed

Lines changed: 222 additions & 0 deletions

File tree

RNN_basics.ipynb

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"id": "view-in-github",
7+
"colab_type": "text"
8+
},
9+
"source": [
10+
"<a href=\"https://colab.research.google.com/github/lovnishverma/Python-Getting-Started/blob/main/RNN_basics.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11+
]
12+
},
13+
{
14+
"cell_type": "markdown",
15+
"id": "c9345e09",
16+
"metadata": {
17+
"id": "c9345e09"
18+
},
19+
"source": [
20+
"# Beginner RNN Basics\n",
21+
"M.TECH AI - Practical 4"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": null,
27+
"id": "b8df6bb4",
28+
"metadata": {
29+
"id": "b8df6bb4"
30+
},
31+
"outputs": [],
32+
"source": [
33+
"import torch\n",
34+
"import torch.nn as nn\n",
35+
"import torch.optim as optim"
36+
]
37+
},
38+
{
39+
"cell_type": "markdown",
40+
"id": "39b6f123",
41+
"metadata": {
42+
"id": "39b6f123"
43+
},
44+
"source": [
45+
"## 1. Generate Data"
46+
]
47+
},
48+
{
49+
"cell_type": "code",
50+
"execution_count": null,
51+
"id": "18c71a3f",
52+
"metadata": {
53+
"colab": {
54+
"base_uri": "https://localhost:8080/"
55+
},
56+
"id": "18c71a3f",
57+
"outputId": "0b808b4f-791d-495a-9f56-2a7d9029adf5"
58+
},
59+
"outputs": [
60+
{
61+
"output_type": "stream",
62+
"name": "stdout",
63+
"text": [
64+
"torch.Size([500, 5, 1]) torch.Size([500, 1])\n"
65+
]
66+
}
67+
],
68+
"source": [
69+
"def generate_data(n=500, seq_len=5):\n",
70+
" X = torch.randn(n, seq_len, 1)\n",
71+
" y = (X.sum(dim=1) > 0).float()\n",
72+
" return X, y\n",
73+
"\n",
74+
"X, y = generate_data()\n",
75+
"\n",
76+
"print(X.shape, y.shape)"
77+
]
78+
},
79+
{
80+
"cell_type": "markdown",
81+
"id": "334eae52",
82+
"metadata": {
83+
"id": "334eae52"
84+
},
85+
"source": [
86+
"## 2. Simple RNN Model"
87+
]
88+
},
89+
{
90+
"cell_type": "code",
91+
"execution_count": null,
92+
"id": "4542f5ff",
93+
"metadata": {
94+
"id": "4542f5ff"
95+
},
96+
"outputs": [],
97+
"source": [
98+
"class SimpleRNN(nn.Module):\n",
99+
" def __init__(self):\n",
100+
" super().__init__()\n",
101+
" self.rnn = nn.RNN(input_size=1, hidden_size=8, batch_first=True)\n",
102+
" self.fc = nn.Linear(8, 1)\n",
103+
"\n",
104+
" def forward(self, x):\n",
105+
" h0 = torch.zeros(1, x.size(0), 8)\n",
106+
" out, _ = self.rnn(x, h0)\n",
107+
" out = out[:, -1, :]\n",
108+
" out = self.fc(out)\n",
109+
" return torch.sigmoid(out)\n",
110+
"\n",
111+
"model = SimpleRNN()"
112+
]
113+
},
114+
{
115+
"cell_type": "markdown",
116+
"id": "3a79a18b",
117+
"metadata": {
118+
"id": "3a79a18b"
119+
},
120+
"source": [
121+
"## 3. Training"
122+
]
123+
},
124+
{
125+
"cell_type": "code",
126+
"execution_count": null,
127+
"id": "63128b5d",
128+
"metadata": {
129+
"colab": {
130+
"base_uri": "https://localhost:8080/"
131+
},
132+
"id": "63128b5d",
133+
"outputId": "cf8632e4-3886-4a19-9071-ed160045c983"
134+
},
135+
"outputs": [
136+
{
137+
"output_type": "stream",
138+
"name": "stdout",
139+
"text": [
140+
"Epoch 10, Loss: 0.6385\n",
141+
"Epoch 20, Loss: 0.5232\n",
142+
"Epoch 30, Loss: 0.2741\n",
143+
"Epoch 40, Loss: 0.2147\n",
144+
"Epoch 50, Loss: 0.1879\n"
145+
]
146+
}
147+
],
148+
"source": [
149+
"loss_fn = nn.BCELoss()\n",
150+
"optimizer = optim.Adam(model.parameters(), lr=0.01)\n",
151+
"\n",
152+
"for epoch in range(50):\n",
153+
" preds = model(X)\n",
154+
" loss = loss_fn(preds, y)\n",
155+
"\n",
156+
" optimizer.zero_grad()\n",
157+
" loss.backward()\n",
158+
" optimizer.step()\n",
159+
"\n",
160+
" if (epoch+1) % 10 == 0:\n",
161+
" print(f\"Epoch {epoch+1}, Loss: {loss.item():.4f}\")"
162+
]
163+
},
164+
{
165+
"cell_type": "markdown",
166+
"id": "30dc228e",
167+
"metadata": {
168+
"id": "30dc228e"
169+
},
170+
"source": [
171+
"## 4. Testing"
172+
]
173+
},
174+
{
175+
"cell_type": "code",
176+
"execution_count": null,
177+
"id": "cf68e653",
178+
"metadata": {
179+
"colab": {
180+
"base_uri": "https://localhost:8080/"
181+
},
182+
"id": "cf68e653",
183+
"outputId": "ae0d5783-1e10-4d51-92ae-cc16340c9912"
184+
},
185+
"outputs": [
186+
{
187+
"output_type": "stream",
188+
"name": "stdout",
189+
"text": [
190+
"Predictions: [0.0, 1.0, 0.0, 1.0, 1.0]\n",
191+
"Actual: [0.0, 1.0, 0.0, 1.0, 1.0]\n"
192+
]
193+
}
194+
],
195+
"source": [
196+
"test_X, test_y = generate_data(5)\n",
197+
"\n",
198+
"with torch.no_grad():\n",
199+
" preds = model(test_X)\n",
200+
" preds = (preds > 0.5).float()\n",
201+
"\n",
202+
"print(\"Predictions:\", preds.squeeze().tolist())\n",
203+
"print(\"Actual:\", test_y.squeeze().tolist())"
204+
]
205+
}
206+
],
207+
"metadata": {
208+
"colab": {
209+
"provenance": [],
210+
"include_colab_link": true
211+
},
212+
"language_info": {
213+
"name": "python"
214+
},
215+
"kernelspec": {
216+
"name": "python3",
217+
"display_name": "Python 3"
218+
}
219+
},
220+
"nbformat": 4,
221+
"nbformat_minor": 5
222+
}

0 commit comments

Comments
 (0)