diff --git a/.flake8 b/.flake8 index f649211e..9968b3a7 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,4 @@ [flake8] max-line-length=120 exclude = docs/*,.tox/* +ignore = E203, W503, W504, E704 diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 00000000..2caebff7 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,5 @@ +# Initial commit to format the codebase with black +e454428fdf5b3d32f8ffbe9bded3878a80e7b169 +# Initial commit to sort import on the codebase +2d69e4fb6daf037cb2f4358fde0c6c96e3f7ca16 + \ No newline at end of file diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml new file mode 100644 index 00000000..faf097ac --- /dev/null +++ b/.github/workflows/format.yml @@ -0,0 +1,19 @@ +name: Check format + +on: + - pull_request +jobs: + check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + - name: Install deps + run: pip install -r requirements-dev.txt + - name: Run black + run: black --check opentaxii/ tests/ + - name: Run isort + run: isort --check-only opentaxii/ tests/ diff --git a/docs/conf.py b/docs/conf.py index 56fa10aa..acec863b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,8 +13,9 @@ # serve to show the default. import datetime -import sys import os +import sys + import sphinx_rtd_theme # If extensions (or modules to document with autodoc) are in another directory, diff --git a/docs/update_db_schema_diagram.py b/docs/update_db_schema_diagram.py index 7268f73f..09d1349a 100755 --- a/docs/update_db_schema_diagram.py +++ b/docs/update_db_schema_diagram.py @@ -5,9 +5,10 @@ current_dir = os.path.dirname(__file__) sys.path.append(os.path.abspath(os.path.join(current_dir, ".."))) -from opentaxii.persistence.sqldb.models import Base from sqla_graphs import TableGrapher +from opentaxii.persistence.sqldb.models import Base + grapher = TableGrapher( style={"node_table_header": {"bgcolor": "#000080"}}, graph_options={"size": "30,30!"}, # inches, this maps to 2880px diff --git a/examples/hooks.py b/examples/hooks.py index 5a38ef43..eec6486b 100644 --- a/examples/hooks.py +++ b/examples/hooks.py @@ -1,6 +1,8 @@ from opentaxii.signals import ( - CONTENT_BLOCK_CREATED, INBOX_MESSAGE_CREATED, SUBSCRIPTION_CREATED + CONTENT_BLOCK_CREATED, + INBOX_MESSAGE_CREATED, + SUBSCRIPTION_CREATED, ) diff --git a/opentaxii/__init__.py b/opentaxii/__init__.py index 70565526..0f172433 100644 --- a/opentaxii/__init__.py +++ b/opentaxii/__init__.py @@ -1,11 +1,11 @@ ''' - OpenTAXII, TAXII server implementation from EclecticIQ. +OpenTAXII, TAXII server implementation from EclecticIQ. ''' + # flake8: noqa from ._version import __version__ -from .server import TAXIIServer from .config import ServerConfig from .entities import Account - from .local import context +from .server import TAXIIServer diff --git a/opentaxii/auth/api.py b/opentaxii/auth/api.py index ffc5713a..db3e7bac 100644 --- a/opentaxii/auth/api.py +++ b/opentaxii/auth/api.py @@ -1,4 +1,3 @@ - class OpenTAXIIAuthAPI: '''Abstract class that represents OpenTAXII Authentication API. diff --git a/opentaxii/auth/manager.py b/opentaxii/auth/manager.py index aa4774ab..7ed7a3a8 100644 --- a/opentaxii/auth/manager.py +++ b/opentaxii/auth/manager.py @@ -47,9 +47,7 @@ def update_account(self, account, password): for colname, permission in list(account.permissions.items()): collection = self.server.servers.taxii1.persistence.get_collection(colname) if not collection: - log.warning( - "update_account.unknown_collection", - collection=colname) + log.warning("update_account.unknown_collection", collection=colname) account.permissions.pop(colname) account = self.api.update_account(account, password) return account diff --git a/opentaxii/auth/sqldb/api.py b/opentaxii/auth/sqldb/api.py index 29808341..78da33fd 100644 --- a/opentaxii/auth/sqldb/api.py +++ b/opentaxii/auth/sqldb/api.py @@ -2,10 +2,11 @@ import jwt import structlog +from sqlalchemy.orm import exc + from opentaxii.auth import OpenTAXIIAuthAPI from opentaxii.common.sqldb import BaseSQLDatabaseAPI from opentaxii.entities import Account as AccountEntity -from sqlalchemy.orm import exc from .models import Account, Base @@ -32,16 +33,19 @@ class SQLDatabaseAPI(BaseSQLDatabaseAPI, OpenTAXIIAuthAPI): BASEMODEL = Base def __init__( - self, - db_connection, - create_tables=False, - secret=None, - token_ttl_secs=None, - **engine_parameters): + self, + db_connection, + create_tables=False, + secret=None, + token_ttl_secs=None, + **engine_parameters, + ): super().__init__(db_connection, create_tables, **engine_parameters) if not secret: - raise ValueError('Secret is not defined for %s.%s' % ( - self.__module__, self.__class__.__name__)) + raise ValueError( + 'Secret is not defined for %s.%s' + % (self.__module__, self.__class__.__name__) + ) self.secret = secret self.token_ttl_secs = token_ttl_secs or 60 * 60 # 60min @@ -71,7 +75,9 @@ def get_account(self, token): return account_to_account_entity(account) def delete_account(self, username): - account = self.db.session.query(Account).filter_by(username=username).one_or_none() + account = ( + self.db.session.query(Account).filter_by(username=username).one_or_none() + ) if account: self.db.session.delete(account) self.db.session.commit() @@ -79,10 +85,15 @@ def delete_account(self, username): def get_accounts(self): return [ account_to_account_entity(account) - for account in self.db.session.query(Account).all()] + for account in self.db.session.query(Account).all() + ] def update_account(self, obj, password=None): - account = self.db.session.query(Account).filter_by(username=obj.username).one_or_none() + account = ( + self.db.session.query(Account) + .filter_by(username=obj.username) + .one_or_none() + ) if not account: account = Account(username=obj.username) self.db.session.add(account) @@ -120,4 +131,5 @@ def account_to_account_entity(account): id=account.id, username=account.username, is_admin=account.is_admin, - permissions=account.permissions) + permissions=account.permissions, + ) diff --git a/opentaxii/auth/sqldb/models.py b/opentaxii/auth/sqldb/models.py index ce512fe8..a296ee92 100644 --- a/opentaxii/auth/sqldb/models.py +++ b/opentaxii/auth/sqldb/models.py @@ -2,10 +2,7 @@ from sqlalchemy import schema, types from sqlalchemy.ext.declarative import declarative_base - -from werkzeug.security import ( - check_password_hash, generate_password_hash -) +from werkzeug.security import check_password_hash, generate_password_hash __all__ = ['Base', 'Account'] @@ -42,6 +39,8 @@ def permissions(self, permissions): for collection_name, permission in permissions.items(): if permission not in ALL_PERMISSIONS: raise ValueError( - "Unknown permission '{}' specified for collection '{}'" - .format(permission, collection_name)) + "Unknown permission '{}' specified for collection '{}'".format( + permission, collection_name + ) + ) self._permissions = json.dumps(permissions) diff --git a/opentaxii/cli/__init__.py b/opentaxii/cli/__init__.py index 7e9114a7..ec25ff64 100644 --- a/opentaxii/cli/__init__.py +++ b/opentaxii/cli/__init__.py @@ -1,9 +1,8 @@ -from opentaxii.server import TAXIIServer from opentaxii.config import ServerConfig from opentaxii.middleware import create_app +from opentaxii.server import TAXIIServer from opentaxii.utils import configure_logging - config = ServerConfig() configure_logging(config['logging'], plain=True) diff --git a/opentaxii/cli/auth.py b/opentaxii/cli/auth.py index 428b81fe..bbbc212a 100644 --- a/opentaxii/cli/auth.py +++ b/opentaxii/cli/auth.py @@ -7,7 +7,7 @@ def create_account(argv=None): parser = argparse.ArgumentParser( description="Create Account via OpenTAXII Auth API", - formatter_class=argparse.ArgumentDefaultsHelpFormatter + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("-u", "--username", required=True) parser.add_argument("-p", "--password", required=True) @@ -38,7 +38,7 @@ def is_truely(text): def update_account(argv=None): parser = argparse.ArgumentParser( description="Update Account via OpenTAXII Auth API", - formatter_class=argparse.ArgumentDefaultsHelpFormatter + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) fields = ("password", "admin") parser.add_argument("-u", "--username", required=True) diff --git a/opentaxii/cli/run.py b/opentaxii/cli/run.py index 7b0919a9..302ef396 100644 --- a/opentaxii/cli/run.py +++ b/opentaxii/cli/run.py @@ -1,4 +1,3 @@ - from opentaxii.cli import app diff --git a/opentaxii/common/sqldb.py b/opentaxii/common/sqldb.py index 65720a1b..2b66b33e 100644 --- a/opentaxii/common/sqldb.py +++ b/opentaxii/common/sqldb.py @@ -20,7 +20,7 @@ def __init__(self, db_connection, create_tables=False, **engine_parameters): "autocommit": False, "autoflush": True, }, - **engine_parameters + **engine_parameters, ) if create_tables: self.db.create_all_tables() diff --git a/opentaxii/config.py b/opentaxii/config.py index b8f2f87a..5e257cf1 100644 --- a/opentaxii/config.py +++ b/opentaxii/config.py @@ -99,7 +99,7 @@ def _get_env_config(env=os.environ, optional_env_var=None): continue if key == optional_env_var: continue - key = key[len(ENV_VAR_PREFIX):].lstrip("_").lower() + key = key[len(ENV_VAR_PREFIX) :].lstrip("_").lower() value = yaml.safe_load(value) container = result diff --git a/opentaxii/entities.py b/opentaxii/entities.py index 9ff27a15..8e357eaf 100644 --- a/opentaxii/entities.py +++ b/opentaxii/entities.py @@ -1,4 +1,3 @@ - class Account: '''Represents Account entity. @@ -9,8 +8,7 @@ class Account: :param dict details: additional details of an account ''' - def __init__( - self, id, username, permissions, is_admin=False, **details): + def __init__(self, id, username, permissions, is_admin=False, **details): self.id = id self.username = username self.permissions = permissions @@ -18,16 +16,13 @@ def __init__( self.details = details def can_read(self, collection_name): - return ( - self.is_admin or - self.permissions.get(collection_name) in ('read', 'modify')) + return self.is_admin or self.permissions.get(collection_name) in ( + 'read', + 'modify', + ) def can_modify(self, collection_name): - return ( - self.is_admin or - self.permissions.get(collection_name) == 'modify') + return self.is_admin or self.permissions.get(collection_name) == 'modify' def __repr__(self): - return ( - 'Account(username={}, is_admin={})' - .format(self.username, self.is_admin)) + return 'Account(username={}, is_admin={})'.format(self.username, self.is_admin) diff --git a/opentaxii/exceptions.py b/opentaxii/exceptions.py index b4aaabcd..2f372ba5 100644 --- a/opentaxii/exceptions.py +++ b/opentaxii/exceptions.py @@ -1,4 +1,3 @@ - from .taxii.exceptions import UnauthorizedStatus diff --git a/opentaxii/http.py b/opentaxii/http.py index 7a57b282..f78fa29e 100644 --- a/opentaxii/http.py +++ b/opentaxii/http.py @@ -1,10 +1,8 @@ - -from .middleware import create_app from .config import ServerConfig +from .middleware import create_app from .server import TAXIIServer from .utils import configure_logging - # This module is also used as a Gunicorn configuration module, i.e. passed # as ``--config python:opentaxii.http``. ``logconfig_dict`` module-level # variable is recognised by Gunicorn >= 19.8. The desired effect is to @@ -15,12 +13,7 @@ 'version': 1, 'disable_existing_loggers': False, 'root': {}, - 'loggers': { - 'gunicorn.error': { - 'level': 'INFO', - 'propagate': True - } - } + 'loggers': {'gunicorn.error': {'level': 'INFO', 'propagate': True}}, } config_obj = ServerConfig() diff --git a/opentaxii/middleware.py b/opentaxii/middleware.py index 02f8ebb1..bd5636c5 100644 --- a/opentaxii/middleware.py +++ b/opentaxii/middleware.py @@ -2,8 +2,7 @@ import structlog from flask import Flask, request -from marshmallow.exceptions import \ - ValidationError as MarshmallowValidationError +from marshmallow.exceptions import ValidationError as MarshmallowValidationError from werkzeug.exceptions import HTTPException from .exceptions import InvalidAuthHeader @@ -42,7 +41,9 @@ def create_app(server): app.register_error_handler(500, server.handle_internal_error) app.register_error_handler(StatusMessageException, server.handle_status_exception) app.register_error_handler(HTTPException, server.handle_http_exception) - app.register_error_handler(MarshmallowValidationError, server.handle_validation_exception) + app.register_error_handler( + MarshmallowValidationError, server.handle_validation_exception + ) app.before_request(functools.partial(create_context_before_request, server)) app.after_request(cleanup_context) return app diff --git a/opentaxii/persistence/__init__.py b/opentaxii/persistence/__init__.py index e02c81c3..1bcf8433 100644 --- a/opentaxii/persistence/__init__.py +++ b/opentaxii/persistence/__init__.py @@ -1,4 +1,7 @@ # flake8: noqa from .api import OpenTAXII2PersistenceAPI, OpenTAXIIPersistenceAPI -from .manager import (BasePersistenceManager, Taxii1PersistenceManager, - Taxii2PersistenceManager) +from .manager import ( + BasePersistenceManager, + Taxii1PersistenceManager, + Taxii2PersistenceManager, +) diff --git a/opentaxii/persistence/api.py b/opentaxii/persistence/api.py index d5782f33..0fc37af3 100644 --- a/opentaxii/persistence/api.py +++ b/opentaxii/persistence/api.py @@ -1,9 +1,14 @@ import datetime from typing import Dict, List, Optional, Tuple -from opentaxii.taxii2.entities import (ApiRoot, Collection, Job, - ManifestRecord, STIXObject, - VersionRecord) +from opentaxii.taxii2.entities import ( + ApiRoot, + Collection, + Job, + ManifestRecord, + STIXObject, + VersionRecord, +) class OpenTAXIIPersistenceAPI: @@ -267,6 +272,7 @@ class OpenTAXII2PersistenceAPI: Stub, pending implementation. """ + @staticmethod def get_next_param(self, kwargs: Dict) -> str: """ @@ -295,9 +301,7 @@ def get_api_roots(self) -> List[ApiRoot]: def get_api_root(self, api_root_id: str) -> Optional[ApiRoot]: raise NotImplementedError - def get_job_and_details( - self, api_root_id: str, job_id: str - ) -> Optional[Job]: + def get_job_and_details(self, api_root_id: str, job_id: str) -> Optional[Job]: raise NotImplementedError def get_collections(self, api_root_id: str) -> List[Collection]: @@ -334,7 +338,9 @@ def get_objects( ) -> Tuple[List[STIXObject], bool, Optional[str]]: raise NotImplementedError - def add_objects(self, api_root_id: str, collection_id: str, objects: List[Dict]) -> Job: + def add_objects( + self, api_root_id: str, collection_id: str, objects: List[Dict] + ) -> Job: raise NotImplementedError def get_object( diff --git a/opentaxii/persistence/exceptions.py b/opentaxii/persistence/exceptions.py index 81dc03f4..6f2ea7e1 100644 --- a/opentaxii/persistence/exceptions.py +++ b/opentaxii/persistence/exceptions.py @@ -1,4 +1,3 @@ - class ResultsNotReady(Exception): pass diff --git a/opentaxii/persistence/manager.py b/opentaxii/persistence/manager.py index 8f3649d3..49dfe36a 100644 --- a/opentaxii/persistence/manager.py +++ b/opentaxii/persistence/manager.py @@ -2,16 +2,27 @@ from typing import Dict, List, Optional, Tuple import structlog + from opentaxii.local import context -from opentaxii.persistence.exceptions import (DoesNotExistError, - NoReadNoWritePermission, - NoReadPermission, - NoWritePermission) -from opentaxii.signals import (CONTENT_BLOCK_CREATED, INBOX_MESSAGE_CREATED, - SUBSCRIPTION_CREATED) -from opentaxii.taxii2.entities import (ApiRoot, Collection, Job, - ManifestRecord, STIXObject, - VersionRecord) +from opentaxii.persistence.exceptions import ( + DoesNotExistError, + NoReadNoWritePermission, + NoReadPermission, + NoWritePermission, +) +from opentaxii.signals import ( + CONTENT_BLOCK_CREATED, + INBOX_MESSAGE_CREATED, + SUBSCRIPTION_CREATED, +) +from opentaxii.taxii2.entities import ( + ApiRoot, + Collection, + Job, + ManifestRecord, + STIXObject, + VersionRecord, +) log = structlog.getLogger(__name__) diff --git a/opentaxii/persistence/sqldb/common.py b/opentaxii/persistence/sqldb/common.py index c37cbd77..2c330fc6 100644 --- a/opentaxii/persistence/sqldb/common.py +++ b/opentaxii/persistence/sqldb/common.py @@ -1,4 +1,5 @@ """A module to put common database helper components.""" + import uuid from datetime import timezone diff --git a/opentaxii/persistence/sqldb/converters.py b/opentaxii/persistence/sqldb/converters.py index bd028ea3..8c736d43 100644 --- a/opentaxii/persistence/sqldb/converters.py +++ b/opentaxii/persistence/sqldb/converters.py @@ -1,4 +1,5 @@ import json + import pytz from opentaxii.taxii import entities @@ -17,7 +18,7 @@ def to_collection_entity(model): supported_content=deserialize_content_bindings(model.bindings), # TODO: Explicit integer # pending: https://github.com/TAXIIProject/libtaxii/issues/191 - volume=int(model.volume) + volume=int(model.volume), ) @@ -32,7 +33,8 @@ def to_block_entity(model): content=model.content, timestamp_label=enforce_timezone(model.timestamp_label), content_binding=entities.ContentBindingEntity( - model.binding_id, subtypes=subtypes), + model.binding_id, subtypes=subtypes + ), message=model.message, inbox_message_id=model.inbox_message_id, ) @@ -60,9 +62,12 @@ def to_inbox_message_entity(model): subscription_collection_name=model.subscription_collection_name, subscription_id=model.subscription_id, exclusive_begin_timestamp_label=enforce_timezone( - model.exclusive_begin_timestamp_label), + model.exclusive_begin_timestamp_label + ), inclusive_end_timestamp_label=enforce_timezone( - model.inclusive_end_timestamp_label)) + model.inclusive_end_timestamp_label + ), + ) def to_result_set_entity(model): @@ -74,7 +79,9 @@ def to_result_set_entity(model): content_bindings=deserialize_content_bindings(model.bindings), timeframe=( enforce_timezone(model.begin_time), - enforce_timezone(model.end_time))) + enforce_timezone(model.end_time), + ), + ) def to_subscription_entity(model): @@ -85,7 +92,8 @@ def to_subscription_entity(model): parsed = dict(json.loads(model.params)) if parsed['content_bindings']: parsed['content_bindings'] = deserialize_content_bindings( - parsed['content_bindings']) + parsed['content_bindings'] + ) params = entities.PollRequestParametersEntity(**parsed) else: params = None @@ -95,7 +103,7 @@ def to_subscription_entity(model): subscription_id=model.id, collection_id=model.collection_id, poll_request_params=params, - status=model.status + status=model.status, ) @@ -103,9 +111,8 @@ def to_service_entity(model): if not model: return return entities.ServiceEntity( - id=model.id, - type=model.type, - properties=model.properties) + id=model.id, type=model.type, properties=model.properties + ) def serialize_content_bindings(content_bindings): @@ -116,7 +123,7 @@ def deserialize_content_bindings(content_bindings): raw_bindings = json.loads(content_bindings) bindings = [] - for (binding, subtypes) in raw_bindings: + for binding, subtypes in raw_bindings: entity = entities.ContentBindingEntity(binding, subtypes=subtypes) bindings.append(entity) diff --git a/opentaxii/persistence/sqldb/models.py b/opentaxii/persistence/sqldb/models.py index cd694084..272f6c25 100644 --- a/opentaxii/persistence/sqldb/models.py +++ b/opentaxii/persistence/sqldb/models.py @@ -1,14 +1,21 @@ import json -import pytz from datetime import datetime +import pytz from sqlalchemy import schema, types -from sqlalchemy.orm import relationship, validates -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.dialects import mysql +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship, validates -__all__ = ['Base', 'ContentBlock', 'DataCollection', 'Service', - 'InboxMessage', 'ResultSet', 'Subscription'] +__all__ = [ + 'Base', + 'ContentBlock', + 'DataCollection', + 'Service', + 'InboxMessage', + 'ResultSet', + 'Subscription', +] Base = declarative_base(name='Model') @@ -22,8 +29,7 @@ def get_utc_now(): class AbstractModel(Base): __abstract__ = True - date_created = schema.Column( - types.DateTime(timezone=True), default=get_utc_now) + date_created = schema.Column(types.DateTime(timezone=True), default=get_utc_now) collection_to_content_block = schema.Table( @@ -32,13 +38,15 @@ class AbstractModel(Base): schema.Column( 'collection_id', types.Integer, - schema.ForeignKey('data_collections.id', ondelete='CASCADE')), + schema.ForeignKey('data_collections.id', ondelete='CASCADE'), + ), schema.Column( 'content_block_id', types.Integer, schema.ForeignKey('content_blocks.id', ondelete='CASCADE'), - index=True), - schema.PrimaryKeyConstraint('collection_id', 'content_block_id') + index=True, + ), + schema.PrimaryKeyConstraint('collection_id', 'content_block_id'), ) @@ -50,14 +58,14 @@ class ContentBlock(AbstractModel): message = schema.Column(types.Text, nullable=True) timestamp_label = schema.Column( - types.DateTime(timezone=True), - default=get_utc_now, index=True) + types.DateTime(timezone=True), default=get_utc_now, index=True + ) inbox_message_id = schema.Column( types.Integer, - schema.ForeignKey( - 'inbox_messages.id', onupdate='CASCADE', ondelete='CASCADE'), - nullable=True) + schema.ForeignKey('inbox_messages.id', onupdate='CASCADE', ondelete='CASCADE'), + nullable=True, + ) content_type = types.LargeBinary().with_variant(MYSQL_LARGE_BINARY, 'mysql') content = schema.Column(content_type, nullable=False) @@ -69,7 +77,8 @@ class ContentBlock(AbstractModel): 'DataCollection', secondary=collection_to_content_block, backref='content_blocks', - lazy='dynamic') + lazy='dynamic', + ) @validates('collections', include_removes=True, include_backrefs=True) def _update_volume(self, key, collection, is_remove): @@ -80,9 +89,11 @@ def _update_volume(self, key, collection, is_remove): return collection def __repr__(self): - return ('ContentBlock(id={obj.id}, ' - 'inbox_message={obj.inbox_message_id}, ' - 'binding={obj.binding_subtype})').format(obj=self) + return ( + 'ContentBlock(id={obj.id}, ' + 'inbox_message={obj.inbox_message_id}, ' + 'binding={obj.binding_subtype})' + ).format(obj=self) service_to_collection = schema.Table( @@ -91,12 +102,14 @@ def __repr__(self): schema.Column( 'service_id', types.String(150), - schema.ForeignKey('services.id', ondelete='CASCADE')), + schema.ForeignKey('services.id', ondelete='CASCADE'), + ), schema.Column( 'collection_id', types.Integer, - schema.ForeignKey('data_collections.id', ondelete='CASCADE')), - schema.PrimaryKeyConstraint('service_id', 'collection_id') + schema.ForeignKey('data_collections.id', ondelete='CASCADE'), + ), + schema.PrimaryKeyConstraint('service_id', 'collection_id'), ) @@ -110,12 +123,10 @@ class Service(AbstractModel): _properties = schema.Column(types.Text, nullable=False) collections = relationship( - 'DataCollection', - secondary=service_to_collection, - backref='services') + 'DataCollection', secondary=service_to_collection, backref='services' + ) - date_updated = schema.Column( - types.DateTime(timezone=True), default=get_utc_now) + date_updated = schema.Column(types.DateTime(timezone=True), default=get_utc_now) @property def properties(self): @@ -142,8 +153,7 @@ class DataCollection(AbstractModel): volume = schema.Column(types.Integer, default=0) def __repr__(self): - return ('DataCollection(name={obj.name}, type={obj.type})' - .format(obj=self)) + return 'DataCollection(name={obj.name}, type={obj.type})'.format(obj=self) class InboxMessage(AbstractModel): @@ -162,11 +172,15 @@ class InboxMessage(AbstractModel): subscription_id = schema.Column(types.Text, nullable=True) exclusive_begin_timestamp_label = schema.Column( - types.DateTime(timezone=True), nullable=True) + types.DateTime(timezone=True), nullable=True + ) inclusive_end_timestamp_label = schema.Column( - types.DateTime(timezone=True), nullable=True) + types.DateTime(timezone=True), nullable=True + ) - original_message_type = types.LargeBinary().with_variant(MYSQL_LARGE_BINARY, 'mysql') + original_message_type = types.LargeBinary().with_variant( + MYSQL_LARGE_BINARY, 'mysql' + ) original_message = schema.Column(original_message_type, nullable=False) content_block_count = schema.Column(types.Integer) @@ -176,14 +190,15 @@ class InboxMessage(AbstractModel): service_id = schema.Column( types.String(150), - schema.ForeignKey( - 'services.id', onupdate="CASCADE", ondelete="CASCADE")) + schema.ForeignKey('services.id', onupdate="CASCADE", ondelete="CASCADE"), + ) service = relationship('Service', backref='inbox_messages') def __repr__(self): - return ('InboxMessage(id={obj.message_id}, created={obj.date_created})' - .format(obj=self)) + return 'InboxMessage(id={obj.message_id}, created={obj.date_created})'.format( + obj=self + ) class ResultSet(AbstractModel): @@ -195,7 +210,9 @@ class ResultSet(AbstractModel): collection_id = schema.Column( types.Integer, schema.ForeignKey( - 'data_collections.id', onupdate='CASCADE', ondelete='CASCADE')) + 'data_collections.id', onupdate='CASCADE', ondelete='CASCADE' + ), + ) collection = relationship('DataCollection', backref='result_sets') @@ -214,7 +231,9 @@ class Subscription(AbstractModel): collection_id = schema.Column( types.Integer, schema.ForeignKey( - 'data_collections.id', onupdate='CASCADE', ondelete='CASCADE')) + 'data_collections.id', onupdate='CASCADE', ondelete='CASCADE' + ), + ) collection = relationship('DataCollection', backref='subscriptions') params = schema.Column(types.Text, nullable=True) @@ -224,6 +243,6 @@ class Subscription(AbstractModel): service_id = schema.Column( types.String(150), - schema.ForeignKey( - 'services.id', onupdate="CASCADE", ondelete="CASCADE")) + schema.ForeignKey('services.id', onupdate="CASCADE", ondelete="CASCADE"), + ) service = relationship('Service', backref='subscriptions') diff --git a/opentaxii/persistence/sqldb/taxii2models.py b/opentaxii/persistence/sqldb/taxii2models.py index ead4333d..3a09d2f5 100644 --- a/opentaxii/persistence/sqldb/taxii2models.py +++ b/opentaxii/persistence/sqldb/taxii2models.py @@ -1,14 +1,16 @@ """Database models for taxii2 entities.""" + import datetime import uuid import sqlalchemy -from opentaxii.persistence.sqldb.common import GUID, UTCDateTime -from opentaxii.taxii2 import entities from sqlalchemy import literal from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship +from opentaxii.persistence.sqldb.common import GUID, UTCDateTime +from opentaxii.taxii2 import entities + Base = declarative_base() diff --git a/opentaxii/server.py b/opentaxii/server.py index 4b848f7a..87827e5b 100644 --- a/opentaxii/server.py +++ b/opentaxii/server.py @@ -11,20 +11,30 @@ import structlog from flask import Flask, Response, request -from werkzeug.exceptions import (Forbidden, MethodNotAllowed, NotAcceptable, - NotFound, RequestEntityTooLarge, Unauthorized, - UnsupportedMediaType) - -from opentaxii.persistence.exceptions import (DoesNotExistError, - NoReadNoWritePermission, - NoReadPermission, - NoWritePermission) +from werkzeug.exceptions import ( + Forbidden, + MethodNotAllowed, + NotAcceptable, + NotFound, + RequestEntityTooLarge, + Unauthorized, + UnsupportedMediaType, +) + +from opentaxii.persistence.exceptions import ( + DoesNotExistError, + NoReadNoWritePermission, + NoReadPermission, + NoWritePermission, +) from opentaxii.taxii2.utils import taxii2_datetimeformat -from opentaxii.taxii2.validation import (validate_delete_filter_params, - validate_envelope, - validate_list_filter_params, - validate_object_filter_params, - validate_versions_filter_params) +from opentaxii.taxii2.validation import ( + validate_delete_filter_params, + validate_envelope, + validate_list_filter_params, + validate_object_filter_params, + validate_versions_filter_params, +) from opentaxii.utils import register_handler from .auth import AuthManager @@ -32,20 +42,30 @@ from .entities import Account from .exceptions import UnauthorizedException from .local import context -from .persistence import (BasePersistenceManager, Taxii1PersistenceManager, - Taxii2PersistenceManager) +from .persistence import ( + BasePersistenceManager, + Taxii1PersistenceManager, + Taxii2PersistenceManager, +) from .taxii2.http import make_taxii2_response -from .taxii.bindings import (ALL_PROTOCOL_BINDINGS, MESSAGE_BINDINGS, - SERVICE_BINDINGS) -from .taxii.exceptions import (FailureStatus, StatusMessageException, - raise_failure) -from .taxii.http import (HTTP_ALLOW, HTTP_X_TAXII_CONTENT_TYPES, - get_content_type, get_http_headers, - make_taxii_response, validate_request_headers, - validate_request_headers_post_parse, - validate_response_headers) -from .taxii.services import (CollectionManagementService, DiscoveryService, - InboxService, PollService) +from .taxii.bindings import ALL_PROTOCOL_BINDINGS, MESSAGE_BINDINGS, SERVICE_BINDINGS +from .taxii.exceptions import FailureStatus, StatusMessageException, raise_failure +from .taxii.http import ( + HTTP_ALLOW, + HTTP_X_TAXII_CONTENT_TYPES, + get_content_type, + get_http_headers, + make_taxii_response, + validate_request_headers, + validate_request_headers_post_parse, + validate_response_headers, +) +from .taxii.services import ( + CollectionManagementService, + DiscoveryService, + InboxService, + PollService, +) from .taxii.services.abstract import TAXIIService from .taxii.status import process_status_exception from .taxii.utils import configure_libtaxii_xml_parser, parse_message diff --git a/opentaxii/sqldb_helper.py b/opentaxii/sqldb_helper.py index a80b0b55..8d18e78f 100644 --- a/opentaxii/sqldb_helper.py +++ b/opentaxii/sqldb_helper.py @@ -54,8 +54,7 @@ def create_scoped_session(self, options=None): options.setdefault('query_cls', self.Query) - return orm.scoped_session( - self.create_session(options), scopefunc=get_ident) + return orm.scoped_session(self.create_session(options), scopefunc=get_ident) def create_session(self, options): kwargs = { diff --git a/opentaxii/taxii/bindings.py b/opentaxii/taxii/bindings.py index fa3f8597..4af06e47 100644 --- a/opentaxii/taxii/bindings.py +++ b/opentaxii/taxii/bindings.py @@ -1,15 +1,20 @@ from collections import namedtuple -import libtaxii.messages_11 as tm11 import libtaxii.messages_10 as tm10 -from libtaxii.validation import SchemaValidator +import libtaxii.messages_11 as tm11 from libtaxii.constants import ( - CB_STIX_XML_10, CB_STIX_XML_101, CB_STIX_XML_11, CB_STIX_XML_111, - VID_TAXII_HTTP_10, VID_TAXII_HTTPS_10, - VID_TAXII_XML_10, VID_TAXII_XML_11, - VID_TAXII_SERVICES_10, VID_TAXII_SERVICES_11 + CB_STIX_XML_10, + CB_STIX_XML_11, + CB_STIX_XML_101, + CB_STIX_XML_111, + VID_TAXII_HTTP_10, + VID_TAXII_HTTPS_10, + VID_TAXII_SERVICES_10, + VID_TAXII_SERVICES_11, + VID_TAXII_XML_10, + VID_TAXII_XML_11, ) - +from libtaxii.validation import SchemaValidator ValidatorAndParser = namedtuple('ValidatorAndParser', ['validator', 'parser']) @@ -20,31 +25,19 @@ CB_STIX_XML_111, ] -ALL_PROTOCOL_BINDINGS = [ - VID_TAXII_HTTP_10, - VID_TAXII_HTTPS_10 -] +ALL_PROTOCOL_BINDINGS = [VID_TAXII_HTTP_10, VID_TAXII_HTTPS_10] -PROTOCOL_TO_SCHEME = { - VID_TAXII_HTTP_10: 'http://', - VID_TAXII_HTTPS_10: 'https://' -} +PROTOCOL_TO_SCHEME = {VID_TAXII_HTTP_10: 'http://', VID_TAXII_HTTPS_10: 'https://'} -MESSAGE_BINDINGS = [ - VID_TAXII_XML_10, - VID_TAXII_XML_11 -] +MESSAGE_BINDINGS = [VID_TAXII_XML_10, VID_TAXII_XML_11] -SERVICE_BINDINGS = [ - VID_TAXII_SERVICES_10, - VID_TAXII_SERVICES_11 -] +SERVICE_BINDINGS = [VID_TAXII_SERVICES_10, VID_TAXII_SERVICES_11] MESSAGE_VALIDATOR_PARSER = { VID_TAXII_XML_10: ValidatorAndParser( - SchemaValidator(SchemaValidator.TAXII_10_SCHEMA), - tm10.get_message_from_xml), + SchemaValidator(SchemaValidator.TAXII_10_SCHEMA), tm10.get_message_from_xml + ), VID_TAXII_XML_11: ValidatorAndParser( - SchemaValidator(SchemaValidator.TAXII_11_SCHEMA), - tm11.get_message_from_xml) + SchemaValidator(SchemaValidator.TAXII_11_SCHEMA), tm11.get_message_from_xml + ), } diff --git a/opentaxii/taxii/converters.py b/opentaxii/taxii/converters.py index 050d6976..0a12e90f 100644 --- a/opentaxii/taxii/converters.py +++ b/opentaxii/taxii/converters.py @@ -1,27 +1,29 @@ -import six -import libtaxii.messages_11 as tm11 import libtaxii.messages_10 as tm10 - +import libtaxii.messages_11 as tm11 +import six from libtaxii.constants import ( - SVC_COLLECTION_MANAGEMENT, SVC_FEED_MANAGEMENT, - VID_TAXII_SERVICES_10, VID_TAXII_SERVICES_11 + SVC_COLLECTION_MANAGEMENT, + SVC_FEED_MANAGEMENT, + VID_TAXII_SERVICES_10, + VID_TAXII_SERVICES_11, ) from .entities import ( - ContentBindingEntity, InboxMessageEntity, ContentBlockEntity, - ServiceEntity + ContentBindingEntity, + ContentBlockEntity, + InboxMessageEntity, + ServiceEntity, ) def parse_content_binding(raw_content_binding, version): if version == 10: - return ContentBindingEntity( - binding=raw_content_binding, - subtypes=None) + return ContentBindingEntity(binding=raw_content_binding, subtypes=None) if version == 11: return ContentBindingEntity( binding=raw_content_binding.binding_id, - subtypes=raw_content_binding.subtype_ids) + subtypes=raw_content_binding.subtype_ids, + ) raise ValueError('invalid version') @@ -34,15 +36,15 @@ def content_binding_entity_to_content_binding(content_binding, version): return content_binding.binding if version == 11: return tm11.ContentBinding( - binding_id=content_binding.binding, - subtype_ids=content_binding.subtypes) + binding_id=content_binding.binding, subtype_ids=content_binding.subtypes + ) raise ValueError('invalid version') def content_binding_entities_to_content_bindings(content_bindings, version): return [ - content_binding_entity_to_content_binding(c, version) - for c in content_bindings] + content_binding_entity_to_content_binding(c, version) for c in content_bindings + ] def service_to_service_instances(service, version): @@ -64,7 +66,7 @@ def service_to_service_instances(service, version): protocol_binding=binding, service_address=address, message_bindings=service.supported_message_bindings, - message=service.description + message=service.description, ) elif version == 11: instance = tm11.ServiceInstance( @@ -74,7 +76,7 @@ def service_to_service_instances(service, version): protocol_binding=binding, service_address=address, message_bindings=service.supported_message_bindings, - message=service.description + message=service.description, ) else: raise ValueError('invalid version') @@ -84,8 +86,9 @@ def service_to_service_instances(service, version): # PollingServiceInstance vs PollInstance -def poll_service_to_polling_service_instance(service, version, - is_poll_instance_cls=False): +def poll_service_to_polling_service_instance( + service, version, is_poll_instance_cls=False +): instances = [] @@ -101,7 +104,8 @@ def poll_service_to_polling_service_instance(service, version, instance = cls( poll_protocol=binding, poll_address=address, - poll_message_bindings=service.supported_message_bindings) + poll_message_bindings=service.supported_message_bindings, + ) instances.append(instance) @@ -119,7 +123,7 @@ def subscription_service_to_subscription_method(service, version): instance = module.SubscriptionMethod( subscription_protocol=binding, subscription_address=address, - subscription_message_bindings=service.supported_message_bindings + subscription_message_bindings=service.supported_message_bindings, ) instances.append(instance) @@ -131,12 +135,14 @@ def inbox_to_receiving_inbox_instance(inbox): for protocol_binding in inbox.supported_protocol_bindings: - inbox_instances.append(tm11.ReceivingInboxService( - inbox_protocol=protocol_binding, - inbox_address=inbox.get_absolute_address(protocol_binding), - inbox_message_bindings=inbox.supported_message_bindings, - supported_contents=inbox.get_supported_content(version=11) - )) + inbox_instances.append( + tm11.ReceivingInboxService( + inbox_protocol=protocol_binding, + inbox_address=inbox.get_absolute_address(protocol_binding), + inbox_message_bindings=inbox.supported_message_bindings, + supported_contents=inbox.get_supported_content(version=11), + ) + ) return inbox_instances @@ -146,20 +152,23 @@ def collection_to_feedcollection_information(service, collection, version): polling_instances = [] for poll in service.get_polling_services(collection): polling_instances.extend( - poll_service_to_polling_service_instance(poll, version=version)) + poll_service_to_polling_service_instance(poll, version=version) + ) push_methods = service.get_push_methods(collection) subscription_methods = [] for s in service.get_subscription_services(collection): subscription_methods.extend( - subscription_service_to_subscription_method(s, version=version)) + subscription_service_to_subscription_method(s, version=version) + ) if collection.accept_all_content: supported_content = [] else: supported_content = content_binding_entities_to_content_bindings( - collection.supported_content, version=version) + collection.supported_content, version=version + ) if version == 11: inbox_instances = [] @@ -171,14 +180,12 @@ def collection_to_feedcollection_information(service, collection, version): collection_description=collection.description, supported_contents=supported_content, available=collection.available, - push_methods=push_methods, polling_service_instances=polling_instances, subscription_methods=subscription_methods, - collection_volume=collection.volume, collection_type=collection.type, - receiving_inbox_services=inbox_instances + receiving_inbox_services=inbox_instances, ) if version == 10: @@ -187,10 +194,9 @@ def collection_to_feedcollection_information(service, collection, version): feed_description=collection.description, supported_contents=supported_content, available=collection.available, - push_methods=push_methods, polling_service_instances=polling_instances, - subscription_methods=subscription_methods + subscription_methods=subscription_methods, # collection_volume, collection_type, and # receiving_inbox_services are not supported in TAXII 1.0 ) @@ -198,15 +204,17 @@ def collection_to_feedcollection_information(service, collection, version): raise ValueError('invalid version') -def subscription_to_subscription_instance(subscription, polling_services, - version, - subscription_parameters=None): +def subscription_to_subscription_instance( + subscription, polling_services, version, subscription_parameters=None +): polling_instances = [] for poll in polling_services: polling_instances.extend( poll_service_to_polling_service_instance( - poll, version=version, is_poll_instance_cls=True)) + poll, version=version, is_poll_instance_cls=True + ) + ) params = dict( subscription_id=subscription.subscription_id, @@ -220,18 +228,21 @@ def subscription_to_subscription_instance(subscription, polling_services, if version == 11: push_params = None - params.update(dict( - status=subscription.status, - push_parameters=push_params, - )) + params.update( + dict( + status=subscription.status, + push_parameters=push_params, + ) + ) if subscription_parameters: bindings = content_binding_entities_to_content_bindings( - subscription_parameters.content_bindings, version=version) + subscription_parameters.content_bindings, version=version + ) params['subscription_parameters'] = tm11.SubscriptionParameters( response_type=subscription_parameters.response_type, - content_bindings=bindings + content_bindings=bindings, ) return tm11.SubscriptionInstance(**params) @@ -243,11 +254,10 @@ def inbox_message_to_inbox_message_entity(inbox_message, service_id, version): params = dict( message_id=inbox_message.message_id, - # FIXME: how to get raw value? original_message=inbox_message.to_xml(), content_block_count=len(inbox_message.content_blocks), - service_id=service_id + service_id=service_id, ) if version == 10: @@ -255,47 +265,55 @@ def inbox_message_to_inbox_message_entity(inbox_message, service_id, version): si = inbox_message.subscription_information begin = si.inclusive_begin_timestamp_label end = si.inclusive_end_timestamp_label - params.update(dict( - subscription_collection_name=si.feed_name, - subscription_id=si.subscription_id, - - # TODO: Match up exclusive vs inclusive - exclusive_begin_timestamp_label=begin, - inclusive_end_timestamp_label=end - )) + params.update( + dict( + subscription_collection_name=si.feed_name, + subscription_id=si.subscription_id, + # TODO: Match up exclusive vs inclusive + exclusive_begin_timestamp_label=begin, + inclusive_end_timestamp_label=end, + ) + ) return InboxMessageEntity(**params) if version == 11: - params.update(dict( - result_id=inbox_message.result_id, - destination_collections=inbox_message.destination_collection_names, - )) + params.update( + dict( + result_id=inbox_message.result_id, + destination_collections=inbox_message.destination_collection_names, + ) + ) if inbox_message.record_count: - params.update(dict( - record_count=inbox_message.record_count.record_count, - partial_count=inbox_message.record_count.partial_count - )) + params.update( + dict( + record_count=inbox_message.record_count.record_count, + partial_count=inbox_message.record_count.partial_count, + ) + ) if inbox_message.subscription_information: si = inbox_message.subscription_information begin = si.exclusive_begin_timestamp_label end = si.inclusive_end_timestamp_label - params.update(dict( - subscription_collection_name=si.collection_name, - subscription_id=si.subscription_id, - exclusive_begin_timestamp_label=begin, - inclusive_end_timestamp_label=end - )) + params.update( + dict( + subscription_collection_name=si.collection_name, + subscription_id=si.subscription_id, + exclusive_begin_timestamp_label=begin, + inclusive_end_timestamp_label=end, + ) + ) return InboxMessageEntity(**params) raise ValueError('invalid version') -def content_block_to_content_block_entity(content_block, version, - inbox_message_id=None): +def content_block_to_content_block_entity( + content_block, version, inbox_message_id=None +): content_binding = parse_content_binding( - content_block.content_binding, - version=version) + content_block.content_binding, version=version + ) message = content_block.message if version == 11 else None @@ -306,33 +324,36 @@ def content_block_to_content_block_entity(content_block, version, inbox_message_id=inbox_message_id, content=content_block.content, timestamp_label=content_block.timestamp_label, - content_binding=content_binding + content_binding=content_binding, # padding = content_block.padding, ) def content_block_entity_to_content_block(entity, version): content_bindings = content_binding_entity_to_content_binding( - entity.content_binding, - version=version) + entity.content_binding, version=version + ) # Libtaxii requires content to be unicode content = ( entity.content if isinstance(entity.content, six.string_types) - else entity.content.decode('utf-8')) + else entity.content.decode('utf-8') + ) if version == 10: return tm10.ContentBlock( content_binding=content_bindings, content=content, - timestamp_label=entity.timestamp_label) + timestamp_label=entity.timestamp_label, + ) if version == 11: return tm11.ContentBlock( content_binding=content_bindings, content=content, timestamp_label=entity.timestamp_label, - message=entity.message) + message=entity.message, + ) raise ValueError('invalid version') diff --git a/opentaxii/taxii/entities.py b/opentaxii/taxii/entities.py index 20f8688e..b1ad5111 100644 --- a/opentaxii/taxii/entities.py +++ b/opentaxii/taxii/entities.py @@ -1,6 +1,14 @@ import six -from libtaxii.constants import (CT_DATA_FEED, CT_DATA_SET, RT_COUNT_ONLY, - RT_FULL, SS_ACTIVE, SS_PAUSED, SS_UNSUBSCRIBED) +from libtaxii.constants import ( + CT_DATA_FEED, + CT_DATA_SET, + RT_COUNT_ONLY, + RT_FULL, + SS_ACTIVE, + SS_PAUSED, + SS_UNSUBSCRIBED, +) + from opentaxii.common.entities import Entity from .utils import is_content_supported @@ -37,7 +45,7 @@ def __init__(self, binding, subtypes=None): def deserialize_content_bindings(supported_content): bindings = [] - for content in (supported_content or []): + for content in supported_content or []: if not content: continue if isinstance(content, six.string_types): @@ -70,9 +78,17 @@ class CollectionEntity(Entity): TYPE_FEED = CT_DATA_FEED TYPE_SET = CT_DATA_SET - def __init__(self, name, id=None, description=None, type=TYPE_FEED, - volume=None, accept_all_content=False, - supported_content=None, available=True): + def __init__( + self, + name, + id=None, + description=None, + type=TYPE_FEED, + volume=None, + accept_all_content=False, + supported_content=None, + available=True, + ): self.id = id self.name = name @@ -83,8 +99,7 @@ def __init__(self, name, id=None, description=None, type=TYPE_FEED, if type not in [self.TYPE_FEED, self.TYPE_SET]: raise ValueError('Unknown collection type "%s"' % type) self.type = type - self.supported_content = ( - deserialize_content_bindings(supported_content)) + self.supported_content = deserialize_content_bindings(supported_content) def is_content_supported(self, content_binding): if self.accept_all_content: @@ -120,20 +135,22 @@ def get_matching_bindings(self, requested_bindings): overlap.append(supported) continue - subtypes_overlap = ( - set(supported.subtypes).intersection(requested.subtypes)) + subtypes_overlap = set(supported.subtypes).intersection( + requested.subtypes + ) - overlap.append(ContentBindingEntity( - binding=requested.binding, - subtypes=subtypes_overlap - )) + overlap.append( + ContentBindingEntity( + binding=requested.binding, subtypes=subtypes_overlap + ) + ) return overlap def __repr__(self): - return ( - "CollectionEntity(name={}, type={}, supported_content={})" - .format(self.name, self.type, self.supported_content)) + return "CollectionEntity(name={}, type={}, supported_content={})".format( + self.name, self.type, self.supported_content + ) class ContentBlockEntity(Entity): @@ -147,8 +164,15 @@ class ContentBlockEntity(Entity): :param str inbox_message_id: internal ID of the inbox message entity ''' - def __init__(self, content, timestamp_label, content_binding=None, id=None, - message=None, inbox_message_id=None): + def __init__( + self, + content, + timestamp_label, + content_binding=None, + id=None, + message=None, + inbox_message_id=None, + ): self.content = content @@ -188,12 +212,22 @@ class InboxMessageEntity(Entity): subscription's inclusive begin timestamp label ''' - def __init__(self, message_id, original_message, content_block_count, - service_id, id=None, result_id=None, - destination_collections=None, record_count=None, - partial_count=False, subscription_collection_name=None, - subscription_id=None, exclusive_begin_timestamp_label=None, - inclusive_end_timestamp_label=None): + def __init__( + self, + message_id, + original_message, + content_block_count, + service_id, + id=None, + result_id=None, + destination_collections=None, + record_count=None, + partial_count=False, + subscription_collection_name=None, + subscription_id=None, + exclusive_begin_timestamp_label=None, + inclusive_end_timestamp_label=None, + ): self.id = id @@ -227,8 +261,7 @@ class ResultSetEntity(Entity): a timeframe of the Result Set in a form of ``(begin, end)`` ''' - def __init__(self, id, collection_id, content_bindings=None, - timeframe=None): + def __init__(self, id, collection_id, content_bindings=None, timeframe=None): self.id = id @@ -268,11 +301,13 @@ class PollRequestParametersEntity(SubscriptionParameters): list of :class:`ContentBindingEntity` instances ''' - def __init__(self, response_type=SubscriptionParameters.FULL, - content_bindings=None): + def __init__( + self, response_type=SubscriptionParameters.FULL, content_bindings=None + ): super(PollRequestParametersEntity, self).__init__( - response_type=response_type, content_bindings=content_bindings) + response_type=response_type, content_bindings=content_bindings + ) class SubscriptionEntity(Entity): @@ -291,8 +326,14 @@ class SubscriptionEntity(Entity): PAUSED = SS_PAUSED UNSUBSCRIBED = SS_UNSUBSCRIBED - def __init__(self, service_id, collection_id, subscription_id=None, - status=ACTIVE, poll_request_params=None): + def __init__( + self, + service_id, + collection_id, + subscription_id=None, + status=ACTIVE, + poll_request_params=None, + ): self.service_id = service_id self.collection_id = collection_id diff --git a/opentaxii/taxii/exceptions.py b/opentaxii/taxii/exceptions.py index cf784165..e4cab367 100644 --- a/opentaxii/taxii/exceptions.py +++ b/opentaxii/taxii/exceptions.py @@ -1,17 +1,22 @@ import sys + import six -from libtaxii.constants import ( - ST_BAD_MESSAGE, ST_FAILURE, ST_UNAUTHORIZED -) +from libtaxii.constants import ST_BAD_MESSAGE, ST_FAILURE, ST_UNAUTHORIZED class StatusMessageException(Exception): - def __init__(self, status_type, in_response_to='0', message=None, - status_details=None, extended_headers=None, e=None): + def __init__( + self, + status_type, + in_response_to='0', + message=None, + status_details=None, + extended_headers=None, + e=None, + ): - super(StatusMessageException, self).__init__( - e or message or status_type) + super(StatusMessageException, self).__init__(e or message or status_type) self.in_response_to = in_response_to self.status_type = status_type @@ -24,22 +29,22 @@ class BadMessageStatus(StatusMessageException): def __init__(self, message, **kwargs): super(BadMessageStatus, self).__init__( - ST_BAD_MESSAGE, message=message, **kwargs) + ST_BAD_MESSAGE, message=message, **kwargs + ) class FailureStatus(StatusMessageException): def __init__(self, message, **kwargs): - super(FailureStatus, self).__init__( - ST_FAILURE, message=message, **kwargs) + super(FailureStatus, self).__init__(ST_FAILURE, message=message, **kwargs) class UnauthorizedStatus(StatusMessageException): def __init__(self, status_type=ST_UNAUTHORIZED, **kwargs): super(UnauthorizedStatus, self).__init__( - status_type=status_type.upper(), - **kwargs) + status_type=status_type.upper(), **kwargs + ) def raise_failure(message, in_response_to='0'): @@ -47,4 +52,5 @@ def raise_failure(message, in_response_to='0'): six.reraise( FailureStatus, FailureStatus(message, in_response_to=in_response_to, e=ei), - tb=tb) + tb=tb, + ) diff --git a/opentaxii/taxii/http.py b/opentaxii/taxii/http.py index cb75ffe8..399f62be 100644 --- a/opentaxii/taxii/http.py +++ b/opentaxii/taxii/http.py @@ -1,7 +1,12 @@ from flask import Response, make_response -from libtaxii.constants import (VID_TAXII_HTTP_10, VID_TAXII_HTTPS_10, - VID_TAXII_SERVICES_10, VID_TAXII_SERVICES_11, - VID_TAXII_XML_10, VID_TAXII_XML_11) +from libtaxii.constants import ( + VID_TAXII_HTTP_10, + VID_TAXII_HTTPS_10, + VID_TAXII_SERVICES_10, + VID_TAXII_SERVICES_11, + VID_TAXII_XML_10, + VID_TAXII_XML_11, +) from .exceptions import raise_failure @@ -26,35 +31,38 @@ REQUIRED_REQUEST_HEADERS = BASIC_REQUEST_HEADERS + (HTTP_X_TAXII_SERVICES,) REQUIRED_RESPONSE_HEADERS = ( - HTTP_CONTENT_TYPE, HTTP_X_TAXII_CONTENT_TYPE, - HTTP_X_TAXII_PROTOCOL, HTTP_X_TAXII_SERVICES) + HTTP_CONTENT_TYPE, + HTTP_X_TAXII_CONTENT_TYPE, + HTTP_X_TAXII_PROTOCOL, + HTTP_X_TAXII_SERVICES, +) TAXII_11_HTTPS_HEADERS = { HTTP_CONTENT_TYPE: HTTP_CONTENT_XML, HTTP_X_TAXII_CONTENT_TYPE: VID_TAXII_XML_11, HTTP_X_TAXII_PROTOCOL: VID_TAXII_HTTPS_10, - HTTP_X_TAXII_SERVICES: VID_TAXII_SERVICES_11 + HTTP_X_TAXII_SERVICES: VID_TAXII_SERVICES_11, } TAXII_11_HTTP_HEADERS = { HTTP_CONTENT_TYPE: HTTP_CONTENT_XML, HTTP_X_TAXII_CONTENT_TYPE: VID_TAXII_XML_11, HTTP_X_TAXII_PROTOCOL: VID_TAXII_HTTP_10, - HTTP_X_TAXII_SERVICES: VID_TAXII_SERVICES_11 + HTTP_X_TAXII_SERVICES: VID_TAXII_SERVICES_11, } TAXII_10_HTTPS_HEADERS = { HTTP_CONTENT_TYPE: HTTP_CONTENT_XML, HTTP_X_TAXII_CONTENT_TYPE: VID_TAXII_XML_10, HTTP_X_TAXII_PROTOCOL: VID_TAXII_HTTPS_10, - HTTP_X_TAXII_SERVICES: VID_TAXII_SERVICES_10 + HTTP_X_TAXII_SERVICES: VID_TAXII_SERVICES_10, } TAXII_10_HTTP_HEADERS = { HTTP_CONTENT_TYPE: HTTP_CONTENT_XML, HTTP_X_TAXII_CONTENT_TYPE: VID_TAXII_XML_10, HTTP_X_TAXII_PROTOCOL: VID_TAXII_HTTP_10, - HTTP_X_TAXII_SERVICES: VID_TAXII_SERVICES_10 + HTTP_X_TAXII_SERVICES: VID_TAXII_SERVICES_10, } @@ -78,12 +86,13 @@ def get_http_headers(version, is_secure): # FIXME: should raise a custom error raise ValueError( - "Unknown combination: version={}, is_secure={}" - .format(version, is_secure)) + "Unknown combination: version={}, is_secure={}".format(version, is_secure) + ) -def validate_request_headers_post_parse(headers, supported_message_bindings, - service_bindings, protocol_bindings): +def validate_request_headers_post_parse( + headers, supported_message_bindings, service_bindings, protocol_bindings +): for h in REQUIRED_REQUEST_HEADERS: if h not in headers: @@ -98,21 +107,20 @@ def validate_request_headers_post_parse(headers, supported_message_bindings, # Validate the X-TAXII-Services header if taxii_services not in service_bindings: raise_failure( - "The value of {} was not recognized".format(HTTP_X_TAXII_SERVICES)) + "The value of {} was not recognized".format(HTTP_X_TAXII_SERVICES) + ) # Validate the X-TAXII-Protocol header # FIXME: Look into the service properties # instead of assuming both are supported if taxii_protocol and taxii_protocol not in protocol_bindings: - raise_failure( - "The specified value of X-TAXII-Protocol is not supported") + raise_failure("The specified value of X-TAXII-Protocol is not supported") # Validate the X-TAXII-Accept header # FIXME: Accept more "complex" accept headers # (e.g., ones that specify# more than one value) if taxii_accept and taxii_accept not in supported_message_bindings: - raise_failure( - "The specified value of X-TAXII-Accept is not recognized") + raise_failure("The specified value of X-TAXII-Accept is not recognized") def validate_request_headers(headers, supported_message_bindings): @@ -122,8 +130,10 @@ def validate_request_headers(headers, supported_message_bindings): if headers[HTTP_X_TAXII_CONTENT_TYPE] not in supported_message_bindings: raise_failure( - 'TAXII Content Type "{}" is not supported' - .format(headers[HTTP_X_TAXII_CONTENT_TYPE])) + 'TAXII Content Type "{}" is not supported'.format( + headers[HTTP_X_TAXII_CONTENT_TYPE] + ) + ) if 'application/xml' not in headers[HTTP_CONTENT_TYPE]: raise_failure("The specified value of Content-Type is not supported") @@ -132,8 +142,7 @@ def validate_request_headers(headers, supported_message_bindings): def validate_response_headers(headers): for h in REQUIRED_RESPONSE_HEADERS: if h not in headers: - raise ValueError( - "Required response header not specified: {}".format(h)) + raise ValueError("Required response header not specified: {}".format(h)) def make_taxii_response(taxii_xml, taxii_headers) -> Response: diff --git a/opentaxii/taxii/services/__init__.py b/opentaxii/taxii/services/__init__.py index b7b83cbd..c8128ae9 100644 --- a/opentaxii/taxii/services/__init__.py +++ b/opentaxii/taxii/services/__init__.py @@ -1,6 +1,6 @@ # flake8: noqa -from .inbox import InboxService -from .discovery import DiscoveryService from .collection_management import CollectionManagementService +from .discovery import DiscoveryService +from .inbox import InboxService from .poll import PollService diff --git a/opentaxii/taxii/services/abstract.py b/opentaxii/taxii/services/abstract.py index a187d09d..de99efbe 100644 --- a/opentaxii/taxii/services/abstract.py +++ b/opentaxii/taxii/services/abstract.py @@ -1,13 +1,10 @@ import structlog - from libtaxii.common import generate_message_id -from libtaxii.constants import ( - VID_TAXII_XML_10, VID_TAXII_XML_11 -) +from libtaxii.constants import VID_TAXII_XML_10, VID_TAXII_XML_11 -from ..exceptions import raise_failure from ..bindings import PROTOCOL_TO_SCHEME from ..converters import service_to_service_instances +from ..exceptions import raise_failure class TAXIIService: @@ -37,9 +34,17 @@ class TAXIIService: supported_message_bindings = [VID_TAXII_XML_10, VID_TAXII_XML_11] supported_protocol_bindings = () - def __init__(self, id, server, address, description=None, path=None, - protocol_bindings=None, available=True, - authentication_required=False): + def __init__( + self, + id, + server, + address, + description=None, + path=None, + protocol_bindings=None, + available=True, + authentication_required=False, + ): self.id = id self.server = server @@ -49,19 +54,21 @@ def __init__(self, id, server, address, description=None, path=None, self.description = description self.supported_protocol_bindings = ( - protocol_bindings or self.supported_protocol_bindings) + protocol_bindings or self.supported_protocol_bindings + ) self.available = available self.authentication_required = authentication_required self.log = structlog.getLogger( - "{}.{}".format(self.__module__, self.__class__.__name__), - service_id=id) + "{}.{}".format(self.__module__, self.__class__.__name__), service_id=id + ) if not self.supported_protocol_bindings: self.log.warning( "No protocol bindings specified, service will be invisible", - service=self.id) + service=self.id, + ) def generate_id(self): return generate_message_id() @@ -72,7 +79,8 @@ def process(self, headers, message): "Processing message", message_id=message.message_id, message_type=message.message_type, - message_version=message.version) + message_version=message.version, + ) handler = self.get_message_handler(message) @@ -82,9 +90,9 @@ def process(self, headers, message): response_message = handler.handle_message(self, message) if not response_message: raise_failure( - "The message handler {} did not return a TAXII Message" - .format(handler), - in_response_to=message.message_id) + "The message handler {} did not return a TAXII Message".format(handler), + in_response_to=message.message_id, + ) return response_message @@ -96,10 +104,12 @@ def get_message_handler(self, message): "Message not supported", message_id=message.message_id, message_type=message.message_type, - message_version=message.version) + message_version=message.version, + ) raise_failure( "Message not supported by this service", - in_response_to=message.message_id) + in_response_to=message.message_id, + ) def to_service_instances(self, version): return service_to_service_instances(self, version) @@ -112,12 +122,11 @@ def get_absolute_address(self, binding): if scheme and not address.lower().startswith(scheme): address = scheme + address else: - self.log.warning("binding.not_recognized", - binding=binding, address=address) + self.log.warning("binding.not_recognized", binding=binding, address=address) return address def __repr__(self): - return ( - "{}(id={}, address={})" - .format(self.__class__.__name__, self.id, self.address)) + return "{}(id={}, address={})".format( + self.__class__.__name__, self.id, self.address + ) diff --git a/opentaxii/taxii/services/collection_management.py b/opentaxii/taxii/services/collection_management.py index 39441316..3b0870b8 100644 --- a/opentaxii/taxii/services/collection_management.py +++ b/opentaxii/taxii/services/collection_management.py @@ -1,39 +1,34 @@ from libtaxii.constants import ( - SVC_COLLECTION_MANAGEMENT, - MSG_COLLECTION_INFORMATION_REQUEST, MSG_FEED_INFORMATION_REQUEST, + MSG_COLLECTION_INFORMATION_REQUEST, + MSG_FEED_INFORMATION_REQUEST, MSG_MANAGE_COLLECTION_SUBSCRIPTION_REQUEST, MSG_MANAGE_FEED_SUBSCRIPTION_REQUEST, + SVC_COLLECTION_MANAGEMENT, ) from .abstract import TAXIIService -from .handlers import ( - CollectionInformationRequestHandler, - SubscriptionRequestHandler -) +from .handlers import CollectionInformationRequestHandler, SubscriptionRequestHandler class CollectionManagementService(TAXIIService): handlers = { - MSG_COLLECTION_INFORMATION_REQUEST: - CollectionInformationRequestHandler, - MSG_FEED_INFORMATION_REQUEST: - CollectionInformationRequestHandler, + MSG_COLLECTION_INFORMATION_REQUEST: CollectionInformationRequestHandler, + MSG_FEED_INFORMATION_REQUEST: CollectionInformationRequestHandler, } subscription_handlers = { - MSG_MANAGE_COLLECTION_SUBSCRIPTION_REQUEST: - SubscriptionRequestHandler, - MSG_MANAGE_FEED_SUBSCRIPTION_REQUEST: - SubscriptionRequestHandler + MSG_MANAGE_COLLECTION_SUBSCRIPTION_REQUEST: SubscriptionRequestHandler, + MSG_MANAGE_FEED_SUBSCRIPTION_REQUEST: SubscriptionRequestHandler, } service_type = SVC_COLLECTION_MANAGEMENT subscription_message = "Default subscription message" subscription_supported = True - def __init__(self, subscription_supported=True, subscription_message=None, - **kwargs): + def __init__( + self, subscription_supported=True, subscription_message=None, **kwargs + ): super(CollectionManagementService, self).__init__(**kwargs) self.subscription_message = subscription_message @@ -41,8 +36,7 @@ def __init__(self, subscription_supported=True, subscription_message=None, if self.subscription_supported: self.handlers = dict(CollectionManagementService.handlers) - self.handlers.update( - CollectionManagementService.subscription_handlers) + self.handlers.update(CollectionManagementService.subscription_handlers) @property def advertised_collections(self): @@ -61,7 +55,8 @@ def get_polling_services(self, collection): def get_subscription_services(self, collection): services = [] all_services = self.server.get_services_for_collection( - collection, 'collection_management') + collection, 'collection_management' + ) for s in all_services: if s.subscription_supported: services.append(s) diff --git a/opentaxii/taxii/services/discovery.py b/opentaxii/taxii/services/discovery.py index f91fc79c..4e950f4d 100644 --- a/opentaxii/taxii/services/discovery.py +++ b/opentaxii/taxii/services/discovery.py @@ -1,4 +1,4 @@ -from libtaxii.constants import SVC_DISCOVERY, MSG_DISCOVERY_REQUEST +from libtaxii.constants import MSG_DISCOVERY_REQUEST, SVC_DISCOVERY from .abstract import TAXIIService from .handlers import DiscoveryRequestHandler @@ -8,9 +8,7 @@ class DiscoveryService(TAXIIService): service_type = SVC_DISCOVERY - handlers = { - MSG_DISCOVERY_REQUEST: DiscoveryRequestHandler - } + handlers = {MSG_DISCOVERY_REQUEST: DiscoveryRequestHandler} advertised_services = [] diff --git a/opentaxii/taxii/services/handlers/__init__.py b/opentaxii/taxii/services/handlers/__init__.py index e36c6838..43f70f05 100644 --- a/opentaxii/taxii/services/handlers/__init__.py +++ b/opentaxii/taxii/services/handlers/__init__.py @@ -1,11 +1,8 @@ # flake8: noqa +from .collection_information_request_handlers import CollectionInformationRequestHandler from .discovery_request_handlers import DiscoveryRequestHandler from .inbox_message_handlers import InboxMessageHandler -from .collection_information_request_handlers import CollectionInformationRequestHandler - -from .poll_request_handlers import PollRequestHandler from .poll_fulfilment_request_handlers import PollFulfilmentRequestHandler - +from .poll_request_handlers import PollRequestHandler from .subscription_request_handlers import SubscriptionRequestHandler - diff --git a/opentaxii/taxii/services/handlers/base_handlers.py b/opentaxii/taxii/services/handlers/base_handlers.py index 9e3fb6c0..f58e8a68 100644 --- a/opentaxii/taxii/services/handlers/base_handlers.py +++ b/opentaxii/taxii/services/handlers/base_handlers.py @@ -1,12 +1,17 @@ +from libtaxii.common import generate_message_id from libtaxii.constants import ( - VID_TAXII_XML_10, VID_TAXII_XML_11, - VID_TAXII_SERVICES_10, VID_TAXII_SERVICES_11 + VID_TAXII_SERVICES_10, + VID_TAXII_SERVICES_11, + VID_TAXII_XML_10, + VID_TAXII_XML_11, ) -from libtaxii.common import generate_message_id from ...exceptions import raise_failure from ...http import ( - HTTP_X_TAXII_CONTENT_TYPE, HTTP_X_TAXII_SERVICES, HTTP_X_TAXII_ACCEPT) + HTTP_X_TAXII_ACCEPT, + HTTP_X_TAXII_CONTENT_TYPE, + HTTP_X_TAXII_SERVICES, +) class BaseMessageHandler: @@ -36,40 +41,51 @@ def validate_headers(cls, headers, in_response_to=None): else: raise ValueError( 'The variable "supported_request_messages" ' - 'contained a non-libtaxii message module: {}'. - format(message.__module__)) + 'contained a non-libtaxii message module: {}'.format( + message.__module__ + ) + ) no_support_service_11 = ( - taxii_services == VID_TAXII_SERVICES_11 and not supports_taxii_11) + taxii_services == VID_TAXII_SERVICES_11 and not supports_taxii_11 + ) no_support_service_10 = ( - taxii_services == VID_TAXII_SERVICES_10 and not supports_taxii_10) + taxii_services == VID_TAXII_SERVICES_10 and not supports_taxii_10 + ) if no_support_service_11 or no_support_service_10: raise_failure( 'The specified value of {} is not supported'.format( - HTTP_X_TAXII_SERVICES), - in_response_to) + HTTP_X_TAXII_SERVICES + ), + in_response_to, + ) no_support_content_type_11 = ( - taxii_content_type == VID_TAXII_XML_11 and not supports_taxii_11) + taxii_content_type == VID_TAXII_XML_11 and not supports_taxii_11 + ) no_support_content_type_10 = ( - taxii_content_type == VID_TAXII_XML_10 and not supports_taxii_10) + taxii_content_type == VID_TAXII_XML_10 and not supports_taxii_10 + ) if no_support_content_type_11 or no_support_content_type_10: raise_failure( 'The specified value of X-TAXII-Content-Type is not supported', - in_response_to) + in_response_to, + ) no_support_accept_11 = ( - taxii_accept == VID_TAXII_XML_11 and not supports_taxii_11) + taxii_accept == VID_TAXII_XML_11 and not supports_taxii_11 + ) no_support_accept_10 = ( - taxii_accept == VID_TAXII_XML_10 and not supports_taxii_10) + taxii_accept == VID_TAXII_XML_10 and not supports_taxii_10 + ) if taxii_accept and (no_support_accept_11 or no_support_accept_10): raise_failure( - "The specified value of X-TAXII-Accept is not supported", - in_response_to) + "The specified value of X-TAXII-Accept is not supported", in_response_to + ) return True @@ -78,7 +94,8 @@ def verify_message_is_supported(cls, taxii_message): if taxii_message.__class__ not in cls.supported_request_messages: raise raise_failure( "TAXII Message not supported by Message Handler", - taxii_message.message_id) + taxii_message.message_id, + ) @classmethod def handle_message(cls, service, request): diff --git a/opentaxii/taxii/services/handlers/collection_information_request_handlers.py b/opentaxii/taxii/services/handlers/collection_information_request_handlers.py index 1d387165..a012edce 100644 --- a/opentaxii/taxii/services/handlers/collection_information_request_handlers.py +++ b/opentaxii/taxii/services/handlers/collection_information_request_handlers.py @@ -1,14 +1,10 @@ - -from .base_handlers import BaseMessageHandler - -import libtaxii.messages_11 as tm11 import libtaxii.messages_10 as tm10 +import libtaxii.messages_11 as tm11 from opentaxii.taxii.exceptions import raise_failure -from ...converters import ( - collection_to_feedcollection_information -) +from ...converters import collection_to_feedcollection_information +from .base_handlers import BaseMessageHandler class CollectionInformationRequest11Handler(BaseMessageHandler): @@ -19,11 +15,13 @@ class CollectionInformationRequest11Handler(BaseMessageHandler): def handle_message(cls, service, request): response = tm11.CollectionInformationResponse( - message_id=cls.generate_id(), in_response_to=request.message_id) + message_id=cls.generate_id(), in_response_to=request.message_id + ) for collection in service.advertised_collections: coll = collection_to_feedcollection_information( - service, collection, version=11) + service, collection, version=11 + ) response.collection_informations.append(coll) return response @@ -37,12 +35,13 @@ class FeedInformationRequest10Handler(BaseMessageHandler): def handle_message(cls, service, request): response = tm10.FeedInformationResponse( - message_id=cls.generate_id(), - in_response_to=request.message_id) + message_id=cls.generate_id(), in_response_to=request.message_id + ) for collection in service.advertised_collections: feed = collection_to_feedcollection_information( - service, collection, version=10) + service, collection, version=10 + ) response.feed_informations.append(feed) return response @@ -51,17 +50,19 @@ def handle_message(cls, service, request): class CollectionInformationRequestHandler(BaseMessageHandler): supported_request_messages = [ - tm10.FeedInformationRequest, tm11.CollectionInformationRequest] + tm10.FeedInformationRequest, + tm11.CollectionInformationRequest, + ] @classmethod def handle_message(cls, service, request): if isinstance(request, tm10.FeedInformationRequest): - return FeedInformationRequest10Handler.handle_message( - service, request) + return FeedInformationRequest10Handler.handle_message(service, request) if isinstance(request, tm11.CollectionInformationRequest): return CollectionInformationRequest11Handler.handle_message( - service, request) + service, request + ) raise_failure( - "TAXII Message not supported by message handler", - request.message_id) + "TAXII Message not supported by message handler", request.message_id + ) diff --git a/opentaxii/taxii/services/handlers/discovery_request_handlers.py b/opentaxii/taxii/services/handlers/discovery_request_handlers.py index 1590bbf6..9b21c633 100644 --- a/opentaxii/taxii/services/handlers/discovery_request_handlers.py +++ b/opentaxii/taxii/services/handlers/discovery_request_handlers.py @@ -1,9 +1,8 @@ +import libtaxii.messages_10 as tm10 +import libtaxii.messages_11 as tm11 -from .base_handlers import BaseMessageHandler from ...exceptions import raise_failure - -import libtaxii.messages_11 as tm11 -import libtaxii.messages_10 as tm10 +from .base_handlers import BaseMessageHandler class DiscoveryRequest11Handler(BaseMessageHandler): @@ -13,8 +12,7 @@ class DiscoveryRequest11Handler(BaseMessageHandler): @classmethod def handle_message(cls, service, request): - response = tm11.DiscoveryResponse( - cls.generate_id(), request.message_id) + response = tm11.DiscoveryResponse(cls.generate_id(), request.message_id) for service in service.advertised_services: service_instances = service.to_service_instances(version=11) response.service_instances.extend(service_instances) @@ -29,8 +27,7 @@ class DiscoveryRequest10Handler(BaseMessageHandler): @classmethod def handle_message(cls, service, request): - response = tm10.DiscoveryResponse( - cls.generate_id(), request.message_id) + response = tm10.DiscoveryResponse(cls.generate_id(), request.message_id) for service in service.advertised_services: service_instances = service.to_service_instances(version=10) diff --git a/opentaxii/taxii/services/handlers/inbox_message_handlers.py b/opentaxii/taxii/services/handlers/inbox_message_handlers.py index b92def38..c5ce3535 100644 --- a/opentaxii/taxii/services/handlers/inbox_message_handlers.py +++ b/opentaxii/taxii/services/handlers/inbox_message_handlers.py @@ -1,15 +1,14 @@ -import structlog - -import libtaxii.messages_11 as tm11 import libtaxii.messages_10 as tm10 +import libtaxii.messages_11 as tm11 +import structlog from libtaxii.constants import ST_SUCCESS -from .base_handlers import BaseMessageHandler -from ...exceptions import raise_failure from ...converters import ( + content_block_to_content_block_entity, inbox_message_to_inbox_message_entity, - content_block_to_content_block_entity ) +from ...exceptions import raise_failure +from .base_handlers import BaseMessageHandler log = structlog.getLogger(__name__) @@ -22,51 +21,60 @@ class InboxMessage11Handler(BaseMessageHandler): def handle_message(cls, service, request): collections = service.validate_destination_collection_names( - request.destination_collection_names, request.message_id) + request.destination_collection_names, request.message_id + ) inbox_message = service.server.persistence.create_inbox_message( inbox_message_to_inbox_message_entity( - request, service_id=service.id, version=11)) + request, service_id=service.id, version=11 + ) + ) for content_block in request.content_blocks: is_supported = service.is_content_supported( - content_block.content_binding, version=11) + content_block.content_binding, version=11 + ) # FIXME: is it correct to skip unsupported content blocks? # 3.2 Inbox Exchange # version1.1/TAXII_Services_Specification.pdf if not is_supported: - log.warning("Content binding is not supported: {}" - .format(content_block.content_binding)) + log.warning( + "Content binding is not supported: {}".format( + content_block.content_binding + ) + ) continue correct_binding_collections = [ - c for c in collections - if c.is_content_supported(content_block.content_binding)] + c + for c in collections + if c.is_content_supported(content_block.content_binding) + ] if not correct_binding_collections: # There's nothing to add this content block to log.warning( "No accessible collection that support " - "binding {} were found" - .format(content_block.content_binding)) + "binding {} were found".format(content_block.content_binding) + ) continue - block = content_block_to_content_block_entity( - content_block, version=11) + block = content_block_to_content_block_entity(content_block, version=11) service.server.persistence.create_content( block, collections=correct_binding_collections, service_id=service.id, - inbox_message_id=inbox_message.id if inbox_message else None) + inbox_message_id=inbox_message.id if inbox_message else None, + ) # Create and return a Status Message indicating success status_message = tm11.StatusMessage( message_id=cls.generate_id(), in_response_to=request.message_id, - status_type=ST_SUCCESS + status_type=ST_SUCCESS, ) return status_message @@ -83,29 +91,36 @@ def handle_message(cls, service, request): inbox_message = service.server.persistence.create_inbox_message( inbox_message_to_inbox_message_entity( - request, service_id=service.id, version=10)) + request, service_id=service.id, version=10 + ) + ) for content_block in request.content_blocks: is_supported = service.is_content_supported( - content_block.content_binding, version=10) + content_block.content_binding, version=10 + ) if not is_supported: - log.warning("Content block binding is not supported: {}" - .format(content_block.content_binding)) + log.warning( + "Content block binding is not supported: {}".format( + content_block.content_binding + ) + ) continue - block = content_block_to_content_block_entity( - content_block, version=10) + block = content_block_to_content_block_entity(content_block, version=10) service.server.persistence.create_content( - block, collections=collections, + block, + collections=collections, service_id=service.id, - inbox_message_id=inbox_message.id if inbox_message else None) + inbox_message_id=inbox_message.id if inbox_message else None, + ) status_message = tm10.StatusMessage( message_id=cls.generate_id(), in_response_to=request.message_id, - status_type=ST_SUCCESS + status_type=ST_SUCCESS, ) return status_message diff --git a/opentaxii/taxii/services/handlers/poll_fulfilment_request_handlers.py b/opentaxii/taxii/services/handlers/poll_fulfilment_request_handlers.py index 81b55084..de6e59eb 100644 --- a/opentaxii/taxii/services/handlers/poll_fulfilment_request_handlers.py +++ b/opentaxii/taxii/services/handlers/poll_fulfilment_request_handlers.py @@ -1,10 +1,7 @@ import libtaxii.messages_11 as tm11 -from libtaxii.constants import ( - ST_NOT_FOUND, SD_ITEM -) +from libtaxii.constants import SD_ITEM, ST_NOT_FOUND from ...exceptions import StatusMessageException, raise_failure - from .base_handlers import BaseMessageHandler from .poll_request_handlers import PollRequest11Handler, retrieve_collection @@ -25,13 +22,15 @@ def handle_message(cls, service, request): result_set = service.get_result_set(result_id) collection = retrieve_collection( - 11, service, collection_name, in_response_to=request.message_id) + 11, service, collection_name, in_response_to=request.message_id + ) if not result_set or result_set.collection_id != collection.id: raise StatusMessageException( ST_NOT_FOUND, in_response_to=request.message_id, - status_details={SD_ITEM: result_id}) + status_details={SD_ITEM: result_id}, + ) response = PollRequest11Handler.prepare_poll_response( service=service, @@ -42,7 +41,8 @@ def handle_message(cls, service, request): result_part=part_number, allow_async=True, return_content=True, - result_id=result_id) + result_id=result_id, + ) return response @@ -60,5 +60,6 @@ def handle_message(cls, service, request): service=service, request=request, ) - raise_failure("TAXII Message not supported by message handler", - request.message_id) + raise_failure( + "TAXII Message not supported by message handler", request.message_id + ) diff --git a/opentaxii/taxii/services/handlers/poll_request_handlers.py b/opentaxii/taxii/services/handlers/poll_request_handlers.py index 73b71d7d..9f6dc121 100644 --- a/opentaxii/taxii/services/handlers/poll_request_handlers.py +++ b/opentaxii/taxii/services/handlers/poll_request_handlers.py @@ -1,16 +1,27 @@ import libtaxii.messages_10 as tm10 import libtaxii.messages_11 as tm11 import structlog -from libtaxii.constants import (RT_FULL, SD_ESTIMATED_WAIT, SD_ITEM, - SD_RESULT_ID, SD_SUPPORTED_CONTENT, - SD_WILL_PUSH, ST_DENIED, ST_NOT_FOUND, - ST_PENDING, ST_UNSUPPORTED_CONTENT_BINDING) +from libtaxii.constants import ( + RT_FULL, + SD_ESTIMATED_WAIT, + SD_ITEM, + SD_RESULT_ID, + SD_SUPPORTED_CONTENT, + SD_WILL_PUSH, + ST_DENIED, + ST_NOT_FOUND, + ST_PENDING, + ST_UNSUPPORTED_CONTENT_BINDING, +) + from opentaxii.local import context from ....persistence.exceptions import ResultsNotReady -from ...converters import (content_binding_entities_to_content_bindings, - content_block_entity_to_content_block, - parse_content_bindings) +from ...converters import ( + content_binding_entities_to_content_bindings, + content_block_entity_to_content_block, + parse_content_bindings, +) from ...exceptions import FailureStatus, StatusMessageException, raise_failure from ...utils import get_utc_now from .base_handlers import BaseMessageHandler @@ -22,13 +33,13 @@ def retrieve_subscription(version, service, subscription_id, in_response_to): subscription = service.get_subscription(subscription_id) if not subscription: message = "Requested subscription was not found" - details = ( - {SD_ITEM: subscription_id} if version == 11 else subscription_id) + details = {SD_ITEM: subscription_id} if version == 11 else subscription_id raise StatusMessageException( ST_NOT_FOUND, message=message, in_response_to=in_response_to, - status_details=details) + status_details=details, + ) return subscription @@ -40,12 +51,14 @@ def retrieve_collection(version, service, collection_name, in_response_to): ST_NOT_FOUND, message="Requested collection was not found", in_response_to=in_response_to, - status_details=details) + status_details=details, + ) if not collection.available: raise FailureStatus( message="The collection is not available", in_response_to=in_response_to, - status_details=details) + status_details=details, + ) return collection @@ -60,25 +73,30 @@ def handle_message(cls, service, request): raise StatusMessageException( ST_DENIED, message="Subscription id is required", - in_response_to=request.message_id) + in_response_to=request.message_id, + ) if request.subscription_id and request.poll_parameters: message = "Both subscription ID and Poll Parameters present" - log.warning(message, service_id=service.id, - subscription_id=request.subscription_id) + log.warning( + message, service_id=service.id, subscription_id=request.subscription_id + ) collection = retrieve_collection( - 11, service, request.collection_name, request.message_id) + 11, service, request.collection_name, request.message_id + ) if request.subscription_id: subscription = retrieve_subscription( - 11, service, request.subscription_id, request.message_id) + 11, service, request.subscription_id, request.message_id + ) if collection.id != subscription.collection_id: raise StatusMessageException( ST_NOT_FOUND, status_details={SD_ITEM: request.collection_name}, - in_response_to=request.message_id) + in_response_to=request.message_id, + ) content_bindings = subscription.params.content_bindings response_type = subscription.params.response_type @@ -87,18 +105,18 @@ def handle_message(cls, service, request): else: params = request.poll_parameters raw_bindings = params.content_bindings - requested_bindings = parse_content_bindings( - raw_bindings, version=11) - content_bindings = collection.get_matching_bindings( - requested_bindings) + requested_bindings = parse_content_bindings(raw_bindings, version=11) + content_bindings = collection.get_matching_bindings(requested_bindings) if requested_bindings and not content_bindings: supported = content_binding_entities_to_content_bindings( - collection.supported_content, version=11) + collection.supported_content, version=11 + ) raise StatusMessageException( ST_UNSUPPORTED_CONTENT_BINDING, in_response_to=request.message_id, - status_details={SD_SUPPORTED_CONTENT: supported}) + status_details={SD_SUPPORTED_CONTENT: supported}, + ) response_type = params.response_type allow_async = params.allow_asynch @@ -106,8 +124,10 @@ def handle_message(cls, service, request): end = request.inclusive_end_timestamp_label if (start and end) and (start > end): - message = ("Exclusive begin timestamp label is later " - "than inclusive end timestamp label") + message = ( + "Exclusive begin timestamp label is later " + "than inclusive end timestamp label" + ) raise_failure(message, request.message_id) return cls.prepare_poll_response( @@ -118,13 +138,23 @@ def handle_message(cls, service, request): in_response_to=request.message_id, allow_async=allow_async, return_content=(response_type == RT_FULL), - subscription_id=request.subscription_id) + subscription_id=request.subscription_id, + ) @classmethod def prepare_poll_response( - cls, service, collection, in_response_to, timeframe=None, - content_bindings=None, result_part=1, allow_async=False, - return_content=True, result_id=None, subscription_id=None): + cls, + service, + collection, + in_response_to, + timeframe=None, + content_bindings=None, + result_part=1, + allow_async=False, + return_content=True, + result_id=None, + subscription_id=None, + ): timeframe = timeframe or (None, None) try: @@ -132,23 +162,26 @@ def prepare_poll_response( collection, timeframe=timeframe, content_bindings=content_bindings, - part_number=result_part) + part_number=result_part, + ) except ResultsNotReady: if not allow_async: - message = ("The content is not available now and " - "the request has allow_asynch set to false") - raise_failure( - message=message, in_response_to=in_response_to) + message = ( + "The content is not available now and " + "the request has allow_asynch set to false" + ) + raise_failure(message=message, in_response_to=in_response_to) result_set = service.create_result_set( - collection, timeframe=timeframe, - content_bindings=content_bindings) + collection, timeframe=timeframe, content_bindings=content_bindings + ) if not result_set: raise StatusMessageException( ST_DENIED, message="Poll fulfilment is not supported", - in_response_to=in_response_to) + in_response_to=in_response_to, + ) return tm11.StatusMessage( message_id=service.generate_id(), @@ -157,7 +190,9 @@ def prepare_poll_response( status_detail={ SD_ESTIMATED_WAIT: service.wait_time, SD_RESULT_ID: result_set.id, - SD_WILL_PUSH: service.can_push}) + SD_WILL_PUSH: service.can_push, + }, + ) # TODO: temporary fix, pending: # https://github.com/TAXIIProject/libtaxii/issues/191 @@ -166,12 +201,11 @@ def prepare_poll_response( if context.server.servers.taxii1.config['count_blocks_in_poll_responses']: # dividing instead of multiplying to be safe from overflow total_count = service.get_content_blocks_count( - collection, timeframe=timeframe, - content_bindings=content_bindings) - has_more = ( - (float(total_count) / service.max_result_size) > result_part) + collection, timeframe=timeframe, content_bindings=content_bindings + ) + has_more = (float(total_count) / service.max_result_size) > result_part capped_count = min(service.max_result_count, total_count) - is_partial = (capped_count < total_count) + is_partial = capped_count < total_count else: has_more = len(content_blocks) == service.max_result_size capped_count = None @@ -179,9 +213,8 @@ def prepare_poll_response( if has_more and not result_id: result_set = service.create_result_set( - collection, - timeframe=timeframe, - content_bindings=content_bindings) + collection, timeframe=timeframe, content_bindings=content_bindings + ) result_id = result_set.id response = tm11.PollResponse( @@ -198,13 +231,16 @@ def prepare_poll_response( record_count=( tm11.RecordCount(int(capped_count), is_partial) if capped_count is not None - else None), - subscription_id=subscription_id) + else None + ), + subscription_id=subscription_id, + ) if return_content: for block in content_blocks: response.content_blocks.append( - content_block_entity_to_content_block(block, version=11)) + content_block_entity_to_content_block(block, version=11) + ) return response @@ -217,50 +253,60 @@ class PollRequest10Handler(BaseMessageHandler): def handle_message(cls, service, request): collection = retrieve_collection( - 10, service, request.feed_name, request.message_id) + 10, service, request.feed_name, request.message_id + ) if request.subscription_id: subscription = retrieve_subscription( - 10, service, request.subscription_id, request.message_id) + 10, service, request.subscription_id, request.message_id + ) if collection.id != subscription.collection_id: details = {SD_ITEM: request.collection_name} raise StatusMessageException( - ST_NOT_FOUND, status_details=details, - in_response_to=request.message_id) + ST_NOT_FOUND, + status_details=details, + in_response_to=request.message_id, + ) content_bindings = subscription.params.content_bindings else: requested_bindings = parse_content_bindings( - request.content_bindings, version=10) + request.content_bindings, version=10 + ) - content_bindings = collection.get_matching_bindings( - requested_bindings) + content_bindings = collection.get_matching_bindings(requested_bindings) if requested_bindings and not content_bindings: - supported_bindings = ( - content_binding_entities_to_content_bindings( - collection.supported_content, version=10)) + supported_bindings = content_binding_entities_to_content_bindings( + collection.supported_content, version=10 + ) details = {SD_SUPPORTED_CONTENT: supported_bindings} raise StatusMessageException( ST_UNSUPPORTED_CONTENT_BINDING, in_response_to=request.message_id, - status_details=details) + status_details=details, + ) # Only Data Feeds existed in TAXII 1.0 if collection.type != collection.TYPE_FEED: - message = ("The Named Data Collection is not a Data Feed, " - "it is a Data Set. Only Data Feeds can be polled " - "in TAXII 1.0") + message = ( + "The Named Data Collection is not a Data Feed, " + "it is a Data Set. Only Data Feeds can be polled " + "in TAXII 1.0" + ) raise StatusMessageException( ST_NOT_FOUND, message=message, status_details={SD_ITEM: request.feed_name}, - in_response_to=request.message_id) + in_response_to=request.message_id, + ) - start, end = (request.exclusive_begin_timestamp_label, - request.inclusive_end_timestamp_label) + start, end = ( + request.exclusive_begin_timestamp_label, + request.inclusive_end_timestamp_label, + ) end_response = end or get_utc_now() @@ -268,19 +314,19 @@ def handle_message(cls, service, request): message_id=service.generate_id(), in_response_to=request.message_id, feed_name=collection.name, - # FIXME: exclusive/inclusive clash inclusive_begin_timestamp_label=start, - inclusive_end_timestamp_label=end_response) + inclusive_end_timestamp_label=end_response, + ) content_blocks = service.get_content_blocks( - collection, - timeframe=(start, end), - content_bindings=content_bindings) + collection, timeframe=(start, end), content_bindings=content_bindings + ) for block in content_blocks: response.content_blocks.append( - content_block_entity_to_content_block(block, version=10)) + content_block_entity_to_content_block(block, version=10) + ) return response diff --git a/opentaxii/taxii/services/handlers/subscription_request_handlers.py b/opentaxii/taxii/services/handlers/subscription_request_handlers.py index af4d07aa..fd100f66 100644 --- a/opentaxii/taxii/services/handlers/subscription_request_handlers.py +++ b/opentaxii/taxii/services/handlers/subscription_request_handlers.py @@ -1,22 +1,24 @@ -import structlog - -import libtaxii.messages_11 as tm11 import libtaxii.messages_10 as tm10 +import libtaxii.messages_11 as tm11 +import structlog from libtaxii.constants import ( - SD_SUPPORTED_CONTENT, ST_UNSUPPORTED_CONTENT_BINDING, - SD_ITEM, ST_NOT_FOUND, ST_BAD_MESSAGE, - ACT_SUBSCRIBE, ACT_UNSUBSCRIBE, ACT_PAUSE, - ACT_RESUME, ACT_STATUS, ACT_TYPES_11, - ACT_TYPES_10 + ACT_PAUSE, + ACT_RESUME, + ACT_STATUS, + ACT_SUBSCRIBE, + ACT_TYPES_10, + ACT_TYPES_11, + ACT_UNSUBSCRIBE, + SD_ITEM, + SD_SUPPORTED_CONTENT, + ST_BAD_MESSAGE, + ST_NOT_FOUND, + ST_UNSUPPORTED_CONTENT_BINDING, ) -from ...exceptions import StatusMessageException, raise_failure -from ...converters import ( - subscription_to_subscription_instance, - parse_content_bindings -) +from ...converters import parse_content_bindings, subscription_to_subscription_instance from ...entities import PollRequestParametersEntity, SubscriptionEntity - +from ...exceptions import StatusMessageException, raise_failure from .base_handlers import BaseMessageHandler from .poll_request_handlers import retrieve_collection @@ -38,11 +40,10 @@ def action_subscribe(request, service, collection, version, **kwargs): supported_contents = [] else: requested_bindings = parse_content_bindings( - params.content_bindings, - version=version) + params.content_bindings, version=version + ) - supported_contents = \ - collection.get_matching_bindings(requested_bindings) + supported_contents = collection.get_matching_bindings(requested_bindings) if requested_bindings and not supported_contents: supported = collection.get_supported_content(version=version) @@ -50,7 +51,8 @@ def action_subscribe(request, service, collection, version, **kwargs): raise StatusMessageException( ST_UNSUPPORTED_CONTENT_BINDING, in_response_to=request.message_id, - status_details=details) + status_details=details, + ) else: supported_contents = [] @@ -67,7 +69,7 @@ def action_subscribe(request, service, collection, version, **kwargs): service_id=service.id, collection_id=collection.id, poll_request_params=poll_request_params, - status=SubscriptionEntity.ACTIVE + status=SubscriptionEntity.ACTIVE, ) return service.create_subscription(subscription) @@ -88,7 +90,8 @@ def action_unsubscribe(request, service, subscription, **kwargs): collection_id=None, service_id=service.id, subscription_id=request.subscription_id, - status=SubscriptionEntity.UNSUBSCRIBED) + status=SubscriptionEntity.UNSUBSCRIBED, + ) def action_status(service, subscription, **kwargs): @@ -131,7 +134,7 @@ def action_resume(service, subscription, **kwargs): ACT_UNSUBSCRIBE: action_unsubscribe, ACT_PAUSE: action_pause, ACT_RESUME: action_resume, - ACT_STATUS: action_status + ACT_STATUS: action_status, } @@ -147,8 +150,10 @@ def validate_request(cls, request, subscription): if action not in ACT_TYPES_11: error_message = "The specified action was invalid" - elif action in (ACT_UNSUBSCRIBE, ACT_PAUSE, ACT_RESUME) \ - and not request.subscription_id: + elif ( + action in (ACT_UNSUBSCRIBE, ACT_PAUSE, ACT_RESUME) + and not request.subscription_id + ): error_message = 'Action "%s" requires a subscription id' % action else: @@ -156,19 +161,18 @@ def validate_request(cls, request, subscription): if error_message: raise StatusMessageException( - ST_BAD_MESSAGE, - message=error_message, - in_response_to=request.message_id) + ST_BAD_MESSAGE, message=error_message, in_response_to=request.message_id + ) - if (not subscription and ( - action in (ACT_PAUSE, ACT_RESUME) or - (action == ACT_STATUS and request.subscription_id))): + if not subscription and ( + action in (ACT_PAUSE, ACT_RESUME) + or (action == ACT_STATUS and request.subscription_id) + ): details = {SD_ITEM: request.subscription_id} raise StatusMessageException( - ST_NOT_FOUND, - status_details=details, - in_response_to=request.message_id) + ST_NOT_FOUND, status_details=details, in_response_to=request.message_id + ) @classmethod def handle_message(cls, service, request): @@ -181,13 +185,15 @@ def handle_message(cls, service, request): cls.validate_request(request, subscription) collection = retrieve_collection( - 11, service, request.collection_name, request.message_id) + 11, service, request.collection_name, request.message_id + ) if subscription and subscription.collection_id != collection.id: raise StatusMessageException( ST_NOT_FOUND, status_details={SD_ITEM: request.collection_name}, - in_response_to=request.message_id) + in_response_to=request.message_id, + ) response = tm11.ManageCollectionSubscriptionResponse( message_id=cls.generate_id(), @@ -197,8 +203,12 @@ def handle_message(cls, service, request): ) result = ACTIONS[request.action]( - service=service, request=request, collection=collection, - subscription=subscription, version=11) + service=service, + request=request, + collection=collection, + subscription=subscription, + version=11, + ) if isinstance(result, (list, tuple)): results = result @@ -212,7 +222,7 @@ def handle_message(cls, service, request): subscription=_result, polling_services=polling_services, version=11, - subscription_parameters=_result.params + subscription_parameters=_result.params, ) response.subscription_instances.append(instance) @@ -237,9 +247,8 @@ def validate_request(cls, request): if error_message: raise StatusMessageException( - ST_BAD_MESSAGE, - message=error_message, - in_response_to=request.message_id) + ST_BAD_MESSAGE, message=error_message, in_response_to=request.message_id + ) @classmethod def handle_message(cls, service, request): @@ -247,7 +256,8 @@ def handle_message(cls, service, request): cls.validate_request(request) collection = retrieve_collection( - 10, service, request.feed_name, request.message_id) + 10, service, request.feed_name, request.message_id + ) if request.subscription_id: subscription = service.get_subscription(request.subscription_id) @@ -258,7 +268,8 @@ def handle_message(cls, service, request): raise StatusMessageException( ST_NOT_FOUND, status_details=request.feed_name, - in_response_to=request.message_id) + in_response_to=request.message_id, + ) response = tm10.ManageFeedSubscriptionResponse( message_id=cls.generate_id(), @@ -272,7 +283,8 @@ def handle_message(cls, service, request): request=request, collection=collection, subscription=subscription, - version=10) + version=10, + ) if not isinstance(results, (list, tuple)): results = [results] @@ -281,9 +293,8 @@ def handle_message(cls, service, request): for _result in results: instance = subscription_to_subscription_instance( - subscription=_result, - polling_services=polling_services, - version=10) + subscription=_result, polling_services=polling_services, version=10 + ) response.subscription_instances.append(instance) return response @@ -291,17 +302,17 @@ def handle_message(cls, service, request): class SubscriptionRequestHandler(BaseMessageHandler): - supported_request_messages = [tm11.ManageCollectionSubscriptionRequest, - tm10.ManageFeedSubscriptionRequest] + supported_request_messages = [ + tm11.ManageCollectionSubscriptionRequest, + tm10.ManageFeedSubscriptionRequest, + ] @classmethod def handle_message(cls, service, request): if isinstance(request, tm10.ManageFeedSubscriptionRequest): - return SubscriptionRequest10Handler.handle_message( - service, request) + return SubscriptionRequest10Handler.handle_message(service, request) if isinstance(request, tm11.ManageCollectionSubscriptionRequest): - return SubscriptionRequest11Handler.handle_message( - service, request) + return SubscriptionRequest11Handler.handle_message(service, request) raise_failure( - "TAXII Message not supported by message handler", - request.message_id) + "TAXII Message not supported by message handler", request.message_id + ) diff --git a/opentaxii/taxii/services/inbox.py b/opentaxii/taxii/services/inbox.py index 85c23421..a3e3d3c8 100644 --- a/opentaxii/taxii/services/inbox.py +++ b/opentaxii/taxii/services/inbox.py @@ -1,21 +1,22 @@ - from libtaxii.constants import ( - SVC_INBOX, MSG_INBOX_MESSAGE, SD_ACCEPTABLE_DESTINATION, - ST_DESTINATION_COLLECTION_ERROR, ST_NOT_FOUND, SD_ITEM + MSG_INBOX_MESSAGE, + SD_ACCEPTABLE_DESTINATION, + SD_ITEM, + ST_DESTINATION_COLLECTION_ERROR, + ST_NOT_FOUND, + SVC_INBOX, ) -from opentaxii.local import context from opentaxii.exceptions import UnauthorizedException - -from ..utils import is_content_supported -from ..entities import ContentBindingEntity -from ..exceptions import StatusMessageException +from opentaxii.local import context from ..converters import ( content_binding_entities_to_content_bindings, - service_to_service_instances + service_to_service_instances, ) - +from ..entities import ContentBindingEntity +from ..exceptions import StatusMessageException +from ..utils import is_content_supported from .abstract import TAXIIService from .handlers import InboxMessageHandler @@ -24,25 +25,26 @@ class InboxService(TAXIIService): service_type = SVC_INBOX - handlers = { - MSG_INBOX_MESSAGE: InboxMessageHandler - } + handlers = {MSG_INBOX_MESSAGE: InboxMessageHandler} destination_collection_required = False accept_all_content = False supported_content = [] - def __init__(self, accept_all_content=False, - destination_collection_required=False, - supported_content=None, **kwargs): + def __init__( + self, + accept_all_content=False, + destination_collection_required=False, + supported_content=None, + **kwargs, + ): super(InboxService, self).__init__(**kwargs) self.accept_all_content = accept_all_content supported_content = supported_content or [] - self.supported_content = [ - ContentBindingEntity(c) for c in supported_content] + self.supported_content = [ContentBindingEntity(c) for c in supported_content] self.destination_collection_required = destination_collection_required @@ -50,7 +52,8 @@ def is_content_supported(self, content_binding, version=None): if self.accept_all_content: return True return is_content_supported( - self.supported_content, content_binding, version=version) + self.supported_content, content_binding, version=version + ) def get_destination_collections(self): return self.server.persistence.get_collections(self.id) @@ -59,24 +62,31 @@ def validate_destination_collection_names(self, name_list, in_response_to): name_list = name_list or [] - if ((self.destination_collection_required and not name_list) - or (not self.destination_collection_required and name_list)): + if (self.destination_collection_required and not name_list) or ( + not self.destination_collection_required and name_list + ): if not name_list: - message = ('A Destination_Collection_Name is required ' - 'and none were specified') + message = ( + 'A Destination_Collection_Name is required ' + 'and none were specified' + ) else: - message = ('Destination_Collection_Names are prohibited ' - 'for this Inbox Service') + message = ( + 'Destination_Collection_Names are prohibited ' + 'for this Inbox Service' + ) details = { SD_ACCEPTABLE_DESTINATION: [ - c.name for c in self.get_destination_collections() - if c.available]} + c.name for c in self.get_destination_collections() if c.available + ] + } raise StatusMessageException( ST_DESTINATION_COLLECTION_ERROR, message=message, in_response_to=in_response_to, - status_details=details) + status_details=details, + ) # If we reach this point and name_list is empty, # self.destination_collection_required must be False @@ -84,8 +94,7 @@ def validate_destination_collection_names(self, name_list, in_response_to): name_list = [c.name for c in self.get_destination_collections()] collections = [] - destinations_map = { - c.name: c for c in self.get_destination_collections()} + destinations_map = {c.name: c for c in self.get_destination_collections()} for name in name_list: if name in destinations_map: @@ -94,14 +103,15 @@ def validate_destination_collection_names(self, name_list, in_response_to): collections.append(collection) else: raise UnauthorizedException( - message=('User can not write to collection {}' - .format(name))) + message=('User can not write to collection {}'.format(name)) + ) else: raise StatusMessageException( ST_NOT_FOUND, message='Collection {} was not found'.format(name), in_response_to=in_response_to, - extended_headers={SD_ITEM: name}) + extended_headers={SD_ITEM: name}, + ) return collections @@ -113,8 +123,9 @@ def to_service_instances(self, version): return service_instances for instance in service_instances: - instance.inbox_service_accepted_content = ( - self.get_supported_content(version)) + instance.inbox_service_accepted_content = self.get_supported_content( + version + ) return service_instances @@ -124,4 +135,5 @@ def get_supported_content(self, version): return [] return content_binding_entities_to_content_bindings( - self.supported_content, version) + self.supported_content, version + ) diff --git a/opentaxii/taxii/services/poll.py b/opentaxii/taxii/services/poll.py index e2146114..1841fa31 100644 --- a/opentaxii/taxii/services/poll.py +++ b/opentaxii/taxii/services/poll.py @@ -1,12 +1,9 @@ import structlog - -from libtaxii.constants import ( - MSG_POLL_REQUEST, MSG_POLL_FULFILLMENT_REQUEST, SVC_POLL -) +from libtaxii.constants import MSG_POLL_FULFILLMENT_REQUEST, MSG_POLL_REQUEST, SVC_POLL from ..entities import ResultSetEntity from .abstract import TAXIIService -from .handlers import PollRequestHandler, PollFulfilmentRequestHandler +from .handlers import PollFulfilmentRequestHandler, PollRequestHandler log = structlog.getLogger(__name__) @@ -18,7 +15,7 @@ class PollService(TAXIIService): handlers = { MSG_POLL_REQUEST: PollRequestHandler, - MSG_POLL_FULFILLMENT_REQUEST: PollFulfilmentRequestHandler + MSG_POLL_FULFILLMENT_REQUEST: PollFulfilmentRequestHandler, } service_type = SVC_POLL @@ -34,20 +31,25 @@ class PollService(TAXIIService): max_result_size = None max_result_count = None - def __init__(self, subscription_required=False, max_result_size=-1, - max_result_count=-1, **kwargs): + def __init__( + self, + subscription_required=False, + max_result_size=-1, + max_result_count=-1, + **kwargs, + ): super(PollService, self).__init__(**kwargs) self.subscription_required = subscription_required self.max_result_size = ( - max_result_size if max_result_size >= 0 - else DEFAULT_MAX_RESULT_SIZE) + max_result_size if max_result_size >= 0 else DEFAULT_MAX_RESULT_SIZE + ) self.max_result_count = ( - max_result_count if max_result_count >= 0 - else DEFAULT_MAX_RESULT_COUNT) + max_result_count if max_result_count >= 0 else DEFAULT_MAX_RESULT_COUNT + ) def get_collection(self, name): return self.server.persistence.get_collection(name, self.id) @@ -60,17 +62,19 @@ def get_offset_limit(self, part_number): return offset, limit def get_content_blocks_count( - self, collection, timeframe=None, content_bindings=None): + self, collection, timeframe=None, content_bindings=None + ): start_time, end_time = timeframe or (None, None) return self.server.persistence.get_content_blocks_count( collection_id=collection.id, start_time=start_time, end_time=end_time, - bindings=content_bindings) + bindings=content_bindings, + ) def get_content_blocks( - self, collection, timeframe=None, content_bindings=None, - part_number=1): + self, collection, timeframe=None, content_bindings=None, part_number=1 + ): start_time, end_time = timeframe or (None, None) offset, limit = self.get_offset_limit(part_number) return self.server.persistence.get_content_blocks( @@ -79,16 +83,16 @@ def get_content_blocks( end_time=end_time, bindings=content_bindings, offset=offset, - limit=limit) + limit=limit, + ) - def create_result_set(self, collection, content_bindings=None, - timeframe=None): + def create_result_set(self, collection, content_bindings=None, timeframe=None): entity = ResultSetEntity( id=self.generate_id(), collection_id=collection.id, content_bindings=content_bindings, - timeframe=timeframe + timeframe=timeframe, ) return self.server.persistence.create_result_set(entity) diff --git a/opentaxii/taxii/status.py b/opentaxii/taxii/status.py index a731b44a..b7ce8a0b 100644 --- a/opentaxii/taxii/status.py +++ b/opentaxii/taxii/status.py @@ -1,16 +1,14 @@ -import libtaxii.messages_11 as tm11 import libtaxii.messages_10 as tm10 +import libtaxii.messages_11 as tm11 from libtaxii.common import generate_message_id - from libtaxii.constants import ( - VID_TAXII_XML_11, VID_TAXII_XML_10, - VID_TAXII_SERVICES_10, VID_TAXII_SERVICES_11 + VID_TAXII_SERVICES_10, + VID_TAXII_SERVICES_11, + VID_TAXII_XML_10, + VID_TAXII_XML_11, ) -from .http import ( - HTTP_X_TAXII_ACCEPT, HTTP_X_TAXII_CONTENT_TYPE, - get_http_headers -) +from .http import HTTP_X_TAXII_ACCEPT, HTTP_X_TAXII_CONTENT_TYPE, get_http_headers def process_status_exception(exception, headers, is_secure): @@ -45,7 +43,8 @@ def exception_to_status(exception, format_version): extended_headers=exception.extended_headers, status_type=exception.status_type, status_detail=exception.status_details, - message=exception.message) + message=exception.message, + ) if format_version == VID_TAXII_XML_11: sm = tm11.StatusMessage(**data) elif format_version == VID_TAXII_XML_10: diff --git a/opentaxii/taxii/utils.py b/opentaxii/taxii/utils.py index 9ab12e90..c8b1a070 100644 --- a/opentaxii/taxii/utils.py +++ b/opentaxii/taxii/utils.py @@ -2,11 +2,11 @@ import pytz import structlog -from lxml import etree from libtaxii.common import set_xml_parser +from lxml import etree -from .exceptions import BadMessageStatus from .bindings import MESSAGE_VALIDATOR_PARSER +from .exceptions import BadMessageStatus log = structlog.getLogger(__name__) @@ -25,12 +25,14 @@ def is_content_supported(supported_bindings, content_binding, version=None): # FIXME: may be not the best option subtype = ( - content_binding.subtype_ids[0] if content_binding.subtype_ids - else None) + content_binding.subtype_ids[0] if content_binding.subtype_ids else None + ) matches = [ - ((supported.binding == binding_id) and - (not supported.subtypes or subtype in supported.subtypes)) + ( + (supported.binding == binding_id) + and (not supported.subtypes or subtype in supported.subtypes) + ) for supported in supported_bindings ] @@ -47,8 +49,10 @@ def parse_message(content_type, body, do_validate=True): if not result.valid: errors = '; '.join([str(err) for err in result.error_log]) raise BadMessageStatus( - 'Request was not schema valid: "{}" for content type "{}"' - .format(errors, content_type)) + 'Request was not schema valid: "{}" for content type "{}"'.format( + errors, content_type + ) + ) except etree.XMLSyntaxError as e: log.error("Invalid XML received", exc_info=True) raise BadMessageStatus('Request was invalid XML', e=e) @@ -63,34 +67,37 @@ def configure_libtaxii_xml_parser(huge_tree=False): Set custom XML parser as a default libtaxii parser ''' # set XML parser in libraxii right away - set_xml_parser(etree.XMLParser( - # inject default attributes from DTD or XMLSchema - attribute_defaults=False, - # validate against a DTD referenced by the document - dtd_validation=False, - # use DTD for parsing - load_dtd=False, - # prevent network access for related files (default: True) - no_network=True, - # clean up redundant namespace declarations - ns_clean=True, - # try hard to parse through broken XML - recover=False, - # discard blank text nodes that appear ignorable - remove_blank_text=False, - # discard comments - remove_comments=False, - # discard processing instructions - remove_pis=False, - # replace CDATA sections by normal text content (default: True) - strip_cdata=True, - # save memory for short text content (default: True) - compact=True, - # use a hash table of XML IDs for fast access - # (default: True, always True with DTD validation) - collect_ids=True, - # replace entities by their text value (default: True) - resolve_entities=False, - # enable/disable security restrictions and support very deep - # trees and very long text content - huge_tree=huge_tree)) + set_xml_parser( + etree.XMLParser( + # inject default attributes from DTD or XMLSchema + attribute_defaults=False, + # validate against a DTD referenced by the document + dtd_validation=False, + # use DTD for parsing + load_dtd=False, + # prevent network access for related files (default: True) + no_network=True, + # clean up redundant namespace declarations + ns_clean=True, + # try hard to parse through broken XML + recover=False, + # discard blank text nodes that appear ignorable + remove_blank_text=False, + # discard comments + remove_comments=False, + # discard processing instructions + remove_pis=False, + # replace CDATA sections by normal text content (default: True) + strip_cdata=True, + # save memory for short text content (default: True) + compact=True, + # use a hash table of XML IDs for fast access + # (default: True, always True with DTD validation) + collect_ids=True, + # replace entities by their text value (default: True) + resolve_entities=False, + # enable/disable security restrictions and support very deep + # trees and very long text content + huge_tree=huge_tree, + ) + ) diff --git a/opentaxii/taxii2/entities.py b/opentaxii/taxii2/entities.py index e660f225..2da976c1 100644 --- a/opentaxii/taxii2/entities.py +++ b/opentaxii/taxii2/entities.py @@ -1,4 +1,5 @@ """Taxii2 entities.""" + from datetime import datetime from typing import List, NamedTuple, Optional diff --git a/opentaxii/taxii2/exceptions.py b/opentaxii/taxii2/exceptions.py index a6bc44a6..8039514e 100644 --- a/opentaxii/taxii2/exceptions.py +++ b/opentaxii/taxii2/exceptions.py @@ -1,5 +1,4 @@ -from marshmallow.exceptions import \ - ValidationError as MarshmallowValidationError +from marshmallow.exceptions import ValidationError as MarshmallowValidationError class ValidationError(MarshmallowValidationError): diff --git a/opentaxii/taxii2/http.py b/opentaxii/taxii2/http.py index a0408d33..eb0a82ba 100644 --- a/opentaxii/taxii2/http.py +++ b/opentaxii/taxii2/http.py @@ -1,11 +1,14 @@ """Taxii2 http helper functions.""" + import json from typing import Dict, Optional from flask import Response, make_response -def make_taxii2_response(data, status: Optional[int] = 200, extra_headers: Optional[Dict] = None) -> Response: +def make_taxii2_response( + data, status: Optional[int] = 200, extra_headers: Optional[Dict] = None +) -> Response: """Turn input data into valid taxii2 response.""" if not isinstance(data, str): data = json.dumps(data) diff --git a/opentaxii/taxii2/utils.py b/opentaxii/taxii2/utils.py index 2893fbcb..d4598556 100644 --- a/opentaxii/taxii2/utils.py +++ b/opentaxii/taxii2/utils.py @@ -1,4 +1,5 @@ """Utility functions for taxii2.""" + import datetime DATETIMEFORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" diff --git a/opentaxii/taxii2/validation.py b/opentaxii/taxii2/validation.py index 07a8e3b5..4c7c6903 100644 --- a/opentaxii/taxii2/validation.py +++ b/opentaxii/taxii2/validation.py @@ -1,15 +1,17 @@ """Taxii2 validation functions.""" + import datetime import json from marshmallow import Schema, fields -from opentaxii.persistence.api import OpenTAXII2PersistenceAPI -from opentaxii.taxii2.exceptions import ValidationError -from opentaxii.taxii2.utils import DATETIMEFORMAT from stix2 import parse from stix2.exceptions import STIXError from werkzeug.datastructures import ImmutableMultiDict +from opentaxii.persistence.api import OpenTAXII2PersistenceAPI +from opentaxii.taxii2.exceptions import ValidationError +from opentaxii.taxii2.utils import DATETIMEFORMAT + def validate_envelope(json_data: str, allow_custom: bool = False) -> None: """ @@ -85,6 +87,7 @@ def _deserialize(self, value, attr, data, **kwargs): class PersistenceApiMxin: """Store persistence api on schema instance, to reference in `Taxii2Next`""" + def __init__(self, persistence_api: OpenTAXII2PersistenceAPI, *args, **kwargs): self.persistence_api = persistence_api super().__init__(*args, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index d25ff27c..439585a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,3 +10,11 @@ ignore-semiprivate = true omit-covered-files = true verbose = 0 exclude = ["tests"] + +[tool.black] +line-length = 88 +skip_string_normalization = true +target_version = ["py310", "py311"] + +[tool.isort] +profile = "black" diff --git a/requirements-dev.txt b/requirements-dev.txt index 16a1ba0f..f4f574a3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,4 +5,6 @@ pytest-pythonpath flake8 ipdb factory-boy>=3.2.1 +black==25.11.0 +isort==7.0.0 -r requirements-interrogate.txt diff --git a/tests/conftest.py b/tests/conftest.py index 56192d9a..e9ef3054 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,18 +5,30 @@ import pytest from flask.testing import FlaskClient + from opentaxii.config import ServerConfig from opentaxii.local import context, release_context from opentaxii.middleware import create_app -from opentaxii.persistence.sqldb.taxii2models import (ApiRoot, Collection, Job, - JobDetail, STIXObject) +from opentaxii.persistence.sqldb.taxii2models import ( + ApiRoot, + Collection, + Job, + JobDetail, + STIXObject, +) from opentaxii.server import TAXIIServer from opentaxii.taxii.converters import dict_to_service_entity from opentaxii.taxii.http import HTTP_AUTHORIZATION from opentaxii.utils import configure_logging - -from tests.fixtures import (ACCOUNT, COLLECTIONS_B, DOMAIN, PASSWORD, SERVICES, - USERNAME, VALID_TOKEN) +from tests.fixtures import ( + ACCOUNT, + COLLECTIONS_B, + DOMAIN, + PASSWORD, + SERVICES, + USERNAME, + VALID_TOKEN, +) from tests.taxii2.utils import API_ROOTS, COLLECTIONS, JOBS, STIX_OBJECTS @@ -48,7 +60,6 @@ def dbconn(): except FileNotFoundError: pass - elif DBTYPE in ("mysql", "mariadb"): import MySQLdb @@ -61,7 +72,6 @@ def dbconn(): port = 3307 yield f"mysql+mysqldb://root:@127.0.0.1:{port}/test?charset=utf8" - elif DBTYPE == "postgres": import platform @@ -75,7 +85,6 @@ def dbconn(): def dbconn(): yield "postgresql+psycopg2://test:test@127.0.0.1:5432/test" - else: raise NotImplementedError(f"dbtype {DBTYPE} not supported") @@ -135,7 +144,7 @@ def anonymous_user(): def clean_db(dbconn): # drop and recreate db to provide clean state at beginning if DBTYPE == "sqlite": - filename = dbconn[len("sqlite:///"):] + filename = dbconn[len("sqlite:///") :] os.remove(filename) elif DBTYPE == "postgres": with psycopg2.connect( @@ -207,7 +216,7 @@ def transaction_app(dbconn, taxiiserver): connections.append(connection) sessions.append(manager.api.db.session) yield app - for (transaction, connection, session, manager) in zip( + for transaction, connection, session, manager in zip( transactions, connections, sessions, managers ): transaction.rollback() @@ -270,14 +279,17 @@ def authenticated_client(client): } client.headers = headers client.account = ACCOUNT - with patch.object( - client.application.taxii_server.auth.api, - "authenticate", - side_effect=MOCK_AUTHENTICATE, - ), patch.object( - client.application.taxii_server.auth.api, - "get_account", - side_effect=MOCK_GET_ACCOUNT, + with ( + patch.object( + client.application.taxii_server.auth.api, + "authenticate", + side_effect=MOCK_AUTHENTICATE, + ), + patch.object( + client.application.taxii_server.auth.api, + "get_account", + side_effect=MOCK_GET_ACCOUNT, + ), ): yield client diff --git a/tests/fixtures.py b/tests/fixtures.py index 6fae04bd..c0246caa 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,7 +1,7 @@ from uuid import uuid4 -from libtaxii.constants import (CB_STIX_XML_111, VID_TAXII_HTTP_10, - VID_TAXII_HTTPS_10) +from libtaxii.constants import CB_STIX_XML_111, VID_TAXII_HTTP_10, VID_TAXII_HTTPS_10 + from opentaxii.entities import Account from opentaxii.taxii import entities @@ -17,7 +17,7 @@ destination_collection_required=False, address='/relative/path/inbox-a', accept_all_content=True, - protocol_bindings=PROTOCOL_BINDINGS + protocol_bindings=PROTOCOL_BINDINGS, ) INBOX_B = dict( @@ -27,7 +27,7 @@ destination_collection_required='yes', address='/relative/path/inbox-b', supported_content=[CB_STIX_XML_111, CUSTOM_CONTENT_BINDING], - protocol_bindings=PROTOCOL_BINDINGS + protocol_bindings=PROTOCOL_BINDINGS, ) DISCOVERY_A = dict( @@ -36,9 +36,14 @@ description='discovery-A description', address='/relative/path/discovery-a', advertised_services=[ - 'inbox-A', 'inbox-B', 'discovery-A', 'discovery-B', - 'collection-management-A', 'poll-A'], - protocol_bindings=PROTOCOL_BINDINGS + 'inbox-A', + 'inbox-B', + 'discovery-A', + 'discovery-B', + 'collection-management-A', + 'poll-A', + ], + protocol_bindings=PROTOCOL_BINDINGS, ) DISCOVERY_B = dict( @@ -46,7 +51,7 @@ type='discovery', description='External discovery-B service', address='http://something.com/absolute/path/discovery-b', - protocol_bindings=[VID_TAXII_HTTP_10] + protocol_bindings=[VID_TAXII_HTTP_10], ) SUBSCRIPTION_MESSAGE = 'message about subscription' @@ -57,7 +62,7 @@ description='Collection management description', address='/relative/path/collection-management', protocol_bindings=PROTOCOL_BINDINGS, - subscription_message=SUBSCRIPTION_MESSAGE + subscription_message=SUBSCRIPTION_MESSAGE, ) POLL_RESULT_SIZE = 20 @@ -70,13 +75,12 @@ address='/relative/path/poll', protocol_bindings=PROTOCOL_BINDINGS, max_result_size=POLL_RESULT_SIZE, - max_result_count=POLL_MAX_COUNT + max_result_count=POLL_MAX_COUNT, ) DOMAIN = 'www.some-example.local' -INTERNAL_SERVICES = [ - INBOX_A, INBOX_B, DISCOVERY_A, COLLECTION_MANAGEMENT, POLL] +INTERNAL_SERVICES = [INBOX_A, INBOX_B, DISCOVERY_A, COLLECTION_MANAGEMENT, POLL] SERVICES = INTERNAL_SERVICES + [DISCOVERY_B] INSTANCES_CONFIGURED = sum(len(s['protocol_bindings']) for s in SERVICES) @@ -85,8 +89,7 @@ CONTENT = 'some-content' CONTENT_BINDINGS_ONLY_STIX = [CB_STIX_XML_111] -CONTENT_BINDINGS_STIX_AND_CUSTOM = ( - CONTENT_BINDINGS_ONLY_STIX + [CUSTOM_CONTENT_BINDING]) +CONTENT_BINDINGS_STIX_AND_CUSTOM = CONTENT_BINDINGS_ONLY_STIX + [CUSTOM_CONTENT_BINDING] CONTENT_BINDING_SUBTYPE = 'custom-subtype' MESSAGE = 'test-message' @@ -99,35 +102,33 @@ COLLECTIONS_A = [ - entities.CollectionEntity(**x) for x in - [{ - 'name': COLLECTION_OPEN, - 'available': True, - 'accept_all_content': True - }] + entities.CollectionEntity(**x) + for x in [{'name': COLLECTION_OPEN, 'available': True, 'accept_all_content': True}] ] COLLECTIONS_B = [ - entities.CollectionEntity(**x) for x in - [{ - 'name': COLLECTION_OPEN, - 'available': True, - 'accept_all_content': True, - 'type': entities.CollectionEntity.TYPE_SET - }, { - 'name': COLLECTION_ONLY_STIX, - 'available': True, - 'accept_all_content': False, - 'supported_content': CONTENT_BINDINGS_ONLY_STIX - }, { - 'name': COLLECTION_STIX_AND_CUSTOM, - 'available': True, - 'accept_all_content': False, - 'supported_content': CONTENT_BINDINGS_STIX_AND_CUSTOM - }, { - 'name': COLLECTION_DISABLED, - 'available': False - }] + entities.CollectionEntity(**x) + for x in [ + { + 'name': COLLECTION_OPEN, + 'available': True, + 'accept_all_content': True, + 'type': entities.CollectionEntity.TYPE_SET, + }, + { + 'name': COLLECTION_ONLY_STIX, + 'available': True, + 'accept_all_content': False, + 'supported_content': CONTENT_BINDINGS_ONLY_STIX, + }, + { + 'name': COLLECTION_STIX_AND_CUSTOM, + 'available': True, + 'accept_all_content': False, + 'supported_content': CONTENT_BINDINGS_STIX_AND_CUSTOM, + }, + {'name': COLLECTION_DISABLED, 'available': False}, + ] ] USERNAME = "some-username" diff --git a/tests/services/test_collection_management.py b/tests/services/test_collection_management.py index 61dd5d82..0ebb065b 100644 --- a/tests/services/test_collection_management.py +++ b/tests/services/test_collection_management.py @@ -1,21 +1,30 @@ import pytest -from fixtures import (COLLECTION_DISABLED, COLLECTION_ONLY_STIX, - COLLECTION_OPEN, COLLECTION_STIX_AND_CUSTOM, - COLLECTIONS_B, MESSAGE_ID, SERVICES) -from opentaxii.taxii import entities +from fixtures import ( + COLLECTION_DISABLED, + COLLECTION_ONLY_STIX, + COLLECTION_OPEN, + COLLECTION_STIX_AND_CUSTOM, + COLLECTIONS_B, + MESSAGE_ID, + SERVICES, +) from utils import as_tm, persist_content, prepare_headers +from opentaxii.taxii import entities + ASSIGNED_SERVICES = ['collection-management-A', 'inbox-A', 'inbox-B', 'poll-A'] ASSIGNED_INBOX_INSTANCES = sum( len(s['protocol_bindings']) for s in SERVICES - if s['id'] in ASSIGNED_SERVICES and s['id'].startswith('inbox')) + if s['id'] in ASSIGNED_SERVICES and s['id'].startswith('inbox') +) ASSIGNED_SUBSCTRIPTION_INSTANCES = sum( len(s['protocol_bindings']) for s in SERVICES - if s['id'] in ASSIGNED_SERVICES and s['id'].startswith('collection-')) + if s['id'] in ASSIGNED_SERVICES and s['id'].startswith('collection-') +) @pytest.fixture(autouse=True) @@ -23,7 +32,8 @@ def prepare_server(server, services): for coll in COLLECTIONS_B: coll = server.servers.taxii1.persistence.create_collection(coll) server.servers.taxii1.persistence.set_collection_services( - coll.id, service_ids=ASSIGNED_SERVICES) + coll.id, service_ids=ASSIGNED_SERVICES + ) def prepare_request(version): @@ -47,8 +57,7 @@ def test_collections(server, version, https): names = [c.name for c in COLLECTIONS_B] if version == 11: - assert isinstance( - response, as_tm(version).CollectionInformationResponse) + assert isinstance(response, as_tm(version).CollectionInformationResponse) assert len(response.collection_informations) == len(COLLECTIONS_B) for c in response.collection_informations: @@ -111,18 +120,18 @@ def test_collection_supported_content(server, version, https): def get_coll(name): return next( - c for c in response.collection_informations - if c.collection_name == name) + c for c in response.collection_informations if c.collection_name == name + ) assert ( - get_coll(COLLECTION_OPEN).collection_type == - entities.CollectionEntity.TYPE_SET) + get_coll(COLLECTION_OPEN).collection_type + == entities.CollectionEntity.TYPE_SET + ) else: + def get_coll(name): - return next( - c for c in response.feed_informations - if c.feed_name == name) + return next(c for c in response.feed_informations if c.feed_name == name) assert len(get_coll(COLLECTION_OPEN).supported_contents) == 0 @@ -146,8 +155,10 @@ def test_collections_volume(server, https): response = service.process(headers, request) collection = next( - c for c in response.collection_informations - if c.collection_name == COLLECTION_OPEN) + c + for c in response.collection_informations + if c.collection_name == COLLECTION_OPEN + ) assert collection.collection_volume == 0 @@ -160,8 +171,10 @@ def test_collections_volume(server, https): response = service.process(headers, request) collection = next( - c for c in response.collection_informations - if c.collection_name == COLLECTION_OPEN) + c + for c in response.collection_informations + if c.collection_name == COLLECTION_OPEN + ) assert collection.collection_volume == blocks_amount diff --git a/tests/services/test_discovery.py b/tests/services/test_discovery.py index 0286ef7a..ee472bda 100644 --- a/tests/services/test_discovery.py +++ b/tests/services/test_discovery.py @@ -38,9 +38,7 @@ def test_content_bindings_present(server, version, https): assert len(response.service_instances) == INSTANCES_CONFIGURED assert response.in_response_to == MESSAGE_ID - inboxes = [ - s for s in response.service_instances - if s.service_type == SVC_INBOX] + inboxes = [s for s in response.service_instances if s.service_type == SVC_INBOX] assert len(inboxes) == 4 diff --git a/tests/services/test_inbox.py b/tests/services/test_inbox.py index da61850e..3370c53c 100644 --- a/tests/services/test_inbox.py +++ b/tests/services/test_inbox.py @@ -1,14 +1,22 @@ import pytest -from fixtures import (COLLECTION_ONLY_STIX, COLLECTION_OPEN, COLLECTIONS_A, - COLLECTIONS_B, CONTENT, CONTENT_BINDING_SUBTYPE, - CUSTOM_CONTENT_BINDING, INVALID_CONTENT_BINDING, - MESSAGE_ID) +from fixtures import ( + COLLECTION_ONLY_STIX, + COLLECTION_OPEN, + COLLECTIONS_A, + COLLECTIONS_B, + CONTENT, + CONTENT_BINDING_SUBTYPE, + CUSTOM_CONTENT_BINDING, + INVALID_CONTENT_BINDING, + MESSAGE_ID, +) from libtaxii import messages_10 as tm10 from libtaxii import messages_11 as tm11 from libtaxii.constants import CB_STIX_XML_111, ST_SUCCESS -from opentaxii.taxii import exceptions from utils import as_tm, prepare_headers +from opentaxii.taxii import exceptions + def make_content( version, content_binding=CUSTOM_CONTENT_BINDING, content=CONTENT, subtype=None diff --git a/tests/services/test_poll.py b/tests/services/test_poll.py index eb9adea0..38a620c5 100644 --- a/tests/services/test_poll.py +++ b/tests/services/test_poll.py @@ -1,15 +1,21 @@ import pytest -from fixtures import (COLLECTION_DISABLED, COLLECTION_ONLY_STIX, - COLLECTION_OPEN, COLLECTION_STIX_AND_CUSTOM, - COLLECTIONS_B, CUSTOM_CONTENT_BINDING, MESSAGE_ID, - POLL_MAX_COUNT, POLL_RESULT_SIZE) +from fixtures import ( + COLLECTION_DISABLED, + COLLECTION_ONLY_STIX, + COLLECTION_OPEN, + COLLECTION_STIX_AND_CUSTOM, + COLLECTIONS_B, + CUSTOM_CONTENT_BINDING, + MESSAGE_ID, + POLL_MAX_COUNT, + POLL_RESULT_SIZE, +) from libtaxii import messages_10 as tm10 from libtaxii import messages_11 as tm11 -from libtaxii.constants import (ACT_SUBSCRIBE, CB_STIX_XML_111, RT_COUNT_ONLY, - RT_FULL) +from libtaxii.constants import ACT_SUBSCRIBE, CB_STIX_XML_111, RT_COUNT_ONLY, RT_FULL +from utils import as_tm, persist_content, prepare_headers, prepare_subscription_request + from opentaxii.taxii import exceptions -from utils import (as_tm, persist_content, prepare_headers, - prepare_subscription_request) @pytest.fixture(autouse=True) @@ -18,12 +24,14 @@ def prepare_server(server, services): for coll in COLLECTIONS_B: coll = server.servers.taxii1.persistence.create_collection(coll) server.servers.taxii1.persistence.set_collection_services( - coll.id, service_ids=services) + coll.id, service_ids=services + ) return server -def prepare_request(collection_name, version, count_only=False, - bindings=[], subscription_id=None): +def prepare_request( + collection_name, version, count_only=False, bindings=[], subscription_id=None +): if version == 11: content_bindings = [tm11.ContentBinding(b) for b in bindings] @@ -31,14 +39,14 @@ def prepare_request(collection_name, version, count_only=False, poll_parameters = None else: poll_parameters = tm11.PollParameters( - response_type=( - RT_FULL if not count_only else RT_COUNT_ONLY), - content_bindings=content_bindings) + response_type=(RT_FULL if not count_only else RT_COUNT_ONLY), + content_bindings=content_bindings, + ) return tm11.PollRequest( message_id=MESSAGE_ID, collection_name=collection_name, subscription_id=subscription_id, - poll_parameters=poll_parameters + poll_parameters=poll_parameters, ) elif version == 10: content_bindings = bindings @@ -46,7 +54,7 @@ def prepare_request(collection_name, version, count_only=False, message_id=MESSAGE_ID, feed_name=collection_name, content_bindings=content_bindings, - subscription_id=subscription_id + subscription_id=subscription_id, ) @@ -55,14 +63,19 @@ def prepare_fulfilment_request(collection_name, result_id, part_number): message_id=MESSAGE_ID, collection_name=collection_name, result_id=result_id, - result_part_number=part_number + result_part_number=part_number, ) -@pytest.mark.parametrize(("https", "version", "count_blocks"), [ - (True, 11, True), (False, 11, False), - (True, 10, False), (False, 10, False), -]) +@pytest.mark.parametrize( + ("https", "version", "count_blocks"), + [ + (True, 11, True), + (False, 11, False), + (True, 10, False), + (False, 10, False), + ], +) def test_poll_empty_response(server, version, https, count_blocks): server.servers.taxii1.config['count_blocks_in_poll_responses'] = count_blocks @@ -70,8 +83,7 @@ def test_poll_empty_response(server, version, https, count_blocks): service = server.servers.taxii1.get_service('poll-A') headers = prepare_headers(version, https) - request = prepare_request( - collection_name=COLLECTION_OPEN, version=version) + request = prepare_request(collection_name=COLLECTION_OPEN, version=version) if version == 11: response = service.process(headers, request) @@ -93,15 +105,14 @@ def test_poll_empty_response(server, version, https, count_blocks): @pytest.mark.parametrize( - ("https", "version"), - [(True, 11), (False, 11), (True, 10), (False, 10)]) + ("https", "version"), [(True, 11), (False, 11), (True, 10), (False, 10)] +) def test_poll_collection_not_available(server, version, https): service = server.servers.taxii1.get_service('poll-A') headers = prepare_headers(version, https) - request = prepare_request( - collection_name=COLLECTION_DISABLED, version=version) + request = prepare_request(collection_name=COLLECTION_DISABLED, version=version) with pytest.raises(exceptions.StatusMessageException): service.process(headers, request) @@ -113,14 +124,17 @@ def test_poll_get_content(server, version, https): service = server.servers.taxii1.get_service('poll-A') original = persist_content( - server.servers.taxii1.persistence, COLLECTION_ONLY_STIX, - service.id, binding=CB_STIX_XML_111) + server.servers.taxii1.persistence, + COLLECTION_ONLY_STIX, + service.id, + binding=CB_STIX_XML_111, + ) # wrong collection headers = prepare_headers(version, https) request = prepare_request( - collection_name=COLLECTION_STIX_AND_CUSTOM, - version=version) + collection_name=COLLECTION_STIX_AND_CUSTOM, version=version + ) response = service.process(headers, request) @@ -129,9 +143,7 @@ def test_poll_get_content(server, version, https): # right collection headers = prepare_headers(version, https) - request = prepare_request( - collection_name=COLLECTION_ONLY_STIX, - version=version) + request = prepare_request(collection_name=COLLECTION_ONLY_STIX, version=version) response = service.process(headers, request) @@ -147,7 +159,9 @@ def test_poll_get_content(server, version, https): headers = prepare_headers(version, https) request = prepare_request( collection_name=COLLECTION_ONLY_STIX, - version=version, bindings=[CUSTOM_CONTENT_BINDING]) + version=version, + bindings=[CUSTOM_CONTENT_BINDING], + ) with pytest.raises(exceptions.StatusMessageException): service.process(headers, request) @@ -155,7 +169,8 @@ def test_poll_get_content(server, version, https): @pytest.mark.parametrize( ("https", "count_blocks"), - [(True, True), (False, True), (True, False), (False, False)]) + [(True, True), (False, True), (True, False), (False, False)], +) def test_poll_get_content_count(server, https, count_blocks): version = 11 server.servers.taxii1.config['count_blocks_in_poll_responses'] = count_blocks @@ -170,7 +185,8 @@ def test_poll_get_content_count(server, https, count_blocks): # count-only request request = prepare_request( - collection_name=COLLECTION_OPEN, count_only=True, version=version) + collection_name=COLLECTION_OPEN, count_only=True, version=version + ) response = service.process(headers, request) assert isinstance(response, tm11.PollResponse) @@ -186,7 +202,8 @@ def test_poll_get_content_count(server, https, count_blocks): @pytest.mark.parametrize( ("https", "count_blocks"), - [(True, True), (False, True), (True, False), (False, False)]) + [(True, True), (False, True), (True, False), (False, False)], +) def test_poll_max_count_max_size(server, https, count_blocks): version = 11 @@ -202,8 +219,9 @@ def test_poll_max_count_max_size(server, https, count_blocks): headers = prepare_headers(version, https) # count-only request - request = prepare_request(collection_name=COLLECTION_OPEN, - count_only=True, version=version) + request = prepare_request( + collection_name=COLLECTION_OPEN, count_only=True, version=version + ) response = service.process(headers, request) assert isinstance(response, tm11.PollResponse) @@ -233,7 +251,8 @@ def test_poll_max_count_max_size(server, https, count_blocks): @pytest.mark.parametrize( ("https", "count_blocks"), - [(True, True), (False, True), (True, False), (False, False)]) + [(True, True), (False, True), (True, False), (False, False)], +) def test_poll_fulfilment_request(server, https, count_blocks): server.servers.taxii1.config['count_blocks_in_poll_responses'] = count_blocks version = 11 @@ -265,8 +284,7 @@ def test_poll_fulfilment_request(server, https, count_blocks): # poll fullfilment request result_id = response.result_id part_number = 2 - request = prepare_fulfilment_request( - COLLECTION_OPEN, result_id, part_number) + request = prepare_fulfilment_request(COLLECTION_OPEN, result_id, part_number) response = service.process(headers, request) assert isinstance(response, tm11.PollResponse) @@ -285,8 +303,7 @@ def test_poll_fulfilment_request(server, https, count_blocks): # poll fullfilment request over the top result_id = response.result_id part_number = 3 - request = prepare_fulfilment_request( - COLLECTION_OPEN, result_id, part_number) + request = prepare_fulfilment_request(COLLECTION_OPEN, result_id, part_number) response = service.process(headers, request) assert isinstance(response, tm11.PollResponse) @@ -322,13 +339,12 @@ def test_subscribe_and_poll(server, version, https): params = dict( response_type=RT_COUNT_ONLY, - content_bindings=[CB_STIX_XML_111, CUSTOM_CONTENT_BINDING]) + content_bindings=[CB_STIX_XML_111, CUSTOM_CONTENT_BINDING], + ) subs_request = prepare_subscription_request( - collection=collection, - action=ACT_SUBSCRIBE, - version=version, - params=params) + collection=collection, action=ACT_SUBSCRIBE, version=version, params=params + ) subs_response = subs_service.process(headers, subs_request) @@ -343,7 +359,8 @@ def test_subscribe_and_poll(server, version, https): collection_name=collection, count_only=False, subscription_id=subscription.subscription_id, - version=version) + version=version, + ) poll_response = poll_service.process(headers, poll_request) diff --git a/tests/services/test_subscription_management.py b/tests/services/test_subscription_management.py index a4a89335..de2c8643 100644 --- a/tests/services/test_subscription_management.py +++ b/tests/services/test_subscription_management.py @@ -1,13 +1,26 @@ import pytest -from fixtures import (COLLECTION_OPEN, COLLECTIONS_B, CUSTOM_CONTENT_BINDING, - SUBSCRIPTION_MESSAGE) -from libtaxii.constants import (ACT_PAUSE, ACT_RESUME, ACT_SUBSCRIBE, - ACT_UNSUBSCRIBE, CB_STIX_XML_111, RT_FULL, - SS_ACTIVE, SS_PAUSED, SS_UNSUBSCRIBED) -from opentaxii.taxii import exceptions +from fixtures import ( + COLLECTION_OPEN, + COLLECTIONS_B, + CUSTOM_CONTENT_BINDING, + SUBSCRIPTION_MESSAGE, +) +from libtaxii.constants import ( + ACT_PAUSE, + ACT_RESUME, + ACT_SUBSCRIBE, + ACT_UNSUBSCRIBE, + CB_STIX_XML_111, + RT_FULL, + SS_ACTIVE, + SS_PAUSED, + SS_UNSUBSCRIBED, +) from utils import as_tm, prepare_headers from utils import prepare_subscription_request as prepare_request +from opentaxii.taxii import exceptions + ASSIGNED_SERVICES = ['collection-management-A', 'poll-A'] @@ -16,7 +29,8 @@ def prepare_server(server, services): for coll in COLLECTIONS_B: coll = server.servers.taxii1.persistence.create_collection(coll) server.servers.taxii1.persistence.set_collection_services( - coll.id, service_ids=ASSIGNED_SERVICES) + coll.id, service_ids=ASSIGNED_SERVICES + ) return server @@ -31,24 +45,20 @@ def test_subscribe(server, version, https): params = dict( response_type=RT_FULL, - content_bindings=[CB_STIX_XML_111, CUSTOM_CONTENT_BINDING] + content_bindings=[CB_STIX_XML_111, CUSTOM_CONTENT_BINDING], ) request = prepare_request( - collection=COLLECTION_OPEN, action=ACT_SUBSCRIBE, - version=version, params=params) + collection=COLLECTION_OPEN, action=ACT_SUBSCRIBE, version=version, params=params + ) response = service.process(headers, request) if version == 11: - assert isinstance( - response, - as_tm(version).ManageCollectionSubscriptionResponse) + assert isinstance(response, as_tm(version).ManageCollectionSubscriptionResponse) assert response.collection_name == COLLECTION_OPEN else: - assert isinstance( - response, - as_tm(version).ManageFeedSubscriptionResponse) + assert isinstance(response, as_tm(version).ManageFeedSubscriptionResponse) assert response.feed_name == COLLECTION_OPEN assert response.message == SUBSCRIPTION_MESSAGE @@ -60,22 +70,19 @@ def test_subscribe(server, version, https): # 1 poll service * 2 protocol bindings assert len(subs.poll_instances) == 2 - assert ( - subs.poll_instances[0].poll_address == - poll_service.get_absolute_address( - subs.poll_instances[0].poll_protocol)) + assert subs.poll_instances[0].poll_address == poll_service.get_absolute_address( + subs.poll_instances[0].poll_protocol + ) if version == 11: assert subs.status == SS_ACTIVE response_bindings = [ - b.binding_id - for b in subs.subscription_parameters.content_bindings] + b.binding_id for b in subs.subscription_parameters.content_bindings + ] assert response_bindings == params['content_bindings'] - assert ( - subs.subscription_parameters.response_type == - params['response_type']) + assert subs.subscription_parameters.response_type == params['response_type'] @pytest.mark.parametrize("https", [True, False]) @@ -89,19 +96,17 @@ def test_subscribe_pause_resume(server, https): params = dict( response_type=RT_FULL, - content_bindings=[CB_STIX_XML_111, CUSTOM_CONTENT_BINDING] + content_bindings=[CB_STIX_XML_111, CUSTOM_CONTENT_BINDING], ) # Subscribing request = prepare_request( - collection=COLLECTION_OPEN, action=ACT_SUBSCRIBE, - version=version, params=params) + collection=COLLECTION_OPEN, action=ACT_SUBSCRIBE, version=version, params=params + ) response = service.process(headers, request) - assert isinstance( - response, - as_tm(version).ManageCollectionSubscriptionResponse) + assert isinstance(response, as_tm(version).ManageCollectionSubscriptionResponse) assert response.collection_name == COLLECTION_OPEN assert len(response.subscription_instances) == 1 @@ -110,19 +115,21 @@ def test_subscribe_pause_resume(server, https): assert subs.status == SS_ACTIVE assert ( - server.servers.taxii1.persistence.get_subscription(subs.subscription_id).status == - SS_ACTIVE) + server.servers.taxii1.persistence.get_subscription(subs.subscription_id).status + == SS_ACTIVE + ) # Pausing request = prepare_request( - collection=COLLECTION_OPEN, action=ACT_PAUSE, - subscription_id=subs.subscription_id, version=version) + collection=COLLECTION_OPEN, + action=ACT_PAUSE, + subscription_id=subs.subscription_id, + version=version, + ) response = service.process(headers, request) - assert isinstance( - response, - as_tm(version).ManageCollectionSubscriptionResponse) + assert isinstance(response, as_tm(version).ManageCollectionSubscriptionResponse) assert response.collection_name == COLLECTION_OPEN assert len(response.subscription_instances) == 1 @@ -132,19 +139,21 @@ def test_subscribe_pause_resume(server, https): assert subs.subscription_id assert subs.status == SS_PAUSED assert ( - server.servers.taxii1.persistence.get_subscription(subs.subscription_id).status == - SS_PAUSED) + server.servers.taxii1.persistence.get_subscription(subs.subscription_id).status + == SS_PAUSED + ) # Resume request = prepare_request( - collection=COLLECTION_OPEN, action=ACT_RESUME, - subscription_id=subs.subscription_id, version=version) + collection=COLLECTION_OPEN, + action=ACT_RESUME, + subscription_id=subs.subscription_id, + version=version, + ) response = service.process(headers, request) - assert isinstance( - response, - as_tm(version).ManageCollectionSubscriptionResponse) + assert isinstance(response, as_tm(version).ManageCollectionSubscriptionResponse) assert response.collection_name == COLLECTION_OPEN assert len(response.subscription_instances) == 1 @@ -154,8 +163,9 @@ def test_subscribe_pause_resume(server, https): assert subs.subscription_id assert subs.status == SS_ACTIVE assert ( - server.servers.taxii1.persistence.get_subscription(subs.subscription_id).status == - SS_ACTIVE) + server.servers.taxii1.persistence.get_subscription(subs.subscription_id).status + == SS_ACTIVE + ) @pytest.mark.parametrize("https", [True, False]) @@ -168,13 +178,12 @@ def test_pause_resume_wrong_id(server, https): # Subscribing request = prepare_request( - collection=COLLECTION_OPEN, action=ACT_SUBSCRIBE, - version=version) + collection=COLLECTION_OPEN, action=ACT_SUBSCRIBE, version=version + ) response = service.process(headers, request) - assert isinstance( - response, as_tm(version).ManageCollectionSubscriptionResponse) + assert isinstance(response, as_tm(version).ManageCollectionSubscriptionResponse) assert response.collection_name == COLLECTION_OPEN assert len(response.subscription_instances) == 1 @@ -186,15 +195,21 @@ def test_pause_resume_wrong_id(server, https): # Pausing with wrong subscription ID with pytest.raises(exceptions.StatusMessageException): request = prepare_request( - collection=COLLECTION_OPEN, action=ACT_PAUSE, - subscription_id="RANDOM-WRONG-SUBSCRIPTION", version=version) + collection=COLLECTION_OPEN, + action=ACT_PAUSE, + subscription_id="RANDOM-WRONG-SUBSCRIPTION", + version=version, + ) response = service.process(headers, request) # Resuming with wrong subscription ID with pytest.raises(exceptions.StatusMessageException): request = prepare_request( - collection=COLLECTION_OPEN, action=ACT_RESUME, - subscription_id="RANDOM-WRONG-SUBSCRIPTION", version=version) + collection=COLLECTION_OPEN, + action=ACT_RESUME, + subscription_id="RANDOM-WRONG-SUBSCRIPTION", + version=version, + ) response = service.process(headers, request) @@ -207,13 +222,13 @@ def test_unsubscribe(server, version, https): params = dict( response_type=RT_FULL, - content_bindings=[CB_STIX_XML_111, CUSTOM_CONTENT_BINDING] + content_bindings=[CB_STIX_XML_111, CUSTOM_CONTENT_BINDING], ) # Subscribing request = prepare_request( - collection=COLLECTION_OPEN, action=ACT_SUBSCRIBE, - version=version, params=params) + collection=COLLECTION_OPEN, action=ACT_SUBSCRIBE, version=version, params=params + ) response = service.process(headers, request) @@ -228,8 +243,11 @@ def test_unsubscribe(server, version, https): # return valid response INVALID_ID = "RANDOM-WRONG-SUBSCRIPTION" request = prepare_request( - collection=COLLECTION_OPEN, action=ACT_UNSUBSCRIBE, - subscription_id=INVALID_ID, version=version) + collection=COLLECTION_OPEN, + action=ACT_UNSUBSCRIBE, + subscription_id=INVALID_ID, + version=version, + ) response = service.process(headers, request) assert len(response.subscription_instances) == 1 @@ -238,8 +256,11 @@ def test_unsubscribe(server, version, https): # Unsubscribing with valid subscription ID request = prepare_request( - collection=COLLECTION_OPEN, action=ACT_UNSUBSCRIBE, - subscription_id=subscription_id, version=version) + collection=COLLECTION_OPEN, + action=ACT_UNSUBSCRIBE, + subscription_id=subscription_id, + version=version, + ) response = service.process(headers, request) assert len(response.subscription_instances) == 1 @@ -250,5 +271,6 @@ def test_unsubscribe(server, version, https): assert subs.status == SS_UNSUBSCRIBED assert ( - server.servers.taxii1.persistence.get_subscription(subscription_id).status == - SS_UNSUBSCRIBED) + server.servers.taxii1.persistence.get_subscription(subscription_id).status + == SS_UNSUBSCRIBED + ) diff --git a/tests/taxii2/factories.py b/tests/taxii2/factories.py index 58d88794..02fec5c1 100644 --- a/tests/taxii2/factories.py +++ b/tests/taxii2/factories.py @@ -1,9 +1,11 @@ """Factories for taxii2 entities.""" + import datetime from uuid import uuid4 import factory import stix2 + from opentaxii.taxii2.entities import STIXObject diff --git a/tests/taxii2/test_taxii2_api_root.py b/tests/taxii2/test_taxii2_api_root.py index 80ea5b5f..3038424f 100644 --- a/tests/taxii2/test_taxii2_api_root.py +++ b/tests/taxii2/test_taxii2_api_root.py @@ -172,21 +172,25 @@ def test_api_root( expected_headers, expected_content, ): - with patch.object( - authenticated_client.application.taxii_server.servers.taxii2, - "config", - config_override_func( - authenticated_client.application.taxii_server.servers.taxii2.config + with ( + patch.object( + authenticated_client.application.taxii_server.servers.taxii2, + "config", + config_override_func( + authenticated_client.application.taxii_server.servers.taxii2.config + ), ), - ), patch.object( - authenticated_client.application.taxii_server.servers.taxii2.persistence.api, - "get_api_root", - side_effect=GET_API_ROOT_MOCK, - ), patch.object( - authenticated_client.application.taxii_server, - "servers", - server_mapping_override_func( - authenticated_client.application.taxii_server.servers + patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_api_root", + side_effect=GET_API_ROOT_MOCK, + ), + patch.object( + authenticated_client.application.taxii_server, + "servers", + server_mapping_override_func( + authenticated_client.application.taxii_server.servers + ), ), ): func = getattr(authenticated_client, method) diff --git a/tests/taxii2/test_taxii2_collection.py b/tests/taxii2/test_taxii2_collection.py index 061a6055..fef193b9 100644 --- a/tests/taxii2/test_taxii2_collection.py +++ b/tests/taxii2/test_taxii2_collection.py @@ -3,9 +3,14 @@ from uuid import uuid4 import pytest + from opentaxii.persistence.sqldb import taxii2models -from tests.taxii2.utils import (API_ROOTS, COLLECTIONS, GET_API_ROOT_MOCK, - GET_COLLECTION_MOCK) +from tests.taxii2.utils import ( + API_ROOTS, + COLLECTIONS, + GET_API_ROOT_MOCK, + GET_COLLECTION_MOCK, +) @pytest.mark.parametrize( @@ -169,27 +174,33 @@ def test_collection( expected_headers, expected_content, ): - with patch.object( - authenticated_client.application.taxii_server.servers.taxii2.persistence.api, - "get_api_root", - side_effect=GET_API_ROOT_MOCK, - ), patch.object( - authenticated_client.application.taxii_server.servers.taxii2.persistence.api, - "get_collection", - side_effect=GET_COLLECTION_MOCK, - ), patch.object( - authenticated_client.account, - "permissions", - { - COLLECTIONS[0].id: ["read"], - COLLECTIONS[1].id: ["write"], - COLLECTIONS[2].id: ["read", "write"], - COLLECTIONS[4].id: ["read", "write"], - COLLECTIONS[5].id: ["write"], - }, + with ( + patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_api_root", + side_effect=GET_API_ROOT_MOCK, + ), + patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_collection", + side_effect=GET_COLLECTION_MOCK, + ), + patch.object( + authenticated_client.account, + "permissions", + { + COLLECTIONS[0].id: ["read"], + COLLECTIONS[1].id: ["write"], + COLLECTIONS[2].id: ["read", "write"], + COLLECTIONS[4].id: ["read", "write"], + COLLECTIONS[5].id: ["write"], + }, + ), ): func = getattr(authenticated_client, method) - response = func(f"/taxii2/{api_root_id}/collections/{collection_id}/", headers=headers) + response = func( + f"/taxii2/{api_root_id}/collections/{collection_id}/", headers=headers + ) assert response.status_code == expected_status assert { key: response.headers.get(key) for key in expected_headers @@ -231,14 +242,17 @@ def test_collection_unauthenticated( expected_status_code = 401 else: expected_status_code = 405 - with patch.object( - client.application.taxii_server.servers.taxii2.persistence.api, - "get_api_root", - side_effect=GET_API_ROOT_MOCK, - ), patch.object( - client.application.taxii_server.servers.taxii2.persistence.api, - "get_collection", - side_effect=GET_COLLECTION_MOCK, + with ( + patch.object( + client.application.taxii_server.servers.taxii2.persistence.api, + "get_api_root", + side_effect=GET_API_ROOT_MOCK, + ), + patch.object( + client.application.taxii_server.servers.taxii2.persistence.api, + "get_collection", + side_effect=GET_COLLECTION_MOCK, + ), ): func = getattr(client, method) response = func( @@ -281,7 +295,15 @@ def test_collection_unauthenticated( ], ) def test_add_collection( - app, api_root_id, title, description, alias, is_public, is_public_write, db_api_roots, db_collections + app, + api_root_id, + title, + description, + alias, + is_public, + is_public_write, + db_api_roots, + db_collections, ): collection = app.taxii_server.servers.taxii2.persistence.api.add_collection( api_root_id=api_root_id, diff --git a/tests/taxii2/test_taxii2_collections.py b/tests/taxii2/test_taxii2_collections.py index e478361b..37f422ab 100644 --- a/tests/taxii2/test_taxii2_collections.py +++ b/tests/taxii2/test_taxii2_collections.py @@ -3,10 +3,16 @@ from uuid import uuid4 import pytest -from tests.taxii2.utils import (API_ROOTS, COLLECTIONS, GET_API_ROOT_MOCK, - GET_COLLECTIONS_MOCK, config_noop, - server_mapping_noop, - server_mapping_remove_fields) + +from tests.taxii2.utils import ( + API_ROOTS, + COLLECTIONS, + GET_API_ROOT_MOCK, + GET_COLLECTIONS_MOCK, + config_noop, + server_mapping_noop, + server_mapping_remove_fields, +) @pytest.mark.parametrize( @@ -189,36 +195,42 @@ def test_collections( expected_headers, expected_content, ): - with patch.object( - authenticated_client.application.taxii_server.servers.taxii2, - "config", - config_override_func( - authenticated_client.application.taxii_server.servers.taxii2.config + with ( + patch.object( + authenticated_client.application.taxii_server.servers.taxii2, + "config", + config_override_func( + authenticated_client.application.taxii_server.servers.taxii2.config + ), + ), + patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_api_root", + side_effect=GET_API_ROOT_MOCK, ), - ), patch.object( - authenticated_client.application.taxii_server.servers.taxii2.persistence.api, - "get_api_root", - side_effect=GET_API_ROOT_MOCK, - ), patch.object( - authenticated_client.application.taxii_server.servers.taxii2.persistence.api, - "get_collections", - side_effect=GET_COLLECTIONS_MOCK, - ), patch.object( - authenticated_client.application.taxii_server, - "servers", - server_mapping_override_func( - authenticated_client.application.taxii_server.servers + patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_collections", + side_effect=GET_COLLECTIONS_MOCK, + ), + patch.object( + authenticated_client.application.taxii_server, + "servers", + server_mapping_override_func( + authenticated_client.application.taxii_server.servers + ), + ), + patch.object( + authenticated_client.account, + "permissions", + { + COLLECTIONS[0].id: ["read"], + COLLECTIONS[1].id: ["write"], + COLLECTIONS[2].id: ["read", "write"], + COLLECTIONS[4].id: ["read", "write"], + COLLECTIONS[5].id: ["write"], + }, ), - ), patch.object( - authenticated_client.account, - "permissions", - { - COLLECTIONS[0].id: ["read"], - COLLECTIONS[1].id: ["write"], - COLLECTIONS[2].id: ["read", "write"], - COLLECTIONS[4].id: ["read", "write"], - COLLECTIONS[5].id: ["write"], - }, ): func = getattr(authenticated_client, method) response = func(f"/taxii2/{api_root_id}/collections/", headers=headers) @@ -255,14 +267,17 @@ def test_collections_unauthenticated( expected_status_code = 401 else: expected_status_code = 405 - with patch.object( - client.application.taxii_server.servers.taxii2.persistence.api, - "get_api_root", - side_effect=GET_API_ROOT_MOCK, - ), patch.object( - client.application.taxii_server.servers.taxii2.persistence.api, - "get_collections", - side_effect=GET_COLLECTIONS_MOCK, + with ( + patch.object( + client.application.taxii_server.servers.taxii2.persistence.api, + "get_api_root", + side_effect=GET_API_ROOT_MOCK, + ), + patch.object( + client.application.taxii_server.servers.taxii2.persistence.api, + "get_collections", + side_effect=GET_COLLECTIONS_MOCK, + ), ): func = getattr(client, method) response = func( diff --git a/tests/taxii2/test_taxii2_discovery.py b/tests/taxii2/test_taxii2_discovery.py index 157477e6..3e309d77 100644 --- a/tests/taxii2/test_taxii2_discovery.py +++ b/tests/taxii2/test_taxii2_discovery.py @@ -2,9 +2,13 @@ from unittest.mock import patch import pytest -from tests.taxii2.utils import (API_ROOTS_WITH_DEFAULT, - API_ROOTS_WITHOUT_DEFAULT, config_noop, - config_remove_fields) + +from tests.taxii2.utils import ( + API_ROOTS_WITH_DEFAULT, + API_ROOTS_WITHOUT_DEFAULT, + config_noop, + config_remove_fields, +) @pytest.mark.parametrize( @@ -45,7 +49,9 @@ "title": "Some TAXII Server", "description": "This TAXII Server contains a listing of...", "contact": "string containing contact information", - "api_roots": [f"/taxii2/{item.id}/" for item in API_ROOTS_WITHOUT_DEFAULT], + "api_roots": [ + f"/taxii2/{item.id}/" for item in API_ROOTS_WITHOUT_DEFAULT + ], }, id="good, without default api root", ), @@ -126,19 +132,22 @@ def test_discovery( "description": "This TAXII Server contains a listing of...", "contact": "string containing contact information", } - with patch.object( - authenticated_client.application.taxii_server.servers.taxii2, - "config", - config_override_func( - { - **authenticated_client.application.taxii_server.servers.taxii2.config, - **config_defaults, - } + with ( + patch.object( + authenticated_client.application.taxii_server.servers.taxii2, + "config", + config_override_func( + { + **authenticated_client.application.taxii_server.servers.taxii2.config, + **config_defaults, + } + ), + ), + patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_api_roots", + return_value=api_roots, ), - ), patch.object( - authenticated_client.application.taxii_server.servers.taxii2.persistence.api, - "get_api_roots", - return_value=api_roots, ): func = getattr(authenticated_client, method) response = func("/taxii2/", headers=headers) diff --git a/tests/taxii2/test_taxii2_manifest.py b/tests/taxii2/test_taxii2_manifest.py index 877cdb8c..0569ed73 100644 --- a/tests/taxii2/test_taxii2_manifest.py +++ b/tests/taxii2/test_taxii2_manifest.py @@ -5,10 +5,17 @@ from uuid import uuid4 import pytest + from opentaxii.taxii2.utils import taxii2_datetimeformat -from tests.taxii2.utils import (API_ROOTS, COLLECTIONS, GET_COLLECTION_MOCK, - GET_MANIFEST_MOCK, GET_NEXT_PARAM, NOW, - STIX_OBJECTS) +from tests.taxii2.utils import ( + API_ROOTS, + COLLECTIONS, + GET_COLLECTION_MOCK, + GET_MANIFEST_MOCK, + GET_NEXT_PARAM, + NOW, + STIX_OBJECTS, +) @pytest.mark.parametrize( @@ -701,24 +708,28 @@ def test_manifest( expected_headers, expected_content, ): - with patch.object( - authenticated_client.application.taxii_server.servers.taxii2.persistence.api, - "get_manifest", - side_effect=GET_MANIFEST_MOCK, - ), patch.object( - authenticated_client.application.taxii_server.servers.taxii2.persistence.api, - "get_collection", - side_effect=GET_COLLECTION_MOCK, - ), patch.object( - authenticated_client.account, - "permissions", - { - COLLECTIONS[0].id: ["read"], - COLLECTIONS[1].id: ["write"], - COLLECTIONS[2].id: ["read", "write"], - COLLECTIONS[4].id: ["read", "write"], - COLLECTIONS[5].id: ["write", "read"], - }, + with ( + patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_manifest", + side_effect=GET_MANIFEST_MOCK, + ), + patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_collection", + side_effect=GET_COLLECTION_MOCK, + ), + patch.object( + authenticated_client.account, + "permissions", + { + COLLECTIONS[0].id: ["read"], + COLLECTIONS[1].id: ["write"], + COLLECTIONS[2].id: ["read", "write"], + COLLECTIONS[4].id: ["read", "write"], + COLLECTIONS[5].id: ["write", "read"], + }, + ), ): func = getattr(authenticated_client, method) if filter_kwargs: @@ -762,14 +773,17 @@ def test_manifest_unauthenticated( expected_status_code = 401 else: expected_status_code = 405 - with patch.object( - client.application.taxii_server.servers.taxii2.persistence.api, - "get_manifest", - side_effect=GET_MANIFEST_MOCK, - ), patch.object( - client.application.taxii_server.servers.taxii2.persistence.api, - "get_collection", - side_effect=GET_COLLECTION_MOCK, + with ( + patch.object( + client.application.taxii_server.servers.taxii2.persistence.api, + "get_manifest", + side_effect=GET_MANIFEST_MOCK, + ), + patch.object( + client.application.taxii_server.servers.taxii2.persistence.api, + "get_collection", + side_effect=GET_COLLECTION_MOCK, + ), ): func = getattr(client, method) response = func( diff --git a/tests/taxii2/test_taxii2_object.py b/tests/taxii2/test_taxii2_object.py index 2a7c9fb3..0569d320 100644 --- a/tests/taxii2/test_taxii2_object.py +++ b/tests/taxii2/test_taxii2_object.py @@ -5,10 +5,18 @@ from uuid import uuid4 import pytest + from opentaxii.taxii2.utils import DATETIMEFORMAT, taxii2_datetimeformat -from tests.taxii2.utils import (API_ROOTS, COLLECTIONS, DELETE_OBJECT_MOCK, - GET_COLLECTION_MOCK, GET_NEXT_PARAM, - GET_OBJECT_MOCK, NOW, STIX_OBJECTS) +from tests.taxii2.utils import ( + API_ROOTS, + COLLECTIONS, + DELETE_OBJECT_MOCK, + GET_COLLECTION_MOCK, + GET_NEXT_PARAM, + GET_OBJECT_MOCK, + NOW, + STIX_OBJECTS, +) @pytest.mark.parametrize( @@ -796,29 +804,34 @@ def test_object( expected_headers, expected_content, ): - with patch.object( - authenticated_client.application.taxii_server.servers.taxii2.persistence.api, - "get_object", - side_effect=GET_OBJECT_MOCK, - ), patch.object( - authenticated_client.application.taxii_server.servers.taxii2.persistence.api, - "get_collection", - side_effect=GET_COLLECTION_MOCK, - ), patch.object( - authenticated_client.account, - "permissions", - { - COLLECTIONS[0].id: ["read"], - COLLECTIONS[1].id: ["write"], - COLLECTIONS[2].id: ["read", "write"], - COLLECTIONS[4].id: ["read", "write"], - COLLECTIONS[5].id: ["write", "read"], - }, - ), patch.object( - authenticated_client.application.taxii_server.servers.taxii2.persistence.api, - "delete_object", - side_effect=DELETE_OBJECT_MOCK, - ) as delete_object_mock: + with ( + patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_object", + side_effect=GET_OBJECT_MOCK, + ), + patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_collection", + side_effect=GET_COLLECTION_MOCK, + ), + patch.object( + authenticated_client.account, + "permissions", + { + COLLECTIONS[0].id: ["read"], + COLLECTIONS[1].id: ["write"], + COLLECTIONS[2].id: ["read", "write"], + COLLECTIONS[4].id: ["read", "write"], + COLLECTIONS[5].id: ["write", "read"], + }, + ), + patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "delete_object", + side_effect=DELETE_OBJECT_MOCK, + ) as delete_object_mock, + ): func = getattr(authenticated_client, method) if filter_kwargs: querystring = f"?{urlencode(filter_kwargs)}" @@ -832,16 +845,20 @@ def test_object( assert response.status_code == expected_status if method == "delete" and expected_status == 200: expected_kwargs = { - "match_version": [ - datetime.datetime.strptime( - filter_kwargs["match[version]"], DATETIMEFORMAT - ).replace(tzinfo=datetime.timezone.utc) - ] - if "match[version]" in filter_kwargs - else None, - "match_spec_version": [filter_kwargs["match[spec_version]"]] - if "match[spec_version]" in filter_kwargs - else None, + "match_version": ( + [ + datetime.datetime.strptime( + filter_kwargs["match[version]"], DATETIMEFORMAT + ).replace(tzinfo=datetime.timezone.utc) + ] + if "match[version]" in filter_kwargs + else None + ), + "match_spec_version": ( + [filter_kwargs["match[spec_version]"]] + if "match[spec_version]" in filter_kwargs + else None + ), } delete_object_mock.assert_called_once_with( collection_id=COLLECTIONS[5].id, object_id=object_id, **expected_kwargs @@ -886,14 +903,17 @@ def test_object_unauthenticated( expected_status_code = 401 else: expected_status_code = 405 - with patch.object( - client.application.taxii_server.servers.taxii2.persistence.api, - "get_object", - side_effect=GET_OBJECT_MOCK, - ), patch.object( - client.application.taxii_server.servers.taxii2.persistence.api, - "get_collection", - side_effect=GET_COLLECTION_MOCK, + with ( + patch.object( + client.application.taxii_server.servers.taxii2.persistence.api, + "get_object", + side_effect=GET_OBJECT_MOCK, + ), + patch.object( + client.application.taxii_server.servers.taxii2.persistence.api, + "get_collection", + side_effect=GET_COLLECTION_MOCK, + ), ): func = getattr(client, method) response = func( diff --git a/tests/taxii2/test_taxii2_objects.py b/tests/taxii2/test_taxii2_objects.py index c5a97010..4206a289 100644 --- a/tests/taxii2/test_taxii2_objects.py +++ b/tests/taxii2/test_taxii2_objects.py @@ -5,11 +5,22 @@ from uuid import uuid4 import pytest + from opentaxii.taxii2.utils import taxii2_datetimeformat -from tests.taxii2.utils import (ADD_OBJECTS_MOCK, API_ROOTS, COLLECTIONS, - GET_COLLECTION_MOCK, GET_JOB_AND_DETAILS_MOCK, - GET_NEXT_PARAM, GET_OBJECTS_MOCK, JOBS, NOW, - STIX_OBJECTS, config_noop, config_override) +from tests.taxii2.utils import ( + ADD_OBJECTS_MOCK, + API_ROOTS, + COLLECTIONS, + GET_COLLECTION_MOCK, + GET_JOB_AND_DETAILS_MOCK, + GET_NEXT_PARAM, + GET_OBJECTS_MOCK, + JOBS, + NOW, + STIX_OBJECTS, + config_noop, + config_override, +) from tests.utils import SKIP @@ -1169,38 +1180,45 @@ def test_objects( expected_headers, expected_content, ): - with patch.object( - authenticated_client.application.taxii_server.servers.taxii2, - "config", - config_override_func( - authenticated_client.application.taxii_server.servers.taxii2.config + with ( + patch.object( + authenticated_client.application.taxii_server.servers.taxii2, + "config", + config_override_func( + authenticated_client.application.taxii_server.servers.taxii2.config + ), + ), + patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_objects", + side_effect=GET_OBJECTS_MOCK, + ), + patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_collection", + side_effect=GET_COLLECTION_MOCK, + ), + patch.object( + authenticated_client.account, + "permissions", + { + COLLECTIONS[0].id: ["read"], + COLLECTIONS[1].id: ["write"], + COLLECTIONS[2].id: ["read", "write"], + COLLECTIONS[4].id: ["read", "write"], + COLLECTIONS[5].id: ["write", "read"], + }, + ), + patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "add_objects", + side_effect=ADD_OBJECTS_MOCK, + ) as add_objects_mock, + patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_job_and_details", + side_effect=GET_JOB_AND_DETAILS_MOCK, ), - ), patch.object( - authenticated_client.application.taxii_server.servers.taxii2.persistence.api, - "get_objects", - side_effect=GET_OBJECTS_MOCK, - ), patch.object( - authenticated_client.application.taxii_server.servers.taxii2.persistence.api, - "get_collection", - side_effect=GET_COLLECTION_MOCK, - ), patch.object( - authenticated_client.account, - "permissions", - { - COLLECTIONS[0].id: ["read"], - COLLECTIONS[1].id: ["write"], - COLLECTIONS[2].id: ["read", "write"], - COLLECTIONS[4].id: ["read", "write"], - COLLECTIONS[5].id: ["write", "read"], - }, - ), patch.object( - authenticated_client.application.taxii_server.servers.taxii2.persistence.api, - "add_objects", - side_effect=ADD_OBJECTS_MOCK, - ) as add_objects_mock, patch.object( - authenticated_client.application.taxii_server.servers.taxii2.persistence.api, - "get_job_and_details", - side_effect=GET_JOB_AND_DETAILS_MOCK, ): func = getattr(authenticated_client, method) if filter_kwargs: @@ -1271,19 +1289,23 @@ def test_objects_unauthenticated( expected_status_code = 401 else: expected_status_code = 405 - with patch.object( - client.application.taxii_server.servers.taxii2.persistence.api, - "get_objects", - side_effect=GET_OBJECTS_MOCK, - ), patch.object( - client.application.taxii_server.servers.taxii2.persistence.api, - "get_collection", - side_effect=GET_COLLECTION_MOCK, - ), patch.object( - client.application.taxii_server.servers.taxii2.persistence.api, - "add_objects", - side_effect=ADD_OBJECTS_MOCK, - ) as add_objects_mock: + with ( + patch.object( + client.application.taxii_server.servers.taxii2.persistence.api, + "get_objects", + side_effect=GET_OBJECTS_MOCK, + ), + patch.object( + client.application.taxii_server.servers.taxii2.persistence.api, + "get_collection", + side_effect=GET_COLLECTION_MOCK, + ), + patch.object( + client.application.taxii_server.servers.taxii2.persistence.api, + "add_objects", + side_effect=ADD_OBJECTS_MOCK, + ) as add_objects_mock, + ): kwargs = { "headers": { "Accept": "application/taxii+json;version=2.1", diff --git a/tests/taxii2/test_taxii2_sqldb.py b/tests/taxii2/test_taxii2_sqldb.py index c993720b..a4247bcf 100644 --- a/tests/taxii2/test_taxii2_sqldb.py +++ b/tests/taxii2/test_taxii2_sqldb.py @@ -2,16 +2,27 @@ from uuid import uuid4 import pytest + from opentaxii.persistence.sqldb.taxii2models import Job, JobDetail, STIXObject from opentaxii.taxii2 import entities from opentaxii.taxii2.utils import DATETIMEFORMAT -from tests.taxii2.utils import (API_ROOTS, API_ROOTS_WITH_DEFAULT, - API_ROOTS_WITHOUT_DEFAULT, COLLECTIONS, - GET_API_ROOT_MOCK, GET_COLLECTION_MOCK, - GET_COLLECTIONS_MOCK, GET_JOB_AND_DETAILS_MOCK, - GET_MANIFEST_MOCK, GET_OBJECT_MOCK, - GET_OBJECTS_MOCK, GET_VERSIONS_MOCK, JOBS, NOW, - STIX_OBJECTS) +from tests.taxii2.utils import ( + API_ROOTS, + API_ROOTS_WITH_DEFAULT, + API_ROOTS_WITHOUT_DEFAULT, + COLLECTIONS, + GET_API_ROOT_MOCK, + GET_COLLECTION_MOCK, + GET_COLLECTIONS_MOCK, + GET_JOB_AND_DETAILS_MOCK, + GET_MANIFEST_MOCK, + GET_OBJECT_MOCK, + GET_OBJECTS_MOCK, + GET_VERSIONS_MOCK, + JOBS, + NOW, + STIX_OBJECTS, +) @pytest.mark.parametrize( diff --git a/tests/taxii2/test_taxii2_status.py b/tests/taxii2/test_taxii2_status.py index 660345b8..f698d698 100644 --- a/tests/taxii2/test_taxii2_status.py +++ b/tests/taxii2/test_taxii2_status.py @@ -3,12 +3,18 @@ from uuid import uuid4 import pytest + from opentaxii.persistence.sqldb import taxii2models from opentaxii.taxii2.utils import taxii2_datetimeformat -from tests.taxii2.utils import (API_ROOTS, GET_API_ROOT_MOCK, - GET_JOB_AND_DETAILS_MOCK, JOBS, config_noop, - server_mapping_noop, - server_mapping_remove_fields) +from tests.taxii2.utils import ( + API_ROOTS, + GET_API_ROOT_MOCK, + GET_JOB_AND_DETAILS_MOCK, + JOBS, + config_noop, + server_mapping_noop, + server_mapping_remove_fields, +) @pytest.mark.parametrize( @@ -242,29 +248,35 @@ def test_status( expected_headers, expected_content, ): - with patch.object( - authenticated_client.application.taxii_server.servers.taxii2, - "config", - config_override_func( - authenticated_client.application.taxii_server.servers.taxii2.config + with ( + patch.object( + authenticated_client.application.taxii_server.servers.taxii2, + "config", + config_override_func( + authenticated_client.application.taxii_server.servers.taxii2.config + ), + ), + patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_api_roots", + return_value=API_ROOTS, + ), + patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_api_root", + side_effect=GET_API_ROOT_MOCK, ), - ), patch.object( - authenticated_client.application.taxii_server.servers.taxii2.persistence.api, - "get_api_roots", - return_value=API_ROOTS, - ), patch.object( - authenticated_client.application.taxii_server.servers.taxii2.persistence.api, - "get_api_root", - side_effect=GET_API_ROOT_MOCK, - ), patch.object( - authenticated_client.application.taxii_server.servers.taxii2.persistence.api, - "get_job_and_details", - side_effect=GET_JOB_AND_DETAILS_MOCK, - ), patch.object( - authenticated_client.application.taxii_server, - "servers", - server_mapping_override_func( - authenticated_client.application.taxii_server.servers + patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_job_and_details", + side_effect=GET_JOB_AND_DETAILS_MOCK, + ), + patch.object( + authenticated_client.application.taxii_server, + "servers", + server_mapping_override_func( + authenticated_client.application.taxii_server.servers + ), ), ): func = getattr(authenticated_client, method) @@ -304,14 +316,17 @@ def test_status_unauthenticated( expected_status_code = 401 else: expected_status_code = 405 - with patch.object( - client.application.taxii_server.servers.taxii2.persistence.api, - "get_api_root", - side_effect=GET_API_ROOT_MOCK, - ), patch.object( - client.application.taxii_server.servers.taxii2.persistence.api, - "get_job_and_details", - side_effect=GET_JOB_AND_DETAILS_MOCK, + with ( + patch.object( + client.application.taxii_server.servers.taxii2.persistence.api, + "get_api_root", + side_effect=GET_API_ROOT_MOCK, + ), + patch.object( + client.application.taxii_server.servers.taxii2.persistence.api, + "get_job_and_details", + side_effect=GET_JOB_AND_DETAILS_MOCK, + ), ): func = getattr(client, method) response = func( diff --git a/tests/taxii2/test_taxii2_utils.py b/tests/taxii2/test_taxii2_utils.py index c83376bd..e6b3d8f9 100644 --- a/tests/taxii2/test_taxii2_utils.py +++ b/tests/taxii2/test_taxii2_utils.py @@ -1,6 +1,7 @@ import datetime import pytest + from opentaxii.taxii2.utils import taxii2_datetimeformat diff --git a/tests/taxii2/test_taxii2_versions.py b/tests/taxii2/test_taxii2_versions.py index 1a209b1f..e2a5de31 100644 --- a/tests/taxii2/test_taxii2_versions.py +++ b/tests/taxii2/test_taxii2_versions.py @@ -5,10 +5,17 @@ from uuid import uuid4 import pytest + from opentaxii.taxii2.utils import taxii2_datetimeformat -from tests.taxii2.utils import (API_ROOTS, COLLECTIONS, GET_COLLECTION_MOCK, - GET_NEXT_PARAM, GET_VERSIONS_MOCK, NOW, - STIX_OBJECTS) +from tests.taxii2.utils import ( + API_ROOTS, + COLLECTIONS, + GET_COLLECTION_MOCK, + GET_NEXT_PARAM, + GET_VERSIONS_MOCK, + NOW, + STIX_OBJECTS, +) @pytest.mark.parametrize( @@ -381,24 +388,28 @@ def test_versions( expected_headers, expected_content, ): - with patch.object( - authenticated_client.application.taxii_server.servers.taxii2.persistence.api, - "get_versions", - side_effect=GET_VERSIONS_MOCK, - ), patch.object( - authenticated_client.application.taxii_server.servers.taxii2.persistence.api, - "get_collection", - side_effect=GET_COLLECTION_MOCK, - ), patch.object( - authenticated_client.account, - "permissions", - { - COLLECTIONS[0].id: ["read"], - COLLECTIONS[1].id: ["write"], - COLLECTIONS[2].id: ["read", "write"], - COLLECTIONS[4].id: ["read", "write"], - COLLECTIONS[5].id: ["write", "read"], - }, + with ( + patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_versions", + side_effect=GET_VERSIONS_MOCK, + ), + patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_collection", + side_effect=GET_COLLECTION_MOCK, + ), + patch.object( + authenticated_client.account, + "permissions", + { + COLLECTIONS[0].id: ["read"], + COLLECTIONS[1].id: ["write"], + COLLECTIONS[2].id: ["read", "write"], + COLLECTIONS[4].id: ["read", "write"], + COLLECTIONS[5].id: ["write", "read"], + }, + ), ): func = getattr(authenticated_client, method) if filter_kwargs: @@ -445,14 +456,17 @@ def test_versions_unauthenticated( expected_status_code = 401 else: expected_status_code = 405 - with patch.object( - client.application.taxii_server.servers.taxii2.persistence.api, - "get_versions", - side_effect=GET_VERSIONS_MOCK, - ), patch.object( - client.application.taxii_server.servers.taxii2.persistence.api, - "get_collection", - side_effect=GET_COLLECTION_MOCK, + with ( + patch.object( + client.application.taxii_server.servers.taxii2.persistence.api, + "get_versions", + side_effect=GET_VERSIONS_MOCK, + ), + patch.object( + client.application.taxii_server.servers.taxii2.persistence.api, + "get_collection", + side_effect=GET_COLLECTION_MOCK, + ), ): func = getattr(client, method) response = func( diff --git a/tests/taxii2/test_validation.py b/tests/taxii2/test_validation.py index 61e54e8f..99a5de14 100644 --- a/tests/taxii2/test_validation.py +++ b/tests/taxii2/test_validation.py @@ -2,6 +2,7 @@ import platform import pytest + from opentaxii.taxii2.exceptions import ValidationError from opentaxii.taxii2.validation import validate_envelope from tests.utils import conditional diff --git a/tests/taxii2/utils.py b/tests/taxii2/utils.py index a3c68039..322b1a03 100644 --- a/tests/taxii2/utils.py +++ b/tests/taxii2/utils.py @@ -4,9 +4,15 @@ from uuid import uuid4 from opentaxii.server import ServerMapping -from opentaxii.taxii2.entities import (ApiRoot, Collection, Job, JobDetail, - ManifestRecord, STIXObject, - VersionRecord) +from opentaxii.taxii2.entities import ( + ApiRoot, + Collection, + Job, + JobDetail, + ManifestRecord, + STIXObject, + VersionRecord, +) from opentaxii.taxii2.utils import DATETIMEFORMAT, taxii2_datetimeformat API_ROOTS_WITH_DEFAULT = ( @@ -146,7 +152,9 @@ False, False, ), - Collection(str(uuid4()), API_ROOTS[0].id, "4No description", "", None, False, False), + Collection( + str(uuid4()), API_ROOTS[0].id, "4No description", "", None, False, False + ), Collection( str(uuid4()), API_ROOTS[0].id, @@ -506,9 +514,11 @@ def GET_VERSIONS_MOCK( match_version=["all"], ) return ( - [VersionRecord(obj.date_added, obj.version) for obj in versions] - if versions is not None - else None, + ( + [VersionRecord(obj.date_added, obj.version) for obj in versions] + if versions is not None + else None + ), more, ) diff --git a/tests/test_auth.py b/tests/test_auth.py index 98e958f6..9b0336e7 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -2,23 +2,23 @@ import json import pytest +from fixtures import VID_TAXII_HTTP_10 from libtaxii import messages_10 as tm10 from libtaxii import messages_11 as tm11 -from libtaxii.constants import (CB_STIX_XML_111, RT_FULL, ST_BAD_MESSAGE, - ST_UNAUTHORIZED) +from libtaxii.constants import CB_STIX_XML_111, RT_FULL, ST_BAD_MESSAGE, ST_UNAUTHORIZED +from utils import as_tm, is_headers_valid, prepare_headers + from opentaxii.taxii.http import HTTP_AUTHORIZATION from opentaxii.utils import sync_conf_dict_into_db -from fixtures import VID_TAXII_HTTP_10 -from utils import as_tm, is_headers_valid, prepare_headers - INBOX_OPEN = dict( id='inbox-A', type='inbox', description='inboxA description', address='/path/inbox', destination_collection_required=True, - authentication_required=False) + authentication_required=False, +) INBOX_CLOSED = dict( id='inbox-A', @@ -26,7 +26,8 @@ description='inboxA description', address='/path/inbox', destination_collection_required=True, - authentication_required=True) + authentication_required=True, +) DISCOVERY = dict( id='discovery-A', @@ -35,7 +36,8 @@ address='/path/discovery', advertised_services=['inbox-A', 'poll-A'], protocol_bindings=[VID_TAXII_HTTP_10], - authentication_required=True) + authentication_required=True, +) POLL_CLOSED = dict( id='poll-A', @@ -45,7 +47,8 @@ protocol_bindings=[VID_TAXII_HTTP_10], authentication_required=True, max_result_size=100, - max_result_count=10) + max_result_count=10, +) POLL_OPEN = dict( id='poll-B', @@ -55,40 +58,45 @@ protocol_bindings=[VID_TAXII_HTTP_10], authentication_required=False, # <- open for all max_result_size=100, - max_result_count=10) + max_result_count=10, +) CONTENT = 'inbox-message-content' COLLECTIONS = [ - {'name': 'collection-1', - 'available': True, - 'accept_all_content': True, - 'type': 'DATA_FEED', - 'service_ids': ['discovery-A', 'inbox-A', 'poll-A', 'poll-B']}, - {'name': 'collection-2', - 'available': True, - 'accept_all_content': True, - 'type': 'DATA_FEED', - 'service_ids': ['discovery-A', 'inbox-A', 'poll-A', 'poll-B']}] + { + 'name': 'collection-1', + 'available': True, + 'accept_all_content': True, + 'type': 'DATA_FEED', + 'service_ids': ['discovery-A', 'inbox-A', 'poll-A', 'poll-B'], + }, + { + 'name': 'collection-2', + 'available': True, + 'accept_all_content': True, + 'type': 'DATA_FEED', + 'service_ids': ['discovery-A', 'inbox-A', 'poll-A', 'poll-B'], + }, +] USERNAME = 'some-username' PASSWORD = 'some-password' ACCOUNTS = [ - {'username': 'johnny', - 'password': 'johnny', - 'permissions': { - 'collection-1': 'read', - 'collection-2': 'modify'}}, - {'username': 'billy', - 'password': 'billy', - 'permissions': { - 'collection-1': 'modify'}}, - {'username': 'wally', - 'password': 'wally', - 'is_admin': True}, - {'username': USERNAME, - 'password': PASSWORD}] + { + 'username': 'johnny', + 'password': 'johnny', + 'permissions': {'collection-1': 'read', 'collection-2': 'modify'}, + }, + { + 'username': 'billy', + 'password': 'billy', + 'permissions': {'collection-1': 'modify'}, + }, + {'username': 'wally', 'password': 'wally', 'is_admin': True}, + {'username': USERNAME, 'password': PASSWORD}, +] MESSAGE_ID = '123' @@ -100,11 +108,11 @@ def auth_fixtures(server): sync_conf_dict_into_db( server, config={ - 'services': [ - INBOX_OPEN, INBOX_CLOSED, - DISCOVERY, POLL_OPEN, POLL_CLOSED], + 'services': [INBOX_OPEN, INBOX_CLOSED, DISCOVERY, POLL_OPEN, POLL_CLOSED], 'collections': COLLECTIONS, - 'accounts': ACCOUNTS}) + 'accounts': ACCOUNTS, + }, + ) assert len(server.servers.taxii1.persistence.get_services()) == 4 assert len(server.servers.taxii1.persistence.get_collections()) == len(COLLECTIONS) @@ -113,6 +121,7 @@ def auth_fixtures(server): @pytest.fixture() def test_account(server): from opentaxii.entities import Account + account = Account(id=None, username=USERNAME, permissions={}) server.auth.update_account(account, PASSWORD) @@ -125,7 +134,8 @@ def test_unauthorized_request(app, client, version, https): INBOX_OPEN['address'], data='invalid-body', headers=prepare_headers(version, https), - base_url=base_url) + base_url=base_url, + ) assert response.status_code == 200 assert is_headers_valid(response.headers, version, https) @@ -134,6 +144,7 @@ def test_unauthorized_request(app, client, version, https): assert message.status_type == ST_UNAUTHORIZED from opentaxii import context + assert not hasattr(context, 'account') @@ -143,23 +154,18 @@ def test_get_token(client, version, https): base_url = '%s://localhost' % ('https' if https else 'http') # Invalid credentials response = client.post( - AUTH_PATH, - data={'username': 'dummy', 'password': 'wrong'}, - base_url=base_url) + AUTH_PATH, data={'username': 'dummy', 'password': 'wrong'}, base_url=base_url + ) assert response.status_code == 401 # Invalid auth data - response = client.post( - AUTH_PATH, - data={'other': 'somethind'}, - base_url=base_url) + response = client.post(AUTH_PATH, data={'other': 'somethind'}, base_url=base_url) assert response.status_code == 400 # Valid credentials as form data response = client.post( - AUTH_PATH, - data={'username': USERNAME, 'password': PASSWORD}, - base_url=base_url) + AUTH_PATH, data={'username': USERNAME, 'password': PASSWORD}, base_url=base_url + ) assert response.status_code == 200 @@ -171,7 +177,8 @@ def test_get_token(client, version, https): AUTH_PATH, data=json.dumps({'username': USERNAME, 'password': PASSWORD}), base_url=base_url, - content_type='application/json') + content_type='application/json', + ) assert response.status_code == 200 @@ -187,9 +194,8 @@ def test_get_token_and_send_request(client, version, https): # Get valid token response = client.post( - AUTH_PATH, - data={'username': USERNAME, 'password': PASSWORD}, - base_url=base_url) + AUTH_PATH, data={'username': USERNAME, 'password': PASSWORD}, base_url=base_url + ) assert response.status_code == 200 @@ -203,10 +209,8 @@ def test_get_token_and_send_request(client, version, https): # Get correct response for invalid body response = client.post( - INBOX_OPEN['address'], - data='invalid-body', - headers=headers, - base_url=base_url) + INBOX_OPEN['address'], data='invalid-body', headers=headers, base_url=base_url + ) assert response.status_code == 200 assert is_headers_valid(response.headers, version, https) @@ -220,10 +224,8 @@ def test_get_token_and_send_request(client, version, https): # Get correct response for valid request response = client.post( - DISCOVERY['address'], - data=request.to_xml(), - headers=headers, - base_url=base_url) + DISCOVERY['address'], data=request.to_xml(), headers=headers, base_url=base_url + ) assert response.status_code == 200 assert is_headers_valid(response.headers, version=version, https=https) @@ -233,12 +235,12 @@ def test_get_token_and_send_request(client, version, https): assert isinstance(message, as_tm(version).DiscoveryResponse) from opentaxii import context + assert not hasattr(context, 'account') def basic_auth_token(username, password): - return base64.b64encode( - '{}:{}'.format(username, password).encode('utf-8')) + return base64.b64encode('{}:{}'.format(username, password).encode('utf-8')) @pytest.mark.parametrize("https", [True, False]) @@ -247,17 +249,16 @@ def test_request_with_basic_auth(client, version, https): base_url = '%s://localhost' % ('https' if https else 'http') basic_auth_header = 'Basic {}'.format( - basic_auth_token(USERNAME, PASSWORD).decode('utf-8')) + basic_auth_token(USERNAME, PASSWORD).decode('utf-8') + ) headers = prepare_headers(version, https) headers[HTTP_AUTHORIZATION] = basic_auth_header # Get correct response for invalid body response = client.post( - INBOX_OPEN['address'], - data='invalid-body', - headers=headers, - base_url=base_url) + INBOX_OPEN['address'], data='invalid-body', headers=headers, base_url=base_url + ) assert response.status_code == 200 assert is_headers_valid(response.headers, version, https) @@ -271,10 +272,7 @@ def test_request_with_basic_auth(client, version, https): # Get correct response for valid request response = client.post( - DISCOVERY['address'], - data=request.to_xml(), - headers=headers, - base_url=base_url + DISCOVERY['address'], data=request.to_xml(), headers=headers, base_url=base_url ) assert response.status_code == 200 @@ -285,6 +283,7 @@ def test_request_with_basic_auth(client, version, https): assert isinstance(message, as_tm(version).DiscoveryResponse) from opentaxii import context + assert not hasattr(context, 'account') @@ -299,10 +298,8 @@ def test_invalid_basic_auth_request(client, version, https): request = as_tm(version).DiscoveryRequest(message_id=MESSAGE_ID) response = client.post( - DISCOVERY['address'], - data=request.to_xml(), - headers=headers, - base_url=base_url) + DISCOVERY['address'], data=request.to_xml(), headers=headers, base_url=base_url + ) assert response.status_code == 200 assert is_headers_valid(response.headers, version, https) @@ -324,10 +321,8 @@ def test_invalid_auth_header_request(client, version, https): request = as_tm(version).DiscoveryRequest(message_id=MESSAGE_ID) response = client.post( - DISCOVERY['address'], - data=request.to_xml(), - headers=headers, - base_url=base_url) + DISCOVERY['address'], data=request.to_xml(), headers=headers, base_url=base_url + ) assert response.status_code == 200 assert is_headers_valid(response.headers, version, https) @@ -340,7 +335,8 @@ def prepare_url_headers(version, https, username, password): base_url = '%s://localhost' % ('https' if https else 'http') headers = prepare_headers(version, https) basic_auth_header = 'Basic {}'.format( - basic_auth_token(username, password).decode('utf-8')) + basic_auth_token(username, password).decode('utf-8') + ) headers[HTTP_AUTHORIZATION] = basic_auth_header return base_url, headers @@ -351,14 +347,11 @@ def test_collection_access_private_poll(client, version, https): # POLL_CLOSED collection allowed read access url, headers = prepare_url_headers(version, https, 'johnny', 'johnny') - request = prepare_poll_request( - 'collection-1', version, bindings=[CB_STIX_XML_111]) + request = prepare_poll_request('collection-1', version, bindings=[CB_STIX_XML_111]) response = client.post( - POLL_CLOSED['address'], - data=request.to_xml(), - headers=headers, - base_url=url) + POLL_CLOSED['address'], data=request.to_xml(), headers=headers, base_url=url + ) assert response.status_code == 200 assert is_headers_valid(response.headers, version, https) message = as_tm(version).get_message_from_xml(response.data) @@ -366,14 +359,11 @@ def test_collection_access_private_poll(client, version, https): # POLL_CLOSED collection disallowed read access url, headers = prepare_url_headers(version, https, 'billy', 'billy') - request = prepare_poll_request( - 'collection-2', version, bindings=[CB_STIX_XML_111]) + request = prepare_poll_request('collection-2', version, bindings=[CB_STIX_XML_111]) response = client.post( - POLL_CLOSED['address'], - data=request.to_xml(), - headers=headers, - base_url=url) + POLL_CLOSED['address'], data=request.to_xml(), headers=headers, base_url=url + ) assert response.status_code == 200 assert is_headers_valid(response.headers, version, https) message = as_tm(version).get_message_from_xml(response.data) @@ -381,14 +371,11 @@ def test_collection_access_private_poll(client, version, https): # POLL_CLOSED collection admin access url, headers = prepare_url_headers(version, https, 'wally', 'wally') - request = prepare_poll_request( - 'collection-2', version, bindings=[CB_STIX_XML_111]) + request = prepare_poll_request('collection-2', version, bindings=[CB_STIX_XML_111]) response = client.post( - POLL_CLOSED['address'], - data=request.to_xml(), - headers=headers, - base_url=url) + POLL_CLOSED['address'], data=request.to_xml(), headers=headers, base_url=url + ) assert response.status_code == 200 assert is_headers_valid(response.headers, version, https) message = as_tm(version).get_message_from_xml(response.data) @@ -401,15 +388,12 @@ def test_collection_access_private_inbox(client, version, https): # INBOX read-only collection access url, headers = prepare_url_headers(version, https, 'johnny', 'johnny') request = prepare_inbox_message( - version, - dest_collection='collection-1', - blocks=[make_inbox_content(version)]) + version, dest_collection='collection-1', blocks=[make_inbox_content(version)] + ) response = client.post( - INBOX_CLOSED['address'], - data=request.to_xml(), - headers=headers, - base_url=url) + INBOX_CLOSED['address'], data=request.to_xml(), headers=headers, base_url=url + ) assert response.status_code == 200 assert is_headers_valid(response.headers, version, https) message = as_tm(version).get_message_from_xml(response.data) @@ -417,9 +401,7 @@ def test_collection_access_private_inbox(client, version, https): if version == 11: assert message.status_type == 'UNAUTHORIZED' - assert ( - message.message == - 'User can not write to collection collection-1') + assert message.message == 'User can not write to collection collection-1' else: # Because in TAXII 1.0 destination collection can not be specified # so it impossible to verify access @@ -427,15 +409,12 @@ def test_collection_access_private_inbox(client, version, https): # INBOX modify collection access request = prepare_inbox_message( - version, - dest_collection='collection-2', - blocks=[make_inbox_content(version)]) + version, dest_collection='collection-2', blocks=[make_inbox_content(version)] + ) response = client.post( - INBOX_CLOSED['address'], - data=request.to_xml(), - headers=headers, - base_url=url) + INBOX_CLOSED['address'], data=request.to_xml(), headers=headers, base_url=url + ) assert response.status_code == 200 assert is_headers_valid(response.headers, version, https) message = as_tm(version).get_message_from_xml(response.data) @@ -445,15 +424,12 @@ def test_collection_access_private_inbox(client, version, https): # INBOX modify collection access url, headers = prepare_url_headers(version, https, 'wally', 'wally') request = prepare_inbox_message( - version, - dest_collection='collection-2', - blocks=[make_inbox_content(version)]) + version, dest_collection='collection-2', blocks=[make_inbox_content(version)] + ) response = client.post( - INBOX_CLOSED['address'], - data=request.to_xml(), - headers=headers, - base_url=url) + INBOX_CLOSED['address'], data=request.to_xml(), headers=headers, base_url=url + ) assert response.status_code == 200 assert is_headers_valid(response.headers, version, https) message = as_tm(version).get_message_from_xml(response.data) @@ -461,8 +437,7 @@ def test_collection_access_private_inbox(client, version, https): assert message.status_type == 'SUCCESS' -def prepare_poll_request( - collection_name, version, bindings=[], subscription_id=None): +def prepare_poll_request(collection_name, version, bindings=[], subscription_id=None): if version == 11: content_bindings = [tm11.ContentBinding(b) for b in bindings] @@ -470,41 +445,39 @@ def prepare_poll_request( poll_parameters = None else: poll_parameters = tm11.PollParameters( - response_type=RT_FULL, - content_bindings=content_bindings) + response_type=RT_FULL, content_bindings=content_bindings + ) return tm11.PollRequest( message_id=MESSAGE_ID, collection_name=collection_name, subscription_id=subscription_id, - poll_parameters=poll_parameters) + poll_parameters=poll_parameters, + ) elif version == 10: content_bindings = bindings return tm10.PollRequest( message_id=MESSAGE_ID, feed_name=collection_name, content_bindings=content_bindings, - subscription_id=subscription_id) + subscription_id=subscription_id, + ) -def make_inbox_content( - version, content_binding=CB_STIX_XML_111, content=CONTENT): +def make_inbox_content(version, content_binding=CB_STIX_XML_111, content=CONTENT): if version == 10: return tm10.ContentBlock(content_binding, content) elif version == 11: - return tm11.ContentBlock( - tm11.ContentBinding(content_binding), content) + return tm11.ContentBlock(tm11.ContentBinding(content_binding), content) else: raise ValueError('Unknown TAXII message version: %s' % version) def prepare_inbox_message(version, blocks=None, dest_collection=None): if version == 10: - inbox_message = tm10.InboxMessage( - message_id=MESSAGE_ID, content_blocks=blocks) + inbox_message = tm10.InboxMessage(message_id=MESSAGE_ID, content_blocks=blocks) elif version == 11: - inbox_message = tm11.InboxMessage( - message_id=MESSAGE_ID, content_blocks=blocks) + inbox_message = tm11.InboxMessage(message_id=MESSAGE_ID, content_blocks=blocks) if dest_collection: inbox_message.destination_collection_names.append(dest_collection) else: diff --git a/tests/test_cli.py b/tests/test_cli.py index 0ef6eb4c..dcb40721 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -81,8 +81,9 @@ ], ) def test_sync_data_configuration(app, capsys, argv, raises, message, stdout, stderr): - with mock.patch("opentaxii.cli.persistence.app", app), mock.patch( - "sys.argv", [""] + argv + with ( + mock.patch("opentaxii.cli.persistence.app", app), + mock.patch("sys.argv", [""] + argv), ): with conditional_raises(raises) as exception: sync_data_configuration() @@ -135,8 +136,9 @@ def test_sync_data_configuration(app, capsys, argv, raises, message, stdout, std def test_delete_content_blocks( app, collections, capsys, argv, raises, message, stdout, stderr ): - with mock.patch("opentaxii.cli.persistence.app", app), mock.patch( - "sys.argv", [""] + argv + with ( + mock.patch("opentaxii.cli.persistence.app", app), + mock.patch("sys.argv", [""] + argv), ): with conditional_raises(raises) as exception: delete_content_blocks() @@ -377,11 +379,15 @@ def test_update_account(app, account, capsys, argv, raises, message, stdout, std def test_add_api_root( app, capsys, argv, raises, message, stdout, stderr, expected_call ): - with mock.patch("opentaxii.cli.persistence.app", app), mock.patch( - "sys.argv", [""] + argv - ), mock.patch.object( - app.taxii_server.servers.taxii2.persistence.api, "add_api_root", autospec=True - ) as mock_add_api_root: + with ( + mock.patch("opentaxii.cli.persistence.app", app), + mock.patch("sys.argv", [""] + argv), + mock.patch.object( + app.taxii_server.servers.taxii2.persistence.api, + "add_api_root", + autospec=True, + ) as mock_add_api_root, + ): with conditional_raises(raises) as exception: add_api_root() if raises: @@ -556,11 +562,13 @@ def test_add_collection( "ROOTIDS", ",".join([api_root.id for api_root in db_api_roots]), ) - with mock.patch("opentaxii.cli.persistence.app", app), mock.patch( - "sys.argv", [""] + argv - ), mock.patch.object( - app.taxii_server.servers.taxii2.persistence.api, "add_collection" - ) as mock_add_collection: + with ( + mock.patch("opentaxii.cli.persistence.app", app), + mock.patch("sys.argv", [""] + argv), + mock.patch.object( + app.taxii_server.servers.taxii2.persistence.api, "add_collection" + ) as mock_add_collection, + ): with conditional_raises(raises) as exception: add_collection() if raises: @@ -575,9 +583,14 @@ def test_add_collection( def test_job_cleanup(app, capsys): - with mock.patch("opentaxii.cli.persistence.app", app), mock.patch.object( - app.taxii_server.servers.taxii2.persistence.api, "job_cleanup", return_value=2 - ) as mock_cleanup: + with ( + mock.patch("opentaxii.cli.persistence.app", app), + mock.patch.object( + app.taxii_server.servers.taxii2.persistence.api, + "job_cleanup", + return_value=2, + ) as mock_cleanup, + ): job_cleanup() mock_cleanup.assert_called_once_with() captured = capsys.readouterr() diff --git a/tests/test_config.py b/tests/test_config.py index 0a085726..54095f34 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,6 +2,7 @@ import tempfile import pytest + from opentaxii.config import ServerConfig BACKWARDS_COMPAT_CONFIG = """ diff --git a/tests/test_converters.py b/tests/test_converters.py index 8541d2aa..e11cd90c 100644 --- a/tests/test_converters.py +++ b/tests/test_converters.py @@ -1,5 +1,4 @@ import pytest - from libtaxii import messages_10 as tm10 from libtaxii import messages_11 as tm11 from libtaxii.constants import VID_TAXII_XML_10, VID_TAXII_XML_11 @@ -16,12 +15,11 @@ def test_parse_message(content_type): with pytest.raises(exceptions.BadMessageStatus): parse_message(content_type, 'invalid-body', do_validate=True) - tm = (tm10 if content_type == VID_TAXII_XML_10 else tm11) + tm = tm10 if content_type == VID_TAXII_XML_10 else tm11 parsed = parse_message( - content_type, - tm.DiscoveryRequest(MESSAGE_ID).to_xml(), - do_validate=True) + content_type, tm.DiscoveryRequest(MESSAGE_ID).to_xml(), do_validate=True + ) assert isinstance(parsed, tm.DiscoveryRequest) assert parsed.message_id == MESSAGE_ID diff --git a/tests/test_delete_content_blocks.py b/tests/test_delete_content_blocks.py index 709088e5..19aca42d 100644 --- a/tests/test_delete_content_blocks.py +++ b/tests/test_delete_content_blocks.py @@ -1,7 +1,6 @@ import datetime import pytest - from fixtures import COLLECTION_OPEN, COLLECTIONS_A diff --git a/tests/test_health.py b/tests/test_health.py index 2577e2f2..5303e0f6 100644 --- a/tests/test_health.py +++ b/tests/test_health.py @@ -1,6 +1,6 @@ import json -import pytest +import pytest HEALTH_PATH = '/management/health' @@ -10,10 +10,7 @@ def test_get_health(client, https): base_url = '%s://localhost' % ('https' if https else 'http') # Invalid credentials - response = client.get( - HEALTH_PATH, - base_url=base_url - ) + response = client.get(HEALTH_PATH, base_url=base_url) assert response.status_code == 200 data = json.loads(response.get_data(as_text=True)) diff --git a/tests/test_http.py b/tests/test_http.py index 674c8292..35f01b58 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1,10 +1,10 @@ import pytest from libtaxii.constants import ST_BAD_MESSAGE, ST_FAILURE +from utils import as_tm, is_headers_valid, prepare_headers + from opentaxii.taxii.converters import dict_to_service_entity from opentaxii.taxii.http import HTTP_X_TAXII_SERVICES -from utils import as_tm, is_headers_valid, prepare_headers - INBOX = dict( id='inbox-A', type='inbox', @@ -14,7 +14,8 @@ accept_all_content=True, protocol_bindings=[ 'urn:taxii.mitre.org:protocol:http:1.0', - 'urn:taxii.mitre.org:protocol:https:1.0'] + 'urn:taxii.mitre.org:protocol:https:1.0', + ], ) DISCOVERY = dict( @@ -23,7 +24,7 @@ description='discoveryA description', address='/relative/discovery', advertised_services=['inbox-A', 'discovery-A', 'discovery-B'], - protocol_bindings=['urn:taxii.mitre.org:protocol:http:1.0'] + protocol_bindings=['urn:taxii.mitre.org:protocol:http:1.0'], ) DISCOVERY_NOT_AVAILABLE = dict( @@ -33,7 +34,7 @@ address='/relative/discovery-b', advertised_services=['inbox-A', 'discovery-A'], protocol_bindings=['urn:taxii.mitre.org:protocol:http:1.0'], - available=False + available=False, ) SERVICES = [INBOX, DISCOVERY, DISCOVERY_NOT_AVAILABLE] @@ -44,7 +45,9 @@ @pytest.fixture(autouse=True) def local_services(server): for service in SERVICES: - server.servers.taxii1.persistence.update_service(dict_to_service_entity(service)) + server.servers.taxii1.persistence.update_service( + dict_to_service_entity(service) + ) def test_root_get(client): @@ -71,7 +74,7 @@ def test_status_message_response(client, version, https): INBOX['address'], data='invalid-body', headers=prepare_headers(version, https), - base_url=base_url + base_url=base_url, ) assert response.status_code == 200 @@ -93,7 +96,7 @@ def test_successful_response(client, version, https): DISCOVERY['address'], data=request.to_xml(), headers=prepare_headers(version=version, https=https), - base_url=base_url + base_url=base_url, ) assert response.status_code == 200 @@ -116,10 +119,7 @@ def test_post_parse_verification(client, version, https): base_url = '%s://localhost' % ('https' if https else 'http') response = client.post( - DISCOVERY['address'], - data=request.to_xml(), - headers=headers, - base_url=base_url + DISCOVERY['address'], data=request.to_xml(), headers=headers, base_url=base_url ) assert response.status_code == 200 @@ -144,7 +144,7 @@ def test_services_available(client, version, https): DISCOVERY_NOT_AVAILABLE['address'], data=request.to_xml(), headers=headers, - base_url=base_url + base_url=base_url, ) assert response.status_code == 200 diff --git a/tests/test_options.py b/tests/test_options.py index 7f068da5..9a34bdfe 100644 --- a/tests/test_options.py +++ b/tests/test_options.py @@ -1,4 +1,5 @@ import pytest + from opentaxii.taxii.converters import dict_to_service_entity from opentaxii.taxii.http import HTTP_X_TAXII_CONTENT_TYPES @@ -16,7 +17,9 @@ @pytest.fixture(autouse=True) def local_services(server): for service in [DISCOVERY]: - server.servers.taxii1.persistence.update_service(dict_to_service_entity(service)) + server.servers.taxii1.persistence.update_service( + dict_to_service_entity(service) + ) @pytest.mark.parametrize("https", [True, False]) @@ -25,10 +28,7 @@ def test_options_request(server, client, version, https): base_url = '%s://localhost' % ('https' if https else 'http') - response = client.options( - DISCOVERY['address'], - base_url=base_url - ) + response = client.options(DISCOVERY['address'], base_url=base_url) assert response.status_code == 200 assert HTTP_X_TAXII_CONTENT_TYPES in response.headers diff --git a/tests/test_server.py b/tests/test_server.py index 9fc12806..31095456 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,14 +1,13 @@ import concurrent.futures import pytest -from opentaxii.persistence import (OpenTAXII2PersistenceAPI, - Taxii2PersistenceManager) +from fixtures import DOMAIN + +from opentaxii.persistence import OpenTAXII2PersistenceAPI, Taxii2PersistenceManager from opentaxii.persistence.sqldb import Taxii2SQLDatabaseAPI from opentaxii.server import TAXII2Server from opentaxii.taxii.converters import dict_to_service_entity -from fixtures import DOMAIN - INBOX = dict( id='inbox-A', type='inbox', @@ -18,7 +17,8 @@ accept_all_content='yes', protocol_bindings=[ 'urn:taxii.mitre.org:protocol:http:1.0', - 'urn:taxii.mitre.org:protocol:https:1.0'], + 'urn:taxii.mitre.org:protocol:https:1.0', + ], ) DISCOVERY = dict( @@ -27,7 +27,7 @@ description='discoveryA description', address='/relative/discovery', advertised_services=['inboxA', 'discoveryA'], - protocol_bindings=['urn:taxii.mitre.org:protocol:http:1.0'] + protocol_bindings=['urn:taxii.mitre.org:protocol:http:1.0'], ) DISCOVERY_EXTERNAL = dict( @@ -35,7 +35,7 @@ type='discovery', description='discoveryB description', address='http://example.com/a/b/c', - protocol_bindings=['urn:taxii.mitre.org:protocol:http:1.0'] + protocol_bindings=['urn:taxii.mitre.org:protocol:http:1.0'], ) INTERNAL_SERVICES = [INBOX, DISCOVERY] @@ -45,19 +45,18 @@ @pytest.fixture() def local_services(server): for service in SERVICES: - server.servers.taxii1.persistence.update_service(dict_to_service_entity(service)) + server.servers.taxii1.persistence.update_service( + dict_to_service_entity(service) + ) def test_services_configured(server, local_services): assert len(server.servers.taxii1.get_services()) == len(SERVICES) - with_paths = [ - s for s in server.servers.taxii1.get_services() - if s.path] + with_paths = [s for s in server.servers.taxii1.get_services() if s.path] assert len(with_paths) == len(INTERNAL_SERVICES) - assert all([ - p.address.startswith(DOMAIN) for p in with_paths]) + assert all([p.address.startswith(DOMAIN) for p in with_paths]) def test_taxii2_configured(server): diff --git a/tests/utils.py b/tests/utils.py index 98ff5b04..34b639e8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,18 +1,21 @@ import re import pytest +from fixtures import CB_STIX_XML_111, CONTENT, MESSAGE, MESSAGE_ID from libtaxii import messages_10 as tm10 from libtaxii import messages_11 as tm11 + from opentaxii.taxii import entities -from opentaxii.taxii.http import (HTTP_ACCEPT, HTTP_CONTENT_XML, - TAXII_10_HTTP_HEADERS, - TAXII_10_HTTPS_HEADERS, - TAXII_11_HTTP_HEADERS, - TAXII_11_HTTPS_HEADERS) +from opentaxii.taxii.http import ( + HTTP_ACCEPT, + HTTP_CONTENT_XML, + TAXII_10_HTTP_HEADERS, + TAXII_10_HTTPS_HEADERS, + TAXII_11_HTTP_HEADERS, + TAXII_11_HTTPS_HEADERS, +) from opentaxii.taxii.utils import get_utc_now -from fixtures import CB_STIX_XML_111, CONTENT, MESSAGE, MESSAGE_ID - JWT_RE = re.compile(r'[A-Za-z0-9-_=]+\.[A-Za-z0-9-_=]+\.?[A-Za-z0-9-_.+/=]*') @@ -164,4 +167,5 @@ def assert_str_equal_no_formatting(str1, str2): class SKIP: """Used as signalling value to skip check""" + pass