Skip to content

Commit 7cb79f8

Browse files
committed
support for AMD ROCm
1 parent 2bc2212 commit 7cb79f8

3 files changed

Lines changed: 877 additions & 322 deletions

File tree

docs/installation.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@ Installation
55
Using Pip
66
=========
77

8-
Simply use ``pip install renard-pipeline``.
8+
For the simplest case, use ``pip install renard-pipeline``. By default, this installs the CPU version of PyTorch. If you want GPU support to accelerate inference:
9+
10+
- CUDA 12.8: ``pip install renard-pipeline[cuda128]``
11+
- ROCm 6.3: ``pip install renard-pipeline[rocm63]``
12+
913

1014
Note that for some modules, you might need to install additional
1115
libraries:

pyproject.toml

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ readme = "README.md"
1010
requires-python = ">=3.9,<3.13"
1111
dependencies = [
1212
"torch>=2.7.0",
13-
"transformers>=4.53.2",
13+
"transformers>=4.56.1",
1414
"nltk>=3.9.1",
1515
"tqdm>=4.67.1",
1616
"networkx>=3.2",
@@ -19,9 +19,10 @@ dependencies = [
1919
"matplotlib>=3.9",
2020
"pytest>=8.4.1",
2121
"tibert>=0.5.2",
22-
"grimbert>=0.1.4",
22+
"grimbert>=0.1.5",
2323
"datasets>=4.0.0",
2424
"rank-bm25>=0.2.2",
25+
"accelerate>=1.10.1"
2526
]
2627

2728
[build-system]
@@ -36,16 +37,51 @@ Homepage = "https://github.com/CompNet/Renard"
3637
Documentation = "https://compnet.github.io/Renard/"
3738
Repository = "https://github.com/CompNet/Renard"
3839

39-
[project.optional-dependencies]
40-
ui = [
41-
"gradio>=4.44.1",
42-
"pyvis>=0.3.2",
43-
]
44-
4540
[dependency-groups]
4641
dev = [
4742
"hypothesis>=6.82",
4843
"Sphinx>=4.3",
4944
"sphinx-rtd-theme>=1.0.0",
5045
"sphinx-autodoc-typehints>=1.12.0",
5146
]
47+
48+
[project.optional-dependencies]
49+
ui = [
50+
"gradio>=4.44.1",
51+
"pyvis>=0.3.2",
52+
]
53+
# torch alternatives (default is cpu torch)
54+
cuda128 = [
55+
"torch>=2.7.1",
56+
]
57+
rocm63 = [
58+
"torch>=2.7.1",
59+
"pytorch-triton-rocm>=3.1.0",
60+
]
61+
62+
[tool.uv]
63+
conflicts = [
64+
[
65+
{ extra = "cuda128" },
66+
{ extra = "rocm63" },
67+
],
68+
]
69+
70+
[tool.uv.sources]
71+
torch = [
72+
{ index = "pytorch-cuda128", extra = "cuda128" },
73+
{ index = "pytorch-rocm63", extra = "rocm63" },
74+
]
75+
pytorch-triton-rocm = [
76+
{ index = "pytorch-rocm63", extra = "rocm63" }
77+
]
78+
79+
[[tool.uv.index]]
80+
name = "pytorch-cuda128"
81+
url = "https://download.pytorch.org/whl/cu128"
82+
explicit = true
83+
84+
[[tool.uv.index]]
85+
name = "pytorch-rocm63"
86+
url = "https://download.pytorch.org/whl/rocm6.3"
87+
explicit = true

0 commit comments

Comments
 (0)