Skip to content

Commit f8132cb

Browse files
committed
Created using Colab
1 parent 030275a commit f8132cb

1 file changed

Lines changed: 253 additions & 0 deletions

File tree

RNN_text_classifier.ipynb

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"id": "view-in-github",
7+
"colab_type": "text"
8+
},
9+
"source": [
10+
"<a href=\"https://colab.research.google.com/github/lovnishverma/Python-Getting-Started/blob/main/RNN_text_classifier.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11+
]
12+
},
13+
{
14+
"cell_type": "markdown",
15+
"id": "f9ed309e",
16+
"metadata": {
17+
"id": "f9ed309e"
18+
},
19+
"source": [
20+
"# Simple RNN for Text Classification\n",
21+
"M.TECH AI - Practical 5"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": 1,
27+
"id": "1264af09",
28+
"metadata": {
29+
"id": "1264af09"
30+
},
31+
"outputs": [],
32+
"source": [
33+
"import torch\n",
34+
"import torch.nn as nn\n",
35+
"import torch.optim as optim"
36+
]
37+
},
38+
{
39+
"cell_type": "markdown",
40+
"id": "2ab48979",
41+
"metadata": {
42+
"id": "2ab48979"
43+
},
44+
"source": [
45+
"## 1. Dataset"
46+
]
47+
},
48+
{
49+
"cell_type": "code",
50+
"execution_count": 2,
51+
"id": "cf700915",
52+
"metadata": {
53+
"id": "cf700915"
54+
},
55+
"outputs": [],
56+
"source": [
57+
"texts = [\"i love this\", \"i hate this\"]\n",
58+
"labels = [1, 0]"
59+
]
60+
},
61+
{
62+
"cell_type": "markdown",
63+
"id": "180a47a2",
64+
"metadata": {
65+
"id": "180a47a2"
66+
},
67+
"source": [
68+
"## 2. Text to Numbers"
69+
]
70+
},
71+
{
72+
"cell_type": "code",
73+
"execution_count": 3,
74+
"id": "361b7124",
75+
"metadata": {
76+
"id": "361b7124"
77+
},
78+
"outputs": [],
79+
"source": [
80+
"vocab = {\"i\":0, \"love\":1, \"hate\":2, \"this\":3}\n",
81+
"\n",
82+
"def encode(text):\n",
83+
" return [vocab[word] for word in text.split()]\n",
84+
"\n",
85+
"X = [encode(t) for t in texts]\n",
86+
"y = torch.tensor(labels, dtype=torch.float32)"
87+
]
88+
},
89+
{
90+
"cell_type": "markdown",
91+
"id": "10c699b9",
92+
"metadata": {
93+
"id": "10c699b9"
94+
},
95+
"source": [
96+
"## 3. Padding"
97+
]
98+
},
99+
{
100+
"cell_type": "code",
101+
"execution_count": 4,
102+
"id": "f54a454b",
103+
"metadata": {
104+
"id": "f54a454b"
105+
},
106+
"outputs": [],
107+
"source": [
108+
"max_len = max(len(seq) for seq in X)\n",
109+
"\n",
110+
"def pad(seq):\n",
111+
" return seq + [0]*(max_len - len(seq))\n",
112+
"\n",
113+
"X = torch.tensor([pad(seq) for seq in X])"
114+
]
115+
},
116+
{
117+
"cell_type": "markdown",
118+
"id": "f9e4acb6",
119+
"metadata": {
120+
"id": "f9e4acb6"
121+
},
122+
"source": [
123+
"## 4. Model"
124+
]
125+
},
126+
{
127+
"cell_type": "code",
128+
"execution_count": 5,
129+
"id": "4f826ca7",
130+
"metadata": {
131+
"id": "4f826ca7"
132+
},
133+
"outputs": [],
134+
"source": [
135+
"class SimpleRNN(nn.Module):\n",
136+
" def __init__(self):\n",
137+
" super().__init__()\n",
138+
" self.embedding = nn.Embedding(4, 8)\n",
139+
" self.rnn = nn.RNN(8, 16, batch_first=True)\n",
140+
" self.fc = nn.Linear(16, 1)\n",
141+
"\n",
142+
" def forward(self, x):\n",
143+
" x = self.embedding(x)\n",
144+
" out, _ = self.rnn(x)\n",
145+
" out = out[:, -1, :]\n",
146+
" out = self.fc(out)\n",
147+
" return torch.sigmoid(out).squeeze()\n",
148+
"\n",
149+
"model = SimpleRNN()"
150+
]
151+
},
152+
{
153+
"cell_type": "markdown",
154+
"id": "554d66c4",
155+
"metadata": {
156+
"id": "554d66c4"
157+
},
158+
"source": [
159+
"## 5. Training"
160+
]
161+
},
162+
{
163+
"cell_type": "code",
164+
"execution_count": 6,
165+
"id": "0014cb5a",
166+
"metadata": {
167+
"colab": {
168+
"base_uri": "https://localhost:8080/"
169+
},
170+
"id": "0014cb5a",
171+
"outputId": "9a900acf-9610-44f5-9f5b-5f8d87f9ff0f"
172+
},
173+
"outputs": [
174+
{
175+
"output_type": "stream",
176+
"name": "stdout",
177+
"text": [
178+
"Epoch 10, Loss: 0.5526\n",
179+
"Epoch 20, Loss: 0.1626\n",
180+
"Epoch 30, Loss: 0.0277\n",
181+
"Epoch 40, Loss: 0.0096\n",
182+
"Epoch 50, Loss: 0.0053\n"
183+
]
184+
}
185+
],
186+
"source": [
187+
"loss_fn = nn.BCELoss()\n",
188+
"optimizer = optim.Adam(model.parameters(), lr=0.01)\n",
189+
"\n",
190+
"for epoch in range(50):\n",
191+
" preds = model(X)\n",
192+
" loss = loss_fn(preds, y)\n",
193+
"\n",
194+
" optimizer.zero_grad()\n",
195+
" loss.backward()\n",
196+
" optimizer.step()\n",
197+
"\n",
198+
" if (epoch+1) % 10 == 0:\n",
199+
" print(f\"Epoch {epoch+1}, Loss: {loss.item():.4f}\")"
200+
]
201+
},
202+
{
203+
"cell_type": "markdown",
204+
"id": "158862b3",
205+
"metadata": {
206+
"id": "158862b3"
207+
},
208+
"source": [
209+
"## 6. Testing"
210+
]
211+
},
212+
{
213+
"cell_type": "code",
214+
"execution_count": 7,
215+
"id": "199741d1",
216+
"metadata": {
217+
"colab": {
218+
"base_uri": "https://localhost:8080/"
219+
},
220+
"id": "199741d1",
221+
"outputId": "1d897dcc-40be-4529-ee90-fc1f9483b8e1"
222+
},
223+
"outputs": [
224+
{
225+
"output_type": "stream",
226+
"name": "stdout",
227+
"text": [
228+
"Prediction: 0.9947726130485535\n"
229+
]
230+
}
231+
],
232+
"source": [
233+
"test = torch.tensor([pad(encode(\"i love this\"))])\n",
234+
"print(\"Prediction:\", model(test).item())"
235+
]
236+
}
237+
],
238+
"metadata": {
239+
"colab": {
240+
"provenance": [],
241+
"include_colab_link": true
242+
},
243+
"language_info": {
244+
"name": "python"
245+
},
246+
"kernelspec": {
247+
"name": "python3",
248+
"display_name": "Python 3"
249+
}
250+
},
251+
"nbformat": 4,
252+
"nbformat_minor": 5
253+
}

0 commit comments

Comments
 (0)