Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion label_studio_ml/examples/huggingface_ner/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,18 @@ def fit(self, event, data, **kwargs):
logger.info(f"Skip training: event {event} is not supported")
return

project_id = data['annotation']['project']
# Get project from annotation first if present, otherwise fall back to top-level project field
project = data.get('annotation', {}).get('project') or data.get('project')
# Handle both possible formats
if isinstance(project, dict):
project_id = project.get('id')
else:
project_id = project
# If project_id is still None, log and safely exit
if project_id is None:
logger.error(f"Cannot find project_id in webhook payload: {data}")
return

tasks = self._get_tasks(project_id)

if len(tasks) % self.START_TRAINING_EACH_N_UPDATES != 0 and event != 'START_TRAINING':
Expand Down
30 changes: 30 additions & 0 deletions label_studio_ml/examples/huggingface_ner/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,33 @@ def test_fit(client, mock_get_labeled_tasks, mock_start_training, mock_baseline_
# remove './results/finetuned_model' directory after testing
import shutil
shutil.rmtree(results_dir)

def test_fit_missing_annotation(monkeypatch):
# Initialize the model
model = HuggingFaceNER()

# Mock label_interface to avoid AttributeError
model.label_interface = mock.MagicMock()
# Mock get_first_tag_occurence to return fake values
model.label_interface.get_first_tag_occurence.return_value = ('Labels', 'Text', 'text_field_name')

# Mock data payload with annotation missing, only project present
payload = {
"action": "ANNOTATION_UPDATED",
"project": {"id": 123, "name": "Test Project"}
}

# Monkeypatch _get_tasks to return one fake task
monkeypatch.setattr(model, "_get_tasks", lambda project_id: [
{
"id": "1",
"data": {"text_field_name": "Hello world"},
"annotations": []
}
])

# Call fit()
try:
model.fit(event="ANNOTATION_UPDATED", data=payload)
except Exception as e:
pytest.fail(f"fit() raised an exception when annotation is missing: {e}")
Loading