diff --git a/api/collections/serializers.py b/api/collections/serializers.py index decf6499146..2f24a2eb56a 100644 --- a/api/collections/serializers.py +++ b/api/collections/serializers.py @@ -1,7 +1,10 @@ +import waffle + from django.db import IntegrityError from rest_framework import exceptions from rest_framework import serializers as ser +from osf import features from osf.models import AbstractNode, Node, Collection, Guid, Registration from osf.exceptions import ValidationError, NodeStateError from api.base.serializers import LinksField, RelationshipField, LinkedNodesRelationshipSerializer, LinkedRegistrationsRelationshipSerializer, LinkedPreprintsRelationshipSerializer @@ -426,6 +429,11 @@ def create(self, validated_data): raise exceptions.ValidationError('"creator" must be specified.') if not (creator.has_perm('write_collection', collection) or (hasattr(guid.referent, 'has_permission') and guid.referent.has_permission(creator, WRITE))): raise exceptions.PermissionDenied('Must have write permission on either collection or collected object to collect.') + if waffle.switch_is_active(features.COLLECTION_SUBMISSION_WITH_CEDAR) and collection.provider_id: + try: + collection.provider.validate_required_metadata(guid.referent) + except ValidationError as e: + raise InvalidModelValueError(e.message) try: obj = collection.collect_object(guid.referent, creator, **validated_data) except ValidationError as e: @@ -462,6 +470,11 @@ def create(self, validated_data): raise exceptions.ValidationError('"creator" must be specified.') if not (creator.has_perm('write_collection', collection) or (hasattr(guid.referent, 'has_permission') and guid.referent.has_permission(creator, WRITE))): raise exceptions.PermissionDenied('Must have write permission on either collection or collected object to collect.') + if waffle.switch_is_active(features.COLLECTION_SUBMISSION_WITH_CEDAR) and collection.provider_id: + try: + collection.provider.validate_required_metadata(guid.referent) + except ValidationError as e: + raise InvalidModelValueError(e.message) try: obj = collection.collect_object(guid.referent, creator, **validated_data) except ValidationError as e: diff --git a/api_tests/collections/test_views.py b/api_tests/collections/test_views.py index cc04bb729f8..7ae088ab5a7 100644 --- a/api_tests/collections/test_views.py +++ b/api_tests/collections/test_views.py @@ -1,14 +1,24 @@ -import pytest from urllib.parse import urlparse +import pytest from django.utils.timezone import now +from waffle.testutils import override_switch from api.base.settings.defaults import API_BASE from api.taxonomies.serializers import subjects_as_relationships_version -from api_tests.subjects.mixins import UpdateSubjectsMixin, SubjectsFilterMixin, SubjectsListMixin, SubjectsRelationshipMixin +from api_tests.share._utils import mock_update_share +from api_tests.subjects.mixins import UpdateSubjectsMixin, SubjectsFilterMixin, SubjectsListMixin, \ + SubjectsRelationshipMixin +from api_tests.utils import disconnected_from_listeners from framework.auth.core import Auth +from osf import features +from osf.models import Collection, VersionedGuidMixin +from osf.utils.permissions import ADMIN, WRITE, READ +from osf.utils.sanitize import strip_html from osf_tests.factories import ( + CedarMetadataTemplateFactory, CollectionFactory, + CollectionProviderFactory, NodeFactory, RegistrationFactory, PreprintFactory, @@ -16,15 +26,9 @@ AuthUserFactory, SubjectFactory, ) -from osf.models import Collection, VersionedGuidMixin -from osf.utils.sanitize import strip_html -from osf.utils.permissions import ADMIN, WRITE, READ from website.project.signals import contributor_removed -from api_tests.utils import disconnected_from_listeners -from api_tests.share._utils import mock_update_share from website.views import find_bookmark_collection - url_collection_list = f'/{API_BASE}collections/' @@ -4384,6 +4388,80 @@ def test_filters(self, app, collection_with_one_collection_submission, collectio assert len(res.json['data']) == 1 +@pytest.mark.django_db +class TestCollectionSubmissionWithCedarSwitch: + + @pytest.fixture() + def cedar_template(self): + return CedarMetadataTemplateFactory( + schema_name='Test Schema', + cedar_id='https://cedar.example.com/template/1', + template_version=1, + ) + + @pytest.fixture() + def provider(self, cedar_template): + provider = CollectionProviderFactory() + provider.required_metadata_template = cedar_template + provider.save() + return provider + + @pytest.fixture() + def collection(self, user_one, provider): + c = CollectionFactory(creator=user_one) + c.provider = provider + c.save() + return c + + @pytest.fixture() + def collection_no_provider(self, user_one): + return CollectionFactory(creator=user_one) + + @pytest.fixture() + def project(self, user_one): + return ProjectFactory(creator=user_one) + + @pytest.fixture() + def url(self, collection): + return f'/{API_BASE}collections/{collection._id}/collected_metadata/' + + @pytest.fixture() + def url_no_provider(self, collection_no_provider): + return f'/{API_BASE}collections/{collection_no_provider._id}/collected_metadata/' + + @pytest.fixture() + def payload(self): + def make_collection_payload(**attributes): + return { + 'data': { + 'type': 'collected-metadata', + 'attributes': attributes, + } + } + return make_collection_payload + + def test_switch_active_no_provider_submission_succeeds(self, app, user_one, project, url_no_provider, payload): + with mock_update_share(): + with override_switch(features.COLLECTION_SUBMISSION_WITH_CEDAR, active=True): + res = app.post_json_api( + url_no_provider, + payload(guid=project._id), + auth=user_one.auth, + ) + assert res.status_code == 201 + + def test_switch_active_missing_cedar_record_submission_fails(self, app, user_one, project, url, payload): + with override_switch(features.COLLECTION_SUBMISSION_WITH_CEDAR, active=True): + res = app.post_json_api( + url, + payload(guid=project._id), + auth=user_one.auth, + expect_errors=True, + ) + assert res.status_code == 400 + assert 'CEDAR metadata record' in res.json['errors'][0]['detail'] + + class TestCollectedMetaSubjectFiltering(SubjectsFilterMixin): @pytest.fixture() def project_one(self, user): diff --git a/osf/features.yaml b/osf/features.yaml index cce490a25a4..9f826966eed 100644 --- a/osf/features.yaml +++ b/osf/features.yaml @@ -113,3 +113,8 @@ switches: name: populate_notification_types note: This is used to enable auto population of notification types. active: false + + - flag_name: COLLECTION_SUBMISSION_WITH_CEDAR + name: collection_submission_with_cedar + note: When active, enforces that objects submitted to a collection have a CEDAR metadata record matching the provider's required_metadata_template. + active: false diff --git a/osf/models/provider.py b/osf/models/provider.py index f1dbdec65b3..6adb107c4ee 100644 --- a/osf/models/provider.py +++ b/osf/models/provider.py @@ -165,6 +165,23 @@ def update_or_create_from_json(cls, provider_data, user): related_name='required_by_providers', ) + def validate_required_metadata(self, obj): + """ + Raises ValidationError if obj does not have a CedarMetadataRecord for + this provider's required_metadata_template. + Does nothing when required_metadata_template is not set. + """ + if not self.required_metadata_template_id: + return + guid = obj.guids.first() + if guid is None or not guid.cedar_metadata_records.filter( + template_id=self.required_metadata_template_id + ).exists(): + raise ValidationError( + f'Submitted object must have a CEDAR metadata record for template ' + f'"{self.required_metadata_template.schema_name}" to be submitted to this collection.' + ) + def __repr__(self): return ('(name={self.name!r}, default_license={self.default_license!r}, ' 'allow_submissions={self.allow_submissions!r}) with id {self.id!r}').format(self=self)