|
557 | 557 | " n += 1\n", |
558 | 558 | "\n", |
559 | 559 | "print(f'{log_likelihood=}')\n", |
| 560 | + "\n", |
560 | 561 | "nll = -log_likelihood\n", |
| 562 | + "\n", |
561 | 563 | "print(f'{nll=}') # Negative log likelihood\n", |
562 | 564 | "print(f'Average NLL: {nll/n:.4f}') # More descriptive output" |
563 | 565 | ] |
|
694 | 696 | "id": "47", |
695 | 697 | "metadata": {}, |
696 | 698 | "outputs": [], |
| 699 | + "source": [ |
| 700 | + "import torch" |
| 701 | + ] |
| 702 | + }, |
| 703 | + { |
| 704 | + "cell_type": "code", |
| 705 | + "execution_count": null, |
| 706 | + "id": "48", |
| 707 | + "metadata": {}, |
| 708 | + "outputs": [], |
697 | 709 | "source": [ |
698 | 710 | "#Create training set of all bigrams\n", |
699 | 711 | "xs, ys = [], [] # Input and output character indices\n", |
|
711 | 723 | "ys = torch.tensor(ys)" |
712 | 724 | ] |
713 | 725 | }, |
| 726 | + { |
| 727 | + "cell_type": "markdown", |
| 728 | + "id": "49", |
| 729 | + "metadata": {}, |
| 730 | + "source": [ |
| 731 | + "We can inspect the first few tokens from the two tensors:\n", |
| 732 | + "- `xs` will be the input tokens (what the model will see)\n", |
| 733 | + "- `ys` will be the target tokens (what we want the model to predict)" |
| 734 | + ] |
| 735 | + }, |
714 | 736 | { |
715 | 737 | "cell_type": "code", |
716 | 738 | "execution_count": null, |
717 | | - "id": "48", |
| 739 | + "id": "50", |
718 | 740 | "metadata": {}, |
719 | 741 | "outputs": [], |
720 | 742 | "source": [ |
721 | 743 | "for i in range(5):\n", |
722 | 744 | " print(f'For character #{i} \"{itos[xs[i].item()]}\" in xs, we expect the model to predict \"{itos[ys[i].item()]}\"')" |
723 | 745 | ] |
| 746 | + }, |
| 747 | + { |
| 748 | + "cell_type": "markdown", |
| 749 | + "id": "51", |
| 750 | + "metadata": {}, |
| 751 | + "source": [ |
| 752 | + "One important detail about PyTorch tensors: there exists `torch.Tensor` (a class) and `torch.tensor()` (a method).\n", |
| 753 | + "They are related, but different:\n", |
| 754 | + "- `torch.Tensor` is a class, and every PyTorch tensor is an instance of this class.\n", |
| 755 | + "- `torch.tensor()` is a method to create a tensor, with `dtype` automatically inferred from the input data.\n", |
| 756 | + "\n", |
| 757 | + "Except for when initializing a completely empty tensor, in general there is no reason to choose `torch.Tensor` over `torch.tensor`. \n", |
| 758 | + "Note that `torch.Tensor` is an alias for `torch.FloatTensor`, with a default `dtype` of `torch.float32`.\n", |
| 759 | + "\n", |
| 760 | + "In general, you should use `torch.tensor()` almost always, unless you have a specific reason to use `torch.Tensor`.\n", |
| 761 | + "\n", |
| 762 | + "Here you can find more details from the official docs:\n", |
| 763 | + "- [`torch.tensor()`](https://docs.pytorch.org/docs/stable/generated/torch.tensor.html#torch-tensor)\n", |
| 764 | + "- [`torch.Tensor`](https://docs.pytorch.org/docs/stable/tensors.html)" |
| 765 | + ] |
| 766 | + }, |
| 767 | + { |
| 768 | + "cell_type": "code", |
| 769 | + "execution_count": null, |
| 770 | + "id": "52", |
| 771 | + "metadata": {}, |
| 772 | + "outputs": [], |
| 773 | + "source": [ |
| 774 | + "xs.dtype, xs.type()" |
| 775 | + ] |
| 776 | + }, |
| 777 | + { |
| 778 | + "cell_type": "markdown", |
| 779 | + "id": "53", |
| 780 | + "metadata": {}, |
| 781 | + "source": [ |
| 782 | + "`torch.LongTensor` here means that the tensor contains 64-bit (8-byte) integers." |
| 783 | + ] |
| 784 | + }, |
| 785 | + { |
| 786 | + "cell_type": "markdown", |
| 787 | + "id": "54", |
| 788 | + "metadata": {}, |
| 789 | + "source": [ |
| 790 | + "### Feeding the network" |
| 791 | + ] |
| 792 | + }, |
| 793 | + { |
| 794 | + "cell_type": "markdown", |
| 795 | + "id": "55", |
| 796 | + "metadata": {}, |
| 797 | + "source": [ |
| 798 | + "A neural network essentially is made up of **layers**.\n", |
| 799 | + "At high-level, each layer is typically a linear transformation followed by a non-linear activation function.\n", |
| 800 | + "Something like: $\\text{output} = \\text{activation}(\\text{weights} \\cdot \\text{input} + \\text{bias})$\n", |
| 801 | + "\n", |
| 802 | + "If we were to feed our characters as integer indexes, we would have a sequence of integer indexes as input.\n", |
| 803 | + "If `a` is 1 and `z` is 25, the weight applied to `z` will have 25 times more impact on the output than `a`. This creates an arbitrary and misleading mathematical relationship.\n", |
| 804 | + "\n", |
| 805 | + "Moreover, during the training (optimization) phase, the updates to the weights will be proportional to their input values.\n", |
| 806 | + "Larger input values will cause larger updates to the weights, which can lead to unstable training.\n", |
| 807 | + "\n", |
| 808 | + "And lastly, the network has no reference to the potential value range. It doesn't know that the values are constrained to a specific set (like 0-25 for letters).\n", |
| 809 | + "\n", |
| 810 | + "To address all these issues, we can use **one-hot encoding**.\n", |
| 811 | + "One-hot encoding each letter means creating a vector where only one position has a value of 1 (corresponding to that letter's position in the alphabet) and all other positions are 0.\n", |
| 812 | + "This gives the neural network a much clearer signal about which letter is present without introducing misleading numerical relationships.\n", |
| 813 | + "This approach is particularly important in language models where the relationships between symbols (letters, words) are learned from their contexts and co-occurrences, not from arbitrary numeric values assigned to them." |
| 814 | + ] |
| 815 | + }, |
| 816 | + { |
| 817 | + "cell_type": "markdown", |
| 818 | + "id": "56", |
| 819 | + "metadata": {}, |
| 820 | + "source": [ |
| 821 | + "Luckily, PyTorch provides a convenient way to perform one-hot encoding using the `torch.nn.functional.one_hot` function." |
| 822 | + ] |
| 823 | + }, |
| 824 | + { |
| 825 | + "cell_type": "markdown", |
| 826 | + "id": "57", |
| 827 | + "metadata": {}, |
| 828 | + "source": [ |
| 829 | + "Let's encode the first 5 tokens, print and visualize the result:" |
| 830 | + ] |
| 831 | + }, |
| 832 | + { |
| 833 | + "cell_type": "code", |
| 834 | + "execution_count": null, |
| 835 | + "id": "58", |
| 836 | + "metadata": {}, |
| 837 | + "outputs": [], |
| 838 | + "source": [ |
| 839 | + "import torch.nn.functional as F\n", |
| 840 | + "\n", |
| 841 | + "F.one_hot(xs[:5], num_classes=27)" |
| 842 | + ] |
| 843 | + }, |
| 844 | + { |
| 845 | + "cell_type": "code", |
| 846 | + "execution_count": null, |
| 847 | + "id": "59", |
| 848 | + "metadata": {}, |
| 849 | + "outputs": [], |
| 850 | + "source": [ |
| 851 | + "xenc = F.one_hot(xs[:5], num_classes=27)\n", |
| 852 | + "xenc.shape" |
| 853 | + ] |
| 854 | + }, |
| 855 | + { |
| 856 | + "cell_type": "code", |
| 857 | + "execution_count": null, |
| 858 | + "id": "60", |
| 859 | + "metadata": {}, |
| 860 | + "outputs": [], |
| 861 | + "source": [ |
| 862 | + "%matplotlib inline\n", |
| 863 | + "import matplotlib.pyplot as plt\n", |
| 864 | + "\n", |
| 865 | + "plt.imshow(xenc)" |
| 866 | + ] |
| 867 | + }, |
| 868 | + { |
| 869 | + "cell_type": "markdown", |
| 870 | + "id": "61", |
| 871 | + "metadata": {}, |
| 872 | + "source": [ |
| 873 | + "One problem is the `dtype` of the one-hot encoded tensor.\n", |
| 874 | + "It's `torch.int64` by default (inferred from our data), but we need `torch.float32` to have a the input suitable for the mathematical operations the network will perform.\n", |
| 875 | + "\n", |
| 876 | + "We can convert it using `.float()`:" |
| 877 | + ] |
| 878 | + }, |
| 879 | + { |
| 880 | + "cell_type": "code", |
| 881 | + "execution_count": null, |
| 882 | + "id": "62", |
| 883 | + "metadata": {}, |
| 884 | + "outputs": [], |
| 885 | + "source": [ |
| 886 | + "xenc = F.one_hot(xs, num_classes=27).float()\n", |
| 887 | + "xenc.dtype" |
| 888 | + ] |
724 | 889 | } |
725 | 890 | ], |
726 | 891 | "metadata": { |
727 | 892 | "kernelspec": { |
728 | 893 | "display_name": "Python 3 (ipykernel)", |
729 | 894 | "language": "python", |
730 | 895 | "name": "python3" |
| 896 | + }, |
| 897 | + "language_info": { |
| 898 | + "codemirror_mode": { |
| 899 | + "name": "ipython", |
| 900 | + "version": 3 |
| 901 | + }, |
| 902 | + "file_extension": ".py", |
| 903 | + "mimetype": "text/x-python", |
| 904 | + "name": "python", |
| 905 | + "nbconvert_exporter": "python", |
| 906 | + "pygments_lexer": "ipython3", |
| 907 | + "version": "3.12.10" |
731 | 908 | } |
732 | 909 | }, |
733 | 910 | "nbformat": 4, |
|
0 commit comments