Skip to content

Commit b0241ce

Browse files
refactor: initial commit
0 parents  commit b0241ce

8 files changed

Lines changed: 254 additions & 0 deletions

File tree

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# This workflow will upload a Python Package using Twine when a release is created
2+
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3+
4+
# This workflow uses actions that are not certified by GitHub.
5+
# They are provided by a third-party and are governed by
6+
# separate terms of service, privacy policy, and support
7+
# documentation.
8+
9+
name: Upload Python Package
10+
11+
on:
12+
release:
13+
types: [published]
14+
15+
permissions:
16+
contents: read
17+
18+
jobs:
19+
deploy:
20+
21+
runs-on: ubuntu-latest
22+
23+
steps:
24+
- uses: actions/checkout@v3
25+
- name: Set up Python
26+
uses: actions/setup-python@v3
27+
with:
28+
python-version: '3.x'
29+
- name: Install dependencies
30+
run: |
31+
python -m pip install --upgrade pip
32+
pip install build
33+
- name: Build package
34+
run: python -m build
35+
- name: Publish package
36+
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
37+
with:
38+
user: __token__
39+
password: ${{ secrets.PYPI_API_TOKEN }}

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
__pycache__
2+
.mypy_cache

.pre-commit-config.yaml

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
repos:
2+
- repo: https://github.com/pre-commit/pre-commit-hooks
3+
rev: v2.3.0
4+
hooks:
5+
- id: end-of-file-fixer
6+
- id: trailing-whitespace
7+
8+
# Formats code correctly
9+
- repo: https://github.com/psf/black
10+
rev: 22.3.0
11+
hooks:
12+
- id: black
13+
args: [
14+
'--experimental-string-processing'
15+
]
16+
17+
# Sorts imports
18+
- repo: https://github.com/pycqa/isort
19+
rev: 5.10.1
20+
hooks:
21+
- id: isort
22+
name: isort (python)
23+
args: ["--profile", "black"]
24+
25+
# Checks unused imports, like lengths, etc
26+
- repo: https://gitlab.com/pycqa/flake8
27+
rev: 4.0.0
28+
hooks:
29+
- id: flake8
30+
args: [
31+
'--per-file-ignores=__init__.py:F401',
32+
'--max-line-length=88',
33+
'--ignore=E1,W1,E2,W2,E4,W4,E5,W5' # Handled by black
34+
]
35+
36+
# Checks types
37+
- repo: https://github.com/pre-commit/mirrors-mypy
38+
rev: 'v0.971'
39+
hooks:
40+
- id: mypy
41+
additional_dependencies: [data-science-types>=0.2, torch>=1.6]

LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2022 archinet.ai
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
2+
# Bitcodes - PyTorch
3+
4+
A new vector quantization method with binary codes, in PyTorch.
5+
6+
```bash
7+
pip install bitcodes-pytorch
8+
```
9+
[![PyPI - Python Version](https://img.shields.io/pypi/v/bitcodes-pytorch?style=flat&colorA=black&colorB=black)](https://pypi.org/project/bitcodes-pytorch/)
10+
11+
12+
## Usage
13+
14+
### Quantize
15+
```python
16+
from bitcodes_pytorch import Bitcodes
17+
18+
bitcodes = Bitcodes(
19+
features=8,
20+
num_bits=4,
21+
temperature=10,
22+
)
23+
24+
# Set to eval during inference to make deterministic
25+
bitcodes.eval()
26+
27+
x = torch.randn(1, 6, 8)
28+
# Computes y, the quantzed version of x, and the bitcodes
29+
y, bits = bitcodes(x)
30+
31+
"""
32+
y.shape = torch.Size([1, 6, 8])
33+
34+
bits = tensor([[
35+
[0, 0, 0, 0],
36+
[1, 0, 1, 1],
37+
[1, 0, 0, 1],
38+
[1, 0, 0, 0],
39+
[0, 1, 1, 1],
40+
[0, 0, 1, 0]
41+
]])
42+
"""
43+
```
44+
45+
### Recover Output from Bits
46+
```python
47+
y_decoded = bitcodes.from_bits(bits)
48+
49+
assert torch.allclose(y, y_decoded) # Assert passes in eval mode!
50+
```
51+
52+
### Decimal-Binary Conversion
53+
```python
54+
from bitcodes_pytorch import to_decimal, to_binary
55+
56+
indices = to_decimal(bits)
57+
# tensor([[ 0, 11, 9, 8, 7, 2]])
58+
59+
bits = to_binary(indices, num_bits=4)
60+
61+
"""
62+
bits = tensor([[
63+
[0, 0, 0, 0],
64+
[1, 0, 1, 1],
65+
[1, 0, 0, 1],
66+
[1, 0, 0, 0],
67+
[0, 1, 1, 1],
68+
[0, 0, 1, 0]
69+
]])
70+
"""
71+
```
72+
73+
## Explaination
74+
75+
TODO

bitcodes_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .bitcodes import Bitcodes, to_bits, to_decimal

bitcodes_pytorch/bitcodes.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import Tuple
2+
3+
import torch
4+
import torch.nn.functional as F
5+
from einops import rearrange
6+
from torch import Tensor, einsum, nn
7+
8+
""" Utils """
9+
10+
11+
def to_bits(indices: Tensor, num_bits: int) -> Tensor:
12+
bitmask = 2 ** torch.arange(num_bits - 1, -1, -1)
13+
return indices.unsqueeze(-1).bitwise_and(bitmask).ne(0).long()
14+
15+
16+
def to_decimal(bits: Tensor) -> Tensor:
17+
num_bits = bits.shape[-1]
18+
bitmask = 2 ** torch.arange(num_bits - 1, -1, -1)
19+
return torch.sum(bitmask * bits, dim=-1)
20+
21+
22+
""" Bincodes """
23+
24+
25+
class Bitcodes(nn.Module):
26+
def __init__(self, features: int, num_bits: int, temperature: int):
27+
super().__init__()
28+
self.temperature = temperature
29+
self.codebook = nn.Parameter(torch.randn(2 * num_bits, features))
30+
31+
def from_bits(self, bits: Tensor) -> Tensor:
32+
attn = F.one_hot(bits.long(), num_classes=2).float()
33+
attn = rearrange(attn, "b m p q -> b m (p q)")
34+
out = einsum("b m n, n d -> b m d", attn, self.codebook)
35+
return out
36+
37+
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
38+
sim = einsum("b m d, n d -> b m n", x, self.codebook)
39+
pairs = rearrange(sim, "b m (p q) -> b m p q", q=2)
40+
41+
if self.training:
42+
attn = F.gumbel_softmax(pairs, tau=self.temperature, dim=-1, hard=True)
43+
else:
44+
attn = F.one_hot(pairs.argmax(dim=-1), num_classes=2).float()
45+
46+
attn = rearrange(attn, "b m p q -> b m (p q)")
47+
out = einsum("b m n, n d -> b m d", attn, self.codebook)
48+
bits = pairs.argmax(dim=-1)
49+
return out, bits

setup.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from setuptools import find_packages, setup
2+
3+
setup(
4+
name="bitcodes-pytorch",
5+
packages=find_packages(exclude=[]),
6+
version="0.0.1",
7+
license="MIT",
8+
description="Bitcodes - Pytorch",
9+
long_description_content_type="text/markdown",
10+
author="Flavio Schneider",
11+
author_email="archinetai@protonmail.com",
12+
url="https://github.com/archinetai/bitcodes-pytorch",
13+
keywords=["artificial intelligence", "deep learning"],
14+
install_requires=[
15+
"torch>=1.6",
16+
"data-science-types>=0.2",
17+
"einops>=0.4",
18+
],
19+
classifiers=[
20+
"Development Status :: 4 - Beta",
21+
"Intended Audience :: Developers",
22+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
23+
"License :: OSI Approved :: MIT License",
24+
"Programming Language :: Python :: 3.6",
25+
],
26+
)

0 commit comments

Comments
 (0)