Skip to content

Commit 5359200

Browse files
authored
Merge branch 'main' into renovate/actions-setup-python-6.x
2 parents 441e0ae + 62e52b7 commit 5359200

6 files changed

Lines changed: 17 additions & 29 deletions

File tree

.github/workflows/deploy-docs.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ jobs:
2727

2828
steps:
2929
- name: Checkout repository
30-
uses: actions/checkout@v4
30+
uses: actions/checkout@v6
3131

3232
- name: Set up Python
3333
uses: actions/setup-python@v6
3434
with:
3535
python-version: '3.11'
3636

3737
- name: Install uv
38-
uses: astral-sh/setup-uv@v4
38+
uses: astral-sh/setup-uv@v7
3939

4040
- name: Install dependencies
4141
run: |
@@ -51,7 +51,7 @@ jobs:
5151
uv run sphinx-build -b html docs/source build/html
5252
5353
- name: Upload artifact
54-
uses: actions/upload-pages-artifact@v3
54+
uses: actions/upload-pages-artifact@v4
5555
with:
5656
path: build/html
5757

.github/workflows/python-package.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ jobs:
1919
python-version: ["3.11", "3.12", "3.13"]
2020

2121
steps:
22-
- uses: actions/checkout@v5
22+
- uses: actions/checkout@v6
2323
- name: Install uv
2424
uses: astral-sh/setup-uv@v7
2525
with:
2626
python-version: ${{ matrix.python-version }}
27-
version: 0.9.10
27+
version: 0.9.12
2828
enable-cache: true
2929
cache-dependency-glob: "uv.lock"
3030
- name: Install the project

.github/workflows/python-publish.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ jobs:
3434
# url: https://pypi.org/project/YOURPROJECT/${{ github.event.release.name }}
3535

3636
steps:
37-
- uses: actions/checkout@v5
37+
- uses: actions/checkout@v6
3838
- name: Install uv
3939
uses: astral-sh/setup-uv@v7
4040
with:
4141
enable-cache: true
4242
cache-dependency-glob: "uv.lock"
43-
version: 0.9.10
43+
version: 0.9.12
4444

4545
- name: Change Package Version
4646
run: |

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# torchTextClassifiers
22

3-
[![Documentation](https://img.shields.io/badge/docs-latest-blue.svg)](https://inseeflab.github.io/torchTextClassifiers/)
3+
[![Documentation](https://img.shields.io/badge/docs-latest-blue.svg)](https://inseefrlab.github.io/torchTextClassifiers/)
44

55
A unified, extensible framework for text classification with categorical variables built on [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).
66

@@ -34,8 +34,7 @@ pip install -e .
3434

3535
## 📖 Documentation
3636

37-
Full documentation is available at: **https://inseeflab.github.io/torchTextClassifiers/**
38-
37+
Full documentation is available at: **https://inseefrlab.github.io/torchTextClassifiers/**
3938
The documentation includes:
4039
- **Getting Started**: Installation and quick start guide
4140
- **Architecture**: Understanding the 3-layer design
@@ -58,3 +57,5 @@ See the [examples/](examples/) directory for:
5857
## 📄 License
5958

6059
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
60+
61+

docs/source/architecture/overview.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ custom_head = nn.Sequential(
270270
nn.Linear(64, 5)
271271
)
272272

273-
head = ClassificationHead(linear=custom_head)
273+
head = ClassificationHead(net=custom_head)
274274
```
275275

276276
## Complete Architecture
@@ -611,3 +611,4 @@ torchTextClassifiers provides a **component-based pipeline** for text classifica
611611

612612
Ready to build your classifier? Start with {doc}`../getting_started/quickstart`!
613613

614+

torchTextClassifiers/model/components/classification_head.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def __init__(
2424
"""
2525
super().__init__()
2626
if net is not None:
27+
self.net = net
28+
2729
# --- Custom net should either be a Sequential or a Linear ---
2830
if not (isinstance(net, nn.Sequential) or isinstance(net, nn.Linear)):
2931
raise ValueError("net must be an nn.Sequential when provided.")
@@ -43,7 +45,6 @@ def __init__(
4345
# --- Extract features ---
4446
self.input_dim = first.in_features
4547
self.num_classes = last.out_features
46-
self.net = net
4748
else: # if not Sequential, it is a Linear
4849
self.input_dim = net.in_features
4950
self.num_classes = net.out_features
@@ -53,23 +54,8 @@ def __init__(
5354
input_dim is not None and num_classes is not None
5455
), "Either net or both input_dim and num_classes must be provided."
5556
self.net = nn.Linear(input_dim, num_classes)
56-
self.input_dim, self.num_classes = self._get_linear_input_output_dims(self.net)
57+
self.input_dim = input_dim
58+
self.num_classes = num_classes
5759

5860
def forward(self, x: torch.Tensor) -> torch.Tensor:
5961
return self.net(x)
60-
61-
@staticmethod
62-
def _get_linear_input_output_dims(module: nn.Module):
63-
"""
64-
Returns (input_dim, output_dim) for any module containing Linear layers.
65-
Works for Linear, Sequential, or nested models.
66-
"""
67-
# Collect all Linear layers recursively
68-
linears = [m for m in module.modules() if isinstance(m, nn.Linear)]
69-
70-
if not linears:
71-
raise ValueError("No Linear layers found in the given module.")
72-
73-
input_dim = linears[0].in_features
74-
output_dim = linears[-1].out_features
75-
return input_dim, output_dim

0 commit comments

Comments
 (0)