1+ {
2+ "cells" : [
3+ {
4+ "cell_type" : " markdown" ,
5+ "metadata" : {
6+ "id" : " view-in-github" ,
7+ "colab_type" : " text"
8+ },
9+ "source" : [
10+ " <a href=\" https://colab.research.google.com/github/lovnishverma/Python-Getting-Started/blob/main/RNN_text_classifier.ipynb\" target=\" _parent\" ><img src=\" https://colab.research.google.com/assets/colab-badge.svg\" alt=\" Open In Colab\" /></a>"
11+ ]
12+ },
13+ {
14+ "cell_type" : " markdown" ,
15+ "id" : " f9ed309e" ,
16+ "metadata" : {
17+ "id" : " f9ed309e"
18+ },
19+ "source" : [
20+ " # Simple RNN for Text Classification\n " ,
21+ " M.TECH AI - Practical 5"
22+ ]
23+ },
24+ {
25+ "cell_type" : " code" ,
26+ "execution_count" : 1 ,
27+ "id" : " 1264af09" ,
28+ "metadata" : {
29+ "id" : " 1264af09"
30+ },
31+ "outputs" : [],
32+ "source" : [
33+ " import torch\n " ,
34+ " import torch.nn as nn\n " ,
35+ " import torch.optim as optim"
36+ ]
37+ },
38+ {
39+ "cell_type" : " markdown" ,
40+ "id" : " 2ab48979" ,
41+ "metadata" : {
42+ "id" : " 2ab48979"
43+ },
44+ "source" : [
45+ " ## 1. Dataset"
46+ ]
47+ },
48+ {
49+ "cell_type" : " code" ,
50+ "execution_count" : 2 ,
51+ "id" : " cf700915" ,
52+ "metadata" : {
53+ "id" : " cf700915"
54+ },
55+ "outputs" : [],
56+ "source" : [
57+ " texts = [\" i love this\" , \" i hate this\" ]\n " ,
58+ " labels = [1, 0]"
59+ ]
60+ },
61+ {
62+ "cell_type" : " markdown" ,
63+ "id" : " 180a47a2" ,
64+ "metadata" : {
65+ "id" : " 180a47a2"
66+ },
67+ "source" : [
68+ " ## 2. Text to Numbers"
69+ ]
70+ },
71+ {
72+ "cell_type" : " code" ,
73+ "execution_count" : 3 ,
74+ "id" : " 361b7124" ,
75+ "metadata" : {
76+ "id" : " 361b7124"
77+ },
78+ "outputs" : [],
79+ "source" : [
80+ " vocab = {\" i\" :0, \" love\" :1, \" hate\" :2, \" this\" :3}\n " ,
81+ " \n " ,
82+ " def encode(text):\n " ,
83+ " return [vocab[word] for word in text.split()]\n " ,
84+ " \n " ,
85+ " X = [encode(t) for t in texts]\n " ,
86+ " y = torch.tensor(labels, dtype=torch.float32)"
87+ ]
88+ },
89+ {
90+ "cell_type" : " markdown" ,
91+ "id" : " 10c699b9" ,
92+ "metadata" : {
93+ "id" : " 10c699b9"
94+ },
95+ "source" : [
96+ " ## 3. Padding"
97+ ]
98+ },
99+ {
100+ "cell_type" : " code" ,
101+ "execution_count" : 4 ,
102+ "id" : " f54a454b" ,
103+ "metadata" : {
104+ "id" : " f54a454b"
105+ },
106+ "outputs" : [],
107+ "source" : [
108+ " max_len = max(len(seq) for seq in X)\n " ,
109+ " \n " ,
110+ " def pad(seq):\n " ,
111+ " return seq + [0]*(max_len - len(seq))\n " ,
112+ " \n " ,
113+ " X = torch.tensor([pad(seq) for seq in X])"
114+ ]
115+ },
116+ {
117+ "cell_type" : " markdown" ,
118+ "id" : " f9e4acb6" ,
119+ "metadata" : {
120+ "id" : " f9e4acb6"
121+ },
122+ "source" : [
123+ " ## 4. Model"
124+ ]
125+ },
126+ {
127+ "cell_type" : " code" ,
128+ "execution_count" : 5 ,
129+ "id" : " 4f826ca7" ,
130+ "metadata" : {
131+ "id" : " 4f826ca7"
132+ },
133+ "outputs" : [],
134+ "source" : [
135+ " class SimpleRNN(nn.Module):\n " ,
136+ " def __init__(self):\n " ,
137+ " super().__init__()\n " ,
138+ " self.embedding = nn.Embedding(4, 8)\n " ,
139+ " self.rnn = nn.RNN(8, 16, batch_first=True)\n " ,
140+ " self.fc = nn.Linear(16, 1)\n " ,
141+ " \n " ,
142+ " def forward(self, x):\n " ,
143+ " x = self.embedding(x)\n " ,
144+ " out, _ = self.rnn(x)\n " ,
145+ " out = out[:, -1, :]\n " ,
146+ " out = self.fc(out)\n " ,
147+ " return torch.sigmoid(out).squeeze()\n " ,
148+ " \n " ,
149+ " model = SimpleRNN()"
150+ ]
151+ },
152+ {
153+ "cell_type" : " markdown" ,
154+ "id" : " 554d66c4" ,
155+ "metadata" : {
156+ "id" : " 554d66c4"
157+ },
158+ "source" : [
159+ " ## 5. Training"
160+ ]
161+ },
162+ {
163+ "cell_type" : " code" ,
164+ "execution_count" : 6 ,
165+ "id" : " 0014cb5a" ,
166+ "metadata" : {
167+ "colab" : {
168+ "base_uri" : " https://localhost:8080/"
169+ },
170+ "id" : " 0014cb5a" ,
171+ "outputId" : " 9a900acf-9610-44f5-9f5b-5f8d87f9ff0f"
172+ },
173+ "outputs" : [
174+ {
175+ "output_type" : " stream" ,
176+ "name" : " stdout" ,
177+ "text" : [
178+ " Epoch 10, Loss: 0.5526\n " ,
179+ " Epoch 20, Loss: 0.1626\n " ,
180+ " Epoch 30, Loss: 0.0277\n " ,
181+ " Epoch 40, Loss: 0.0096\n " ,
182+ " Epoch 50, Loss: 0.0053\n "
183+ ]
184+ }
185+ ],
186+ "source" : [
187+ " loss_fn = nn.BCELoss()\n " ,
188+ " optimizer = optim.Adam(model.parameters(), lr=0.01)\n " ,
189+ " \n " ,
190+ " for epoch in range(50):\n " ,
191+ " preds = model(X)\n " ,
192+ " loss = loss_fn(preds, y)\n " ,
193+ " \n " ,
194+ " optimizer.zero_grad()\n " ,
195+ " loss.backward()\n " ,
196+ " optimizer.step()\n " ,
197+ " \n " ,
198+ " if (epoch+1) % 10 == 0:\n " ,
199+ " print(f\" Epoch {epoch+1}, Loss: {loss.item():.4f}\" )"
200+ ]
201+ },
202+ {
203+ "cell_type" : " markdown" ,
204+ "id" : " 158862b3" ,
205+ "metadata" : {
206+ "id" : " 158862b3"
207+ },
208+ "source" : [
209+ " ## 6. Testing"
210+ ]
211+ },
212+ {
213+ "cell_type" : " code" ,
214+ "execution_count" : 7 ,
215+ "id" : " 199741d1" ,
216+ "metadata" : {
217+ "colab" : {
218+ "base_uri" : " https://localhost:8080/"
219+ },
220+ "id" : " 199741d1" ,
221+ "outputId" : " 1d897dcc-40be-4529-ee90-fc1f9483b8e1"
222+ },
223+ "outputs" : [
224+ {
225+ "output_type" : " stream" ,
226+ "name" : " stdout" ,
227+ "text" : [
228+ " Prediction: 0.9947726130485535\n "
229+ ]
230+ }
231+ ],
232+ "source" : [
233+ " test = torch.tensor([pad(encode(\" i love this\" ))])\n " ,
234+ " print(\" Prediction:\" , model(test).item())"
235+ ]
236+ }
237+ ],
238+ "metadata" : {
239+ "colab" : {
240+ "provenance" : [],
241+ "include_colab_link" : true
242+ },
243+ "language_info" : {
244+ "name" : " python"
245+ },
246+ "kernelspec" : {
247+ "name" : " python3" ,
248+ "display_name" : " Python 3"
249+ }
250+ },
251+ "nbformat" : 4 ,
252+ "nbformat_minor" : 5
253+ }
0 commit comments