|
248 | 248 | "\n", |
249 | 249 | "make_array(left, right)" |
250 | 250 | ] |
251 | | - }, |
252 | | - { |
253 | | - "cell_type": "markdown", |
254 | | - "metadata": {}, |
255 | | - "source": [ |
256 | | - "# Text Classification" |
257 | | - ] |
258 | | - }, |
259 | | - { |
260 | | - "cell_type": "code", |
261 | | - "execution_count": null, |
262 | | - "metadata": {}, |
263 | | - "outputs": [], |
264 | | - "source": [ |
265 | | - "from datasets import load_dataset\n", |
266 | | - "\n", |
267 | | - "sms = load_dataset('ucirvine/sms_spam', split='train').shuffle(seed=42)\n", |
268 | | - "sms_texts = np.array(sms['sms'])\n", |
269 | | - "sms_labels = np.array(sms['label'])\n", |
270 | | - "\n", |
271 | | - "sms_tbl = Table().with_columns('Text', sms_texts, 'Class', sms_labels)\n", |
272 | | - "sms_tbl.group('Class').show()" |
273 | | - ] |
274 | | - }, |
275 | | - { |
276 | | - "cell_type": "code", |
277 | | - "execution_count": null, |
278 | | - "metadata": {}, |
279 | | - "outputs": [], |
280 | | - "source": [ |
281 | | - "sms_tbl.where('Class', 1).sample(with_replacement=False).show(5)" |
282 | | - ] |
283 | | - }, |
284 | | - { |
285 | | - "cell_type": "code", |
286 | | - "execution_count": null, |
287 | | - "metadata": {}, |
288 | | - "outputs": [], |
289 | | - "source": [ |
290 | | - "sms_tbl.where('Class', 0).sample(with_replacement=False).show(5)" |
291 | | - ] |
292 | | - }, |
293 | | - { |
294 | | - "cell_type": "code", |
295 | | - "execution_count": null, |
296 | | - "metadata": {}, |
297 | | - "outputs": [], |
298 | | - "source": [ |
299 | | - "texts = sms_tbl.column('Text')\n", |
300 | | - "\n", |
301 | | - "sms_data = Table().with_columns(\n", |
302 | | - " 'Chars', np.char.str_len(texts),\n", |
303 | | - " 'Digits', sum(np.char.count(texts, str(d)) for d in range(10)),\n", |
304 | | - " 'Caps', sum(np.char.count(texts, chr(c)) for c in range(65, 91)),\n", |
305 | | - " 'Exclamations', np.char.count(texts, '!'),\n", |
306 | | - " 'Class', sms_tbl.column('Class')\n", |
307 | | - ")\n", |
308 | | - "sms_data" |
309 | | - ] |
310 | | - }, |
311 | | - { |
312 | | - "cell_type": "code", |
313 | | - "execution_count": null, |
314 | | - "metadata": { |
315 | | - "scrolled": true |
316 | | - }, |
317 | | - "outputs": [], |
318 | | - "source": [ |
319 | | - "sms_data.scatter('Digits', 'Caps', group='Class')" |
320 | | - ] |
321 | | - }, |
322 | | - { |
323 | | - "cell_type": "code", |
324 | | - "execution_count": null, |
325 | | - "metadata": {}, |
326 | | - "outputs": [], |
327 | | - "source": [ |
328 | | - "shuffled = sms_data.sample(with_replacement=False)\n", |
329 | | - "test_size = 100\n", |
330 | | - "train_sms = shuffled.take(np.arange(test_size, shuffled.num_rows))\n", |
331 | | - "test_sms = shuffled.take(np.arange(test_size))\n", |
332 | | - "\n", |
333 | | - "print('Training:', train_sms.num_rows, ' Test:', test_sms.num_rows)\n", |
334 | | - "evaluate_accuracy(train_sms, test_sms, 5)" |
335 | | - ] |
336 | | - }, |
337 | | - { |
338 | | - "cell_type": "markdown", |
339 | | - "metadata": {}, |
340 | | - "source": [ |
341 | | - "## Rotten Tomatoes Movie Reviews" |
342 | | - ] |
343 | | - }, |
344 | | - { |
345 | | - "cell_type": "code", |
346 | | - "execution_count": null, |
347 | | - "metadata": {}, |
348 | | - "outputs": [], |
349 | | - "source": [ |
350 | | - "reviews_full = load_dataset('rotten_tomatoes', split='train')\n", |
351 | | - "reviews_short = reviews_full.filter(lambda x: 5 <= len(x['text'].split()) <= 10)\n", |
352 | | - "\n", |
353 | | - "reviews = Table().with_columns('Text', reviews_short['text'],\n", |
354 | | - " 'Class', reviews_short['label'])\n", |
355 | | - "reviews = reviews.sample(with_replacement=False) # Permute the rows\n", |
356 | | - "reviews.group('Class')" |
357 | | - ] |
358 | | - }, |
359 | | - { |
360 | | - "cell_type": "code", |
361 | | - "execution_count": null, |
362 | | - "metadata": {}, |
363 | | - "outputs": [], |
364 | | - "source": [ |
365 | | - "reviews.sample(5)" |
366 | | - ] |
367 | | - }, |
368 | | - { |
369 | | - "cell_type": "code", |
370 | | - "execution_count": null, |
371 | | - "metadata": {}, |
372 | | - "outputs": [], |
373 | | - "source": [ |
374 | | - "words = [ # The most common adjectives in the data\n", |
375 | | - " 'good', 'bad', 'funny', 'little', 'much', 'new', 'best',\n", |
376 | | - " 'many', 'own', 'other', 'big', 'great', 'most', 'few',\n", |
377 | | - " 'real', 'first', 'full', 'american', 'romantic', 'same', 'old',\n", |
378 | | - " 'better', 'young', 'original', 'interesting', 'human',\n", |
379 | | - " 'hard', 'cinematic', 'enough', 'emotional', 'last', 'least', 'long',\n", |
380 | | - " 'true', 'predictable', 'visual', 'whole', 'high', 'special',\n", |
381 | | - " 'entertaining', 'sweet', 'enjoyable', 'narrative', 'familiar'\n", |
382 | | - "]\n", |
383 | | - "counts = Table(['Word', 'Positive', 'Negative'])\n", |
384 | | - "for word in words:\n", |
385 | | - " has_word = reviews.where('Text', are.containing(word))\n", |
386 | | - " counts = counts.with_row([word, has_word.where('Class', 1).num_rows,\n", |
387 | | - " has_word.where('Class', 0).num_rows])\n", |
388 | | - "\n", |
389 | | - "counts" |
390 | | - ] |
391 | | - }, |
392 | | - { |
393 | | - "cell_type": "code", |
394 | | - "execution_count": null, |
395 | | - "metadata": { |
396 | | - "scrolled": true |
397 | | - }, |
398 | | - "outputs": [], |
399 | | - "source": [ |
400 | | - "reviews.where('Text', are.containing('funny')).where('Class', 0).sample(5, with_replacement=False)" |
401 | | - ] |
402 | | - }, |
403 | | - { |
404 | | - "cell_type": "code", |
405 | | - "execution_count": null, |
406 | | - "metadata": {}, |
407 | | - "outputs": [], |
408 | | - "source": [ |
409 | | - "texts = reviews.column('Text')\n", |
410 | | - "review_words = Table().with_column('Class', reviews.column('Class'))\n", |
411 | | - "for word in words:\n", |
412 | | - " review_words = review_words.with_column(word, np.char.count(np.char.lower(texts), word))\n", |
413 | | - "\n", |
414 | | - "review_words.sample(5)" |
415 | | - ] |
416 | | - }, |
417 | | - { |
418 | | - "cell_type": "code", |
419 | | - "execution_count": null, |
420 | | - "metadata": {}, |
421 | | - "outputs": [], |
422 | | - "source": [ |
423 | | - "train_reviews = review_words.take(np.arange(test_size, reviews.num_rows))\n", |
424 | | - "test_reviews = review_words.take(np.arange(test_size))\n", |
425 | | - "\n", |
426 | | - "print('Word-count KNN:')\n", |
427 | | - "evaluate_accuracy(train_reviews, test_reviews, 5)" |
428 | | - ] |
429 | | - }, |
430 | | - { |
431 | | - "cell_type": "code", |
432 | | - "execution_count": null, |
433 | | - "metadata": {}, |
434 | | - "outputs": [], |
435 | | - "source": [ |
436 | | - "classify_all(train_reviews, test_reviews, 5).pivot('Prediction', 'Class')" |
437 | | - ] |
438 | | - }, |
439 | | - { |
440 | | - "cell_type": "markdown", |
441 | | - "metadata": {}, |
442 | | - "source": [ |
443 | | - "## Sentence Embeddings" |
444 | | - ] |
445 | | - }, |
446 | | - { |
447 | | - "cell_type": "code", |
448 | | - "execution_count": null, |
449 | | - "metadata": {}, |
450 | | - "outputs": [], |
451 | | - "source": [ |
452 | | - "from sentence_transformers import SentenceTransformer\n", |
453 | | - "\n", |
454 | | - "embedder = SentenceTransformer('all-MiniLM-L6-v2')\n", |
455 | | - "review_emb = embedder.encode(list(reviews.column('Text')), show_progress_bar=True)\n", |
456 | | - "print('Embedding shape:', review_emb.shape)" |
457 | | - ] |
458 | | - }, |
459 | | - { |
460 | | - "cell_type": "code", |
461 | | - "execution_count": null, |
462 | | - "metadata": {}, |
463 | | - "outputs": [], |
464 | | - "source": [ |
465 | | - "n_features = 64 # Increasing this will help, but above 128 datahub will crash\n", |
466 | | - "\n", |
467 | | - "cols = ['Class', reviews.column('Class')]\n", |
468 | | - "for i in range(n_features):\n", |
469 | | - " cols += [f'Embed{i}', review_emb[:, i]]\n", |
470 | | - "\n", |
471 | | - "review_emb_table = Table().with_columns(*cols)\n", |
472 | | - "review_emb_table.row(0)" |
473 | | - ] |
474 | | - }, |
475 | | - { |
476 | | - "cell_type": "code", |
477 | | - "execution_count": null, |
478 | | - "metadata": {}, |
479 | | - "outputs": [], |
480 | | - "source": [ |
481 | | - "train = review_emb_table.take(np.arange(test_size, reviews.num_rows))\n", |
482 | | - "test = review_emb_table.take(np.arange(test_size))\n", |
483 | | - "evaluate_accuracy(train, test, 5)" |
484 | | - ] |
485 | | - }, |
486 | | - { |
487 | | - "cell_type": "code", |
488 | | - "execution_count": null, |
489 | | - "metadata": {}, |
490 | | - "outputs": [], |
491 | | - "source": [ |
492 | | - "classify_all(train, test, 5).pivot('Prediction', 'Class')" |
493 | | - ] |
494 | 251 | } |
495 | 252 | ], |
496 | 253 | "metadata": { |
|
0 commit comments