Skip to content

Commit 94c2e81

Browse files
committed
shorten lec 35
1 parent ebab104 commit 94c2e81

1 file changed

Lines changed: 0 additions & 243 deletions

File tree

lec/lec35/lec35.ipynb

Lines changed: 0 additions & 243 deletions
Original file line numberDiff line numberDiff line change
@@ -248,249 +248,6 @@
248248
"\n",
249249
"make_array(left, right)"
250250
]
251-
},
252-
{
253-
"cell_type": "markdown",
254-
"metadata": {},
255-
"source": [
256-
"# Text Classification"
257-
]
258-
},
259-
{
260-
"cell_type": "code",
261-
"execution_count": null,
262-
"metadata": {},
263-
"outputs": [],
264-
"source": [
265-
"from datasets import load_dataset\n",
266-
"\n",
267-
"sms = load_dataset('ucirvine/sms_spam', split='train').shuffle(seed=42)\n",
268-
"sms_texts = np.array(sms['sms'])\n",
269-
"sms_labels = np.array(sms['label'])\n",
270-
"\n",
271-
"sms_tbl = Table().with_columns('Text', sms_texts, 'Class', sms_labels)\n",
272-
"sms_tbl.group('Class').show()"
273-
]
274-
},
275-
{
276-
"cell_type": "code",
277-
"execution_count": null,
278-
"metadata": {},
279-
"outputs": [],
280-
"source": [
281-
"sms_tbl.where('Class', 1).sample(with_replacement=False).show(5)"
282-
]
283-
},
284-
{
285-
"cell_type": "code",
286-
"execution_count": null,
287-
"metadata": {},
288-
"outputs": [],
289-
"source": [
290-
"sms_tbl.where('Class', 0).sample(with_replacement=False).show(5)"
291-
]
292-
},
293-
{
294-
"cell_type": "code",
295-
"execution_count": null,
296-
"metadata": {},
297-
"outputs": [],
298-
"source": [
299-
"texts = sms_tbl.column('Text')\n",
300-
"\n",
301-
"sms_data = Table().with_columns(\n",
302-
" 'Chars', np.char.str_len(texts),\n",
303-
" 'Digits', sum(np.char.count(texts, str(d)) for d in range(10)),\n",
304-
" 'Caps', sum(np.char.count(texts, chr(c)) for c in range(65, 91)),\n",
305-
" 'Exclamations', np.char.count(texts, '!'),\n",
306-
" 'Class', sms_tbl.column('Class')\n",
307-
")\n",
308-
"sms_data"
309-
]
310-
},
311-
{
312-
"cell_type": "code",
313-
"execution_count": null,
314-
"metadata": {
315-
"scrolled": true
316-
},
317-
"outputs": [],
318-
"source": [
319-
"sms_data.scatter('Digits', 'Caps', group='Class')"
320-
]
321-
},
322-
{
323-
"cell_type": "code",
324-
"execution_count": null,
325-
"metadata": {},
326-
"outputs": [],
327-
"source": [
328-
"shuffled = sms_data.sample(with_replacement=False)\n",
329-
"test_size = 100\n",
330-
"train_sms = shuffled.take(np.arange(test_size, shuffled.num_rows))\n",
331-
"test_sms = shuffled.take(np.arange(test_size))\n",
332-
"\n",
333-
"print('Training:', train_sms.num_rows, ' Test:', test_sms.num_rows)\n",
334-
"evaluate_accuracy(train_sms, test_sms, 5)"
335-
]
336-
},
337-
{
338-
"cell_type": "markdown",
339-
"metadata": {},
340-
"source": [
341-
"## Rotten Tomatoes Movie Reviews"
342-
]
343-
},
344-
{
345-
"cell_type": "code",
346-
"execution_count": null,
347-
"metadata": {},
348-
"outputs": [],
349-
"source": [
350-
"reviews_full = load_dataset('rotten_tomatoes', split='train')\n",
351-
"reviews_short = reviews_full.filter(lambda x: 5 <= len(x['text'].split()) <= 10)\n",
352-
"\n",
353-
"reviews = Table().with_columns('Text', reviews_short['text'],\n",
354-
" 'Class', reviews_short['label'])\n",
355-
"reviews = reviews.sample(with_replacement=False) # Permute the rows\n",
356-
"reviews.group('Class')"
357-
]
358-
},
359-
{
360-
"cell_type": "code",
361-
"execution_count": null,
362-
"metadata": {},
363-
"outputs": [],
364-
"source": [
365-
"reviews.sample(5)"
366-
]
367-
},
368-
{
369-
"cell_type": "code",
370-
"execution_count": null,
371-
"metadata": {},
372-
"outputs": [],
373-
"source": [
374-
"words = [ # The most common adjectives in the data\n",
375-
" 'good', 'bad', 'funny', 'little', 'much', 'new', 'best',\n",
376-
" 'many', 'own', 'other', 'big', 'great', 'most', 'few',\n",
377-
" 'real', 'first', 'full', 'american', 'romantic', 'same', 'old',\n",
378-
" 'better', 'young', 'original', 'interesting', 'human',\n",
379-
" 'hard', 'cinematic', 'enough', 'emotional', 'last', 'least', 'long',\n",
380-
" 'true', 'predictable', 'visual', 'whole', 'high', 'special',\n",
381-
" 'entertaining', 'sweet', 'enjoyable', 'narrative', 'familiar'\n",
382-
"]\n",
383-
"counts = Table(['Word', 'Positive', 'Negative'])\n",
384-
"for word in words:\n",
385-
" has_word = reviews.where('Text', are.containing(word))\n",
386-
" counts = counts.with_row([word, has_word.where('Class', 1).num_rows,\n",
387-
" has_word.where('Class', 0).num_rows])\n",
388-
"\n",
389-
"counts"
390-
]
391-
},
392-
{
393-
"cell_type": "code",
394-
"execution_count": null,
395-
"metadata": {
396-
"scrolled": true
397-
},
398-
"outputs": [],
399-
"source": [
400-
"reviews.where('Text', are.containing('funny')).where('Class', 0).sample(5, with_replacement=False)"
401-
]
402-
},
403-
{
404-
"cell_type": "code",
405-
"execution_count": null,
406-
"metadata": {},
407-
"outputs": [],
408-
"source": [
409-
"texts = reviews.column('Text')\n",
410-
"review_words = Table().with_column('Class', reviews.column('Class'))\n",
411-
"for word in words:\n",
412-
" review_words = review_words.with_column(word, np.char.count(np.char.lower(texts), word))\n",
413-
"\n",
414-
"review_words.sample(5)"
415-
]
416-
},
417-
{
418-
"cell_type": "code",
419-
"execution_count": null,
420-
"metadata": {},
421-
"outputs": [],
422-
"source": [
423-
"train_reviews = review_words.take(np.arange(test_size, reviews.num_rows))\n",
424-
"test_reviews = review_words.take(np.arange(test_size))\n",
425-
"\n",
426-
"print('Word-count KNN:')\n",
427-
"evaluate_accuracy(train_reviews, test_reviews, 5)"
428-
]
429-
},
430-
{
431-
"cell_type": "code",
432-
"execution_count": null,
433-
"metadata": {},
434-
"outputs": [],
435-
"source": [
436-
"classify_all(train_reviews, test_reviews, 5).pivot('Prediction', 'Class')"
437-
]
438-
},
439-
{
440-
"cell_type": "markdown",
441-
"metadata": {},
442-
"source": [
443-
"## Sentence Embeddings"
444-
]
445-
},
446-
{
447-
"cell_type": "code",
448-
"execution_count": null,
449-
"metadata": {},
450-
"outputs": [],
451-
"source": [
452-
"from sentence_transformers import SentenceTransformer\n",
453-
"\n",
454-
"embedder = SentenceTransformer('all-MiniLM-L6-v2')\n",
455-
"review_emb = embedder.encode(list(reviews.column('Text')), show_progress_bar=True)\n",
456-
"print('Embedding shape:', review_emb.shape)"
457-
]
458-
},
459-
{
460-
"cell_type": "code",
461-
"execution_count": null,
462-
"metadata": {},
463-
"outputs": [],
464-
"source": [
465-
"n_features = 64 # Increasing this will help, but above 128 datahub will crash\n",
466-
"\n",
467-
"cols = ['Class', reviews.column('Class')]\n",
468-
"for i in range(n_features):\n",
469-
" cols += [f'Embed{i}', review_emb[:, i]]\n",
470-
"\n",
471-
"review_emb_table = Table().with_columns(*cols)\n",
472-
"review_emb_table.row(0)"
473-
]
474-
},
475-
{
476-
"cell_type": "code",
477-
"execution_count": null,
478-
"metadata": {},
479-
"outputs": [],
480-
"source": [
481-
"train = review_emb_table.take(np.arange(test_size, reviews.num_rows))\n",
482-
"test = review_emb_table.take(np.arange(test_size))\n",
483-
"evaluate_accuracy(train, test, 5)"
484-
]
485-
},
486-
{
487-
"cell_type": "code",
488-
"execution_count": null,
489-
"metadata": {},
490-
"outputs": [],
491-
"source": [
492-
"classify_all(train, test, 5).pivot('Prediction', 'Class')"
493-
]
494251
}
495252
],
496253
"metadata": {

0 commit comments

Comments
 (0)